34 Commits

Author SHA1 Message Date
bottler
62a2031dd4 Revert "Fix CUDA kernel index data type in vision/fair/pytorch3d/pytorch3d/csrc/compositing/alpha_composite.cu +10"
This reverts commit 3987612062.
2025-03-27 05:28:03 -07:00
Richard Barnes
3987612062 Fix CUDA kernel index data type in vision/fair/pytorch3d/pytorch3d/csrc/compositing/alpha_composite.cu +10
Summary:
CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables).

Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Reviewed By: dtolnay

Differential Revision: D71355356

fbshipit-source-id: cea44891416d9efd2f466d6c45df4e36008fa036
2025-03-19 13:21:43 -07:00
Alexandros Benetatos
06a76ef8dd Correct "fast" matrix_to_axis_angle near pi (#1953)
Summary:
A continuation of https://github.com/facebookresearch/pytorch3d/issues/1948 -- this commit fixes a small numerical issue with `matrix_to_axis_angle(..., fast=True)` near `pi`.
bottler feel free to check this out, it's a single-line change.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1953

Reviewed By: MichaelRamamonjisoa

Differential Revision: D70088251

Pulled By: bottler

fbshipit-source-id: 54cc7f946283db700cec2cd5575cf918456b7f32
2025-03-11 12:25:59 -07:00
Richard Barnes
21205730d9 Fix unused-variable issues, mostly relating to AMD/HIP
Reviewed By: meyering

Differential Revision: D70845538

fbshipit-source-id: 8e52b5e1f1d96b86404fc3b8cbc6fb952e2cb1a6
2025-03-08 13:03:17 -08:00
Richard Barnes
7e09505538 Enable -Wunused-value in vision/PACKAGE +1
Summary:
This diff enables compilation warning flags for the directory in question. Further details are in [this workplace post](https://fb.workplace.com/permalink.php?story_fbid=pfbid02XaWNiCVk69r1ghfvDVpujB8Hr9Y61uDvNakxiZFa2jwiPHscVdEQwCBHrmWZSyMRl&id=100051201402394).

This is a low-risk diff. There are **no run-time effects** and the diff has already been observed to compile locally. **If the code compiles, it work; test errors are spurious.**

Differential Revision: D70282347

fbshipit-source-id: e2fa55c002d7124b13450c812165d244b8a53f4e
2025-03-04 17:49:30 -08:00
Nicholas Ormrod
20bd8b33f6 facebook-unused-include-check in fbcode/vision
Summary:
Remove headers flagged by facebook-unused-include-check over fbcode.vision.

+ format and autodeps

This is a codemod. It was automatically generated and will be landed once it is approved and tests are passing in sandcastle.
You have been added as a reviewer by Sentinel or Butterfly.

Autodiff project: uiv
Autodiff partition: fbcode.vision
Autodiff bookmark: ad.uiv.fbcode.vision

Reviewed By: dtolnay

Differential Revision: D70403619

fbshipit-source-id: d109c15774eeb3d809875f75fa2a26ed20d7f9a6
2025-02-28 18:08:12 -08:00
alex-bene
7a3c0cbc9d Increase performance for conversions including axis angles (#1948)
Summary:
This is an extension of https://github.com/facebookresearch/pytorch3d/issues/1544 with various speed, stability, and readability improvements. (I could not find a way to make a commit to the existing PR). This PR is still based on the [Rodrigues' rotation formula](https://en.wikipedia.org/wiki/Rotation_formalisms_in_three_dimensions#Rotation_matrix_%E2%86%94_Euler_axis/angle).

The motivation is the same; this change speeds up the conversions up to 10x, depending on the device, batch size, etc.

### Notes
- As the angles get very close to `π`, the existing implementation and the proposed one start to differ. However, (my understanding is that) this is not a problem as the axis can not be stably inferred from the rotation matrix in this case in general.
- bottler , I tried to follow similar conventions as existing functions to deal with weird angles, let me know if something needs to be changed to merge this.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1948

Reviewed By: MichaelRamamonjisoa

Differential Revision: D69193009

Pulled By: bottler

fbshipit-source-id: e5ed34b45b625114ec4419bb89e22a6aefad4eeb
2025-02-07 07:37:42 -08:00
Roman Shapovalov
215590b497 In FrameDataBuilder, set all path even if we don’t load blobs
Summary:
This is a somewhat not BC change: some None paths will be replaced by metadata paths, even when they were not used for data loading.

Moreover, removing the legacy fix to the paths in the old CO3D release.

Reviewed By: bottler

Differential Revision: D69048238

fbshipit-source-id: 2a8b26d7b9f5e2adf39c65888b5863a5a9de1996
2025-02-06 09:41:44 -08:00
Antoine Toisoul
43cd681d4f Updates to Implicitron dataset, metrics and tools
Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack)

Reviewed By: shapovalov

Differential Revision: D65942513

fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
2025-01-27 09:43:42 -08:00
Roman Shapovalov
42a4a7d432 Generalising SqlIndexDataset to support subtypes of SqlSequenceAnnotation
Summary: We did not often extend sequence-level metadata but now for applications like text-to-3D/video, we need to store captions and similar.

Reviewed By: bottler

Differential Revision: D68269926

fbshipit-source-id: f8af308adce51863d719a335d85cd2558943bd4c
2025-01-20 03:39:06 -08:00
generatedunixname89002005307016
699bc671ca Add missing Pyre mode headers] [batch:3/1531] [shard:41/N]
Differential Revision: D68316763

fbshipit-source-id: fb3e1e1a17786f6f681f1b11b48b4efd7a8ac311
2025-01-17 12:41:56 -08:00
Roman Shapovalov
49cf5a0f37 Loading fg probability from the alpha channel of image_rgb
Summary:
It is often easier to store the mask together with RGB, especially for renders. The logic in this diff:
* if load_mask and mask_path provided, take the mask from mask_path,
* otherwise, check if the image has the alpha channel and take it as a mask.

Reviewed By: antoinetlc

Differential Revision: D68160212

fbshipit-source-id: d9b6779f90027a4987ba96800983f441edff9c74
2025-01-15 11:53:30 -08:00
Roman Shapovalov
89b851e64c Refactor a utility function for bbox conversion
Summary: This function makes it easier to extend FrameData class with new channels; brushing it up a bit.

Reviewed By: bottler

Differential Revision: D67816470

fbshipit-source-id: 6575415c864d0f539e283889760cd2331bf226a7
2025-01-06 04:17:57 -08:00
Roman Shapovalov
5247f6ad74 Fixing type hints in FrameData
Summary: As subj

Reviewed By: bottler

Differential Revision: D67791200

fbshipit-source-id: c2db01c94718102618f4c8bc5c5130c65ee1d81f
2025-01-06 04:17:57 -08:00
Roman Shapovalov
e41aff47db Adding default values to FrameData for internal usage
Summary: Ensuring all fields in FrameData have defaults.

Reviewed By: bottler

Differential Revision: D67762780

fbshipit-source-id: b680d29a1a11689850905978df544cdb4eb7ddcd
2025-01-06 04:17:57 -08:00
Roman Shapovalov
64a5bfadc8 Adding SQL Dataset related files to the build script
Summary: Now that we have SQLAlchemy 2.0, we can fully use them.

Reviewed By: bottler

Differential Revision: D66920096

fbshipit-source-id: 25c0ea1c4f7361e66348035519627dc961b9e6e6
2024-12-23 16:05:26 -08:00
Thomas Polasek
055ab3a2e3 Convert directory fbcode/vision to use the Ruff Formatter
Summary:
Converts the directory specified to use the Ruff formatter in pyfmt

ruff_dog

If this diff causes merge conflicts when rebasing, please run
`hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt`
on your diff, and amend any changes before rebasing onto latest.
That should help reduce or eliminate any merge conflicts.

allow-large-files

Reviewed By: bottler

Differential Revision: D66472063

fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
2024-11-26 02:38:20 -08:00
Edward Yang
f6c2ca6bfc Prepare for "Fix type-safety of torch.nn.Module instances": wave 2
Summary: See D52890934

Reviewed By: malfet, r-barnes

Differential Revision: D66245100

fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
2024-11-21 11:08:51 -08:00
Jeremy Reizenstein
e20cbe9b0e test fixes and lints
Summary:
- followup recent pyre change D63415925
- make tests remove temporary files
- weights_only=True in torch.load
- lint fixes

3 test fixes from VRehnberg in https://github.com/facebookresearch/pytorch3d/issues/1914
- imageio channels fix
- frozen decorator in test_config
- load_blobs positional

Reviewed By: MichaelRamamonjisoa

Differential Revision: D66162167

fbshipit-source-id: 7737e174691b62f1708443a4fae07343cec5bfeb
2024-11-20 09:15:51 -08:00
Jeremy Reizenstein
c17e6f947a run CI tests on main
Reviewed By: MichaelRamamonjisoa

Differential Revision: D66162168

fbshipit-source-id: 90268c1925fa9439b876df143035c9d3c3a74632
2024-11-20 05:06:52 -08:00
Yann Noutary
91c9f34137 Add safeguard in case num_tris diverges
Summary:
This PR fixes adds a safeguard preventing num_tris to overflow in `MAX_TRIS`-length arrays. The update rule of `num_tris` is bounded :

 - max(num_tris(t)) = 2*num_tris(t-1)
 - num_tris(0) = 12
 - t <= 6

So :
 - max(num_tris) = 2^6*12
 - max(num_tris) = 768

Reviewed By: bottler

Differential Revision: D66162573

fbshipit-source-id: e269a79c75c6cc33306986b1f1256cffbe96c730
2024-11-20 01:24:28 -08:00
Jeremy Reizenstein
81d82980bc Fix ogl test hang
Summary: See https://github.com/facebookresearch/pytorch3d/issues/1908

Reviewed By: MichaelRamamonjisoa

Differential Revision: D65280253

fbshipit-source-id: ec05902c5f2f7eb9ddd92bda0045cc3564b8c091
2024-11-06 11:40:42 -08:00
Jeremy Reizenstein
8fe6934885 fix subdivide_meshes with empty mesh #1788
Summary:
Simplify code

fixes https://github.com/facebookresearch/pytorch3d/issues/1788

Reviewed By: MichaelRamamonjisoa

Differential Revision: D61847675

fbshipit-source-id: 48400875d1d885bb3615bc9f4b3c7c3d822b67e7
2024-11-06 11:40:26 -08:00
bottler
c434957b2a Run tests in github action (#1896)
Summary: Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1896

Reviewed By: MichaelRamamonjisoa

Differential Revision: D65272512

Pulled By: bottler

fbshipit-source-id: 3bcfab43acd2d6be5444ff25178381510ddac015
2024-11-06 11:15:34 -08:00
Jeremy Reizenstein
dd2a11b5fc Fix OFF for new numpy errors
Summary: Error messages have changed around numpy version 2, making existing code fail.

Reviewed By: MichaelRamamonjisoa

Differential Revision: D65280674

fbshipit-source-id: b3ae613ea8f0f4ae20fb6e5e816314b8c10e6c65
2024-11-06 11:13:59 -08:00
Richard Barnes
9563ef79ca c10::optional -> std::optional in some files
Reviewed By: jermenkoo

Differential Revision: D65425234

fbshipit-source-id: 1e7707d6b6aab640cc1fdd3bd71a3b50f77a0909
2024-11-04 12:03:51 -08:00
generatedunixname89002005287564
008c7ab58c Pre-silence Pyre Errors for upcoming upgrade] [batch:67/603] [shard:3/N]
Reviewed By: MaggieMoss

Differential Revision: D65290095

fbshipit-source-id: ced87d096aa8939700de5599ce6984cd7ae93912
2024-10-31 16:26:25 -07:00
Jeremy Reizenstein
9eaed4c495 Fix K>1 in multimap UV sampling
Summary:
Fixes https://github.com/facebookresearch/pytorch3d/issues/1897
"Wrong dimension on gather".

Reviewed By: cijose

Differential Revision: D65280675

fbshipit-source-id: 1d587036887972bb2a2ea56d40df19cbf1aeb6cc
2024-10-31 16:05:10 -07:00
Richard Barnes
e13848265d at::optional -> std::optional (#1170)
Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/1170

Reviewed By: gineshidalgo99

Differential Revision: D64938040

fbshipit-source-id: 57f98b90676ad0164a6975ea50e4414fd85ae6c4
2024-10-25 06:37:57 -07:00
generatedunixname89002005307016
58566963d6 Add type error suppressions for upcoming upgrade
Reviewed By: MaggieMoss

Differential Revision: D64502797

fbshipit-source-id: cee9a54dfa8a005d5912b895d0bd094f352c5c6f
2024-10-16 19:22:01 -07:00
Suresh Babu Kolla
e17ed5cd50 Hipify Pulsar for PyTorch3D
Summary:
- Hipified Pytorch Pulsar
   - Created separate target for Pulsar tests and enabled RE testing
   - Pytorch3D full test suite requires additional work like fixing EGL
     dependencies on AMD

Reviewed By: danzimm

Differential Revision: D61339912

fbshipit-source-id: 0d10bc966e4de4a959f3834a386bad24e449dc1f
2024-10-09 14:38:42 -07:00
Richard Barnes
8ed0c7a002 c10::optional -> std::optional
Summary: `c10::optional` is an alias for `std::optional`. Let's remove the alias and use the real thing.

Reviewed By: meyering

Differential Revision: D63402341

fbshipit-source-id: 241383e7ca4b2f3f1f9cac3af083056123dfd02b
2024-10-03 14:38:37 -07:00
Richard Barnes
2da913c7e6 c10::optional -> std::optional
Summary: `c10::optional` is an alias for `std::optional`. Let's remove the alias and use the real thing.

Reviewed By: palmje

Differential Revision: D63409387

fbshipit-source-id: fb6db59a14db9e897e2e6b6ad378f33bf2af86e8
2024-10-02 11:09:29 -07:00
generatedunixname89002005307016
fca83e6369 Convert .pyre_configuration.local to fast by default architecture] [batch:23/263] [shard:3/N] [A]
Reviewed By: connernilsen

Differential Revision: D63415925

fbshipit-source-id: c3e28405c70f9edcf8c21457ac4faf7315b07322
2024-09-25 17:34:03 -07:00
173 changed files with 993 additions and 642 deletions

View File

@@ -88,7 +88,6 @@ def workflow_pair(
upload=False, upload=False,
filter_branch, filter_branch,
): ):
w = [] w = []
py = python_version.replace(".", "") py = python_version.replace(".", "")
pyt = pytorch_version.replace(".", "") pyt = pytorch_version.replace(".", "")
@@ -127,7 +126,6 @@ def generate_base_workflow(
btype, btype,
filter_branch=None, filter_branch=None,
): ):
d = { d = {
"name": base_workflow_name, "name": base_workflow_name,
"python_version": python_version, "python_version": python_version,

23
.github/workflows/build.yml vendored Normal file
View File

@@ -0,0 +1,23 @@
name: facebookresearch/pytorch3d/build_and_test
on:
pull_request:
branches:
- main
push:
branches:
- main
jobs:
binary_linux_conda_cuda:
runs-on: 4-core-ubuntu-gpu-t4
env:
PYTHON_VERSION: "3.12"
BUILD_VERSION: "${{ github.run_number }}"
PYTORCH_VERSION: "2.4.1"
CU_VERSION: "cu121"
JUST_TESTRUN: 1
steps:
- uses: actions/checkout@v4
- name: Build and run tests
run: |-
conda create --name env --yes --quiet conda-build
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda

View File

@@ -36,5 +36,5 @@ then
echo "Running pyre..." echo "Running pyre..."
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/" echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ ) ( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
fi fi

View File

@@ -10,6 +10,7 @@ This example demonstrates the most trivial, direct interface of the pulsar
sphere renderer. It renders and saves an image with 10 random spheres. sphere renderer. It renders and saves an image with 10 random spheres.
Output: basic.png. Output: basic.png.
""" """
import logging import logging
import math import math
from os import path from os import path

View File

@@ -11,6 +11,7 @@ interface for sphere renderering. It renders and saves an image with
10 random spheres. 10 random spheres.
Output: basic-pt3d.png. Output: basic-pt3d.png.
""" """
import logging import logging
from os import path from os import path

View File

@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
original camera parameters. original camera parameters.
Output: cam.gif. Output: cam.gif.
""" """
import logging import logging
import math import math
from os import path from os import path

View File

@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
original camera parameters. original camera parameters.
Output: cam-pt3d.gif Output: cam-pt3d.gif
""" """
import logging import logging
from os import path from os import path

View File

@@ -18,6 +18,7 @@ This example is not available yet through the 'unified' interface,
because opacity support has not landed in PyTorch3D for general data because opacity support has not landed in PyTorch3D for general data
structures yet. structures yet.
""" """
import logging import logging
import math import math
from os import path from os import path

View File

@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful optimization is used to converge towards a faithful
scene representation. scene representation.
""" """
import logging import logging
import math import math

View File

@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
optimization is used to converge towards a faithful optimization is used to converge towards a faithful
scene representation. scene representation.
""" """
import logging import logging
import math import math

View File

@@ -4,10 +4,11 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import argparse
import os.path import os.path
import runpy import runpy
import subprocess import subprocess
from typing import List from typing import List, Tuple
# required env vars: # required env vars:
# CU_VERSION: E.g. cu112 # CU_VERSION: E.g. cu112
@@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
source_root_dir = os.environ["PWD"] source_root_dir = os.environ["PWD"]
def version_constraint(version): def version_constraint(version) -> str:
""" """
Given version "11.3" returns " >=11.3,<11.4" Given version "11.3" returns " >=11.3,<11.4"
""" """
@@ -32,7 +33,7 @@ def version_constraint(version):
return f" >={version},<{upper}" return f" >={version},<{upper}"
def get_cuda_major_minor(): def get_cuda_major_minor() -> Tuple[str, str]:
if CU_VERSION == "cpu": if CU_VERSION == "cpu":
raise ValueError("fn only for cuda builds") raise ValueError("fn only for cuda builds")
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu": if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
return major, minor return major, minor
def setup_cuda(): def setup_cuda(use_conda_cuda: bool) -> List[str]:
if CU_VERSION == "cpu": if CU_VERSION == "cpu":
return return []
major, minor = get_cuda_major_minor() major, minor = get_cuda_major_minor()
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
os.environ["FORCE_CUDA"] = "1" os.environ["FORCE_CUDA"] = "1"
basic_nvcc_flags = ( basic_nvcc_flags = (
@@ -75,6 +75,15 @@ def setup_cuda():
if os.environ.get("JUST_TESTRUN", "0") != "1": if os.environ.get("JUST_TESTRUN", "0") != "1":
os.environ["NVCC_FLAGS"] = nvcc_flags os.environ["NVCC_FLAGS"] = nvcc_flags
if use_conda_cuda:
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
f"- cuda-version={major}.{minor}"
)
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
else:
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
return []
def setup_conda_pytorch_constraint() -> List[str]: def setup_conda_pytorch_constraint() -> List[str]:
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
return ["-c", "pytorch", "-c", "nvidia"] return ["-c", "pytorch", "-c", "nvidia"]
def setup_conda_cudatoolkit_constraint(): def setup_conda_cudatoolkit_constraint() -> None:
if CU_VERSION == "cpu": if CU_VERSION == "cpu":
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly" os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = "" os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
@@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
def do_build(start_args: List[str]): def do_build(start_args: List[str]) -> None:
args = start_args.copy() args = start_args.copy()
test_flag = os.environ.get("TEST_FLAG") test_flag = os.environ.get("TEST_FLAG")
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Build the conda package.")
parser.add_argument(
"--use-conda-cuda",
action="store_true",
help="get cuda from conda ignoring local cuda",
)
our_args = parser.parse_args()
args = ["conda", "build"] args = ["conda", "build"]
setup_cuda() args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
init_path = source_root_dir + "/pytorch3d/__init__.py" init_path = source_root_dir + "/pytorch3d/__init__.py"
build_version = runpy.run_path(init_path)["__version__"] build_version = runpy.run_path(init_path)["__version__"]

View File

@@ -8,10 +8,13 @@ source:
requirements: requirements:
build: build:
- {{ compiler('c') }} # [win] - {{ compiler('c') }} # [win]
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
{{ environ.get('CONDA_CUB_CONSTRAINT') }} {{ environ.get('CONDA_CUB_CONSTRAINT') }}
host: host:
- python - python
- mkl =2023 # [x86_64]
{{ environ.get('SETUPTOOLS_CONSTRAINT') }} {{ environ.get('SETUPTOOLS_CONSTRAINT') }}
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
@@ -22,6 +25,7 @@ requirements:
- python - python
- numpy >=1.11 - numpy >=1.11
- torchvision >=0.5 - torchvision >=0.5
- mkl =2023 # [x86_64]
- iopath - iopath
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }} {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
@@ -47,8 +51,11 @@ test:
- imageio - imageio
- hydra-core - hydra-core
- accelerate - accelerate
- matplotlib
- tabulate
- pandas
- sqlalchemy
commands: commands:
#pytest .
python -m unittest discover -v -s tests -t . python -m unittest discover -v -s tests -t .

View File

@@ -7,7 +7,7 @@
# pyre-unsafe # pyre-unsafe
"""" """ "
This file is the entry point for launching experiments with Implicitron. This file is the entry point for launching experiments with Implicitron.
Launch Training Launch Training
@@ -44,6 +44,7 @@ The outputs of the experiment are saved and logged in multiple ways:
config file. config file.
""" """
import logging import logging
import os import os
import warnings import warnings

View File

@@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase): class ModelFactoryBase(ReplaceableBase):
resume: bool = True # resume from the last checkpoint resume: bool = True # resume from the last checkpoint
def __call__(self, **kwargs) -> ImplicitronModelBase: def __call__(self, **kwargs) -> ImplicitronModelBase:
@@ -116,7 +115,9 @@ class ImplicitronModelFactory(ModelFactoryBase):
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
} }
model_state_dict = torch.load( model_state_dict = torch.load(
model_io.get_model_path(model_path), map_location=map_location model_io.get_model_path(model_path),
map_location=map_location,
weights_only=True,
) )
try: try:

View File

@@ -123,6 +123,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
""" """
# Get the parameters to optimize # Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function if hasattr(model, "_get_param_groups"): # use the model function
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay) p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
else: else:
p_groups = [ p_groups = [
@@ -241,7 +242,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
map_location = { map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
} }
optimizer_state = torch.load(opt_path, map_location) optimizer_state = torch.load(opt_path, map_location, weights_only=True)
else: else:
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.") raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
return optimizer_state return optimizer_state

View File

@@ -161,7 +161,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
for epoch in range(start_epoch, self.max_epochs): for epoch in range(start_epoch, self.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start # automatic new_epoch and plotting of stats at every epoch start
with stats: with stats:
# Make sure to re-seed random generators to ensure reproducibility # Make sure to re-seed random generators to ensure reproducibility
# even after restart. # even after restart.
seed_all_random_engines(seed + epoch) seed_all_random_engines(seed + epoch)
@@ -395,6 +394,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
): ):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
if hasattr(model, "visualize"): if hasattr(model, "visualize"):
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
model.visualize( model.visualize(
viz, viz,
visdom_env_imgs, visdom_env_imgs,

View File

@@ -53,12 +53,8 @@ class TestExperiment(unittest.TestCase):
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = ( cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
"JsonIndexDatasetMapProvider" "JsonIndexDatasetMapProvider"
) )
dataset_args = ( dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataloader_args = (
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataset_args.category = "skateboard" dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0 dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted" dataset_args.dataset_root = "manifold://co3d/tree/extracted"
@@ -94,12 +90,8 @@ class TestExperiment(unittest.TestCase):
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = ( cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
"JsonIndexDatasetMapProvider" "JsonIndexDatasetMapProvider"
) )
dataset_args = ( dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataloader_args = (
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataset_args.category = "skateboard" dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0 dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted" dataset_args.dataset_root = "manifold://co3d/tree/extracted"
@@ -111,9 +103,7 @@ class TestExperiment(unittest.TestCase):
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2 cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential" cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = ( cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
2
)
if DEBUG: if DEBUG:
experiment.dump_cfg(cfg) experiment.dump_cfg(cfg)

View File

@@ -81,8 +81,9 @@ class TestOptimizerFactory(unittest.TestCase):
def test_param_overrides_self_param_group_assignment(self): def test_param_overrides_self_param_group_assignment(self):
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)] pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
na, nb = Node(params=[pa]), Node( na, nb = (
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"} Node(params=[pa]),
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
) )
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"}) root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
param_groups = self._get_param_groups(root) param_groups = self._get_param_groups(root)

View File

@@ -84,9 +84,9 @@ def get_nerf_datasets(
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)): if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
# Automatically download the data files if missing. # Automatically download the data files if missing.
download_data((dataset_name,), data_root=data_root) download_data([dataset_name], data_root=data_root)
train_data = torch.load(cameras_path) train_data = torch.load(cameras_path, weights_only=True)
n_cameras = train_data["cameras"]["R"].shape[0] n_cameras = train_data["cameras"]["R"].shape[0]
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS _image_max_image_pixels = Image.MAX_IMAGE_PIXELS

View File

@@ -194,7 +194,6 @@ class Stats:
it = self.it[stat_set] it = self.it[stat_set]
for stat in self.log_vars: for stat in self.log_vars:
if stat not in self.stats[stat_set]: if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter() self.stats[stat_set][stat] = AverageMeter()

View File

@@ -24,7 +24,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
@hydra.main(config_path=CONFIG_DIR, config_name="lego") @hydra.main(config_path=CONFIG_DIR, config_name="lego")
def main(cfg: DictConfig): def main(cfg: DictConfig):
# Device on which to run. # Device on which to run.
if torch.cuda.is_available(): if torch.cuda.is_available():
device = "cuda" device = "cuda"
@@ -63,7 +62,7 @@ def main(cfg: DictConfig):
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!") raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
print(f"Loading checkpoint {checkpoint_path}.") print(f"Loading checkpoint {checkpoint_path}.")
loaded_data = torch.load(checkpoint_path) loaded_data = torch.load(checkpoint_path, weights_only=True)
# Do not load the cached xy grid. # Do not load the cached xy grid.
# - this allows setting an arbitrary evaluation image size. # - this allows setting an arbitrary evaluation image size.
state_dict = { state_dict = {

View File

@@ -42,7 +42,6 @@ class TestRaysampler(unittest.TestCase):
cameras, rays = [], [] cameras, rays = [], []
for _ in range(batch_size): for _ in range(batch_size):
R = random_rotations(1) R = random_rotations(1)
T = torch.randn(1, 3) T = torch.randn(1, 3)
focal_length = torch.rand(1, 2) + 0.5 focal_length = torch.rand(1, 2) + 0.5

View File

@@ -25,7 +25,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
@hydra.main(config_path=CONFIG_DIR, config_name="lego") @hydra.main(config_path=CONFIG_DIR, config_name="lego")
def main(cfg: DictConfig): def main(cfg: DictConfig):
# Set the relevant seeds for reproducibility. # Set the relevant seeds for reproducibility.
np.random.seed(cfg.seed) np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed) torch.manual_seed(cfg.seed)
@@ -77,7 +76,7 @@ def main(cfg: DictConfig):
# Resume training if requested. # Resume training if requested.
if cfg.resume and os.path.isfile(checkpoint_path): if cfg.resume and os.path.isfile(checkpoint_path):
print(f"Resuming from checkpoint {checkpoint_path}.") print(f"Resuming from checkpoint {checkpoint_path}.")
loaded_data = torch.load(checkpoint_path) loaded_data = torch.load(checkpoint_path, weights_only=True)
model.load_state_dict(loaded_data["model"]) model.load_state_dict(loaded_data["model"])
stats = pickle.loads(loaded_data["stats"]) stats = pickle.loads(loaded_data["stats"])
print(f" => resuming from epoch {stats.epoch}.") print(f" => resuming from epoch {stats.epoch}.")
@@ -219,7 +218,6 @@ def main(cfg: DictConfig):
# Validation # Validation
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0: if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
# Sample a validation camera/image. # Sample a validation camera/image.
val_batch = next(val_dataloader.__iter__()) val_batch = next(val_dataloader.__iter__())
val_image, val_camera, camera_idx = val_batch[0].values() val_image, val_camera, camera_idx = val_batch[0].values()

View File

@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
def meshgrid_ij( def meshgrid_ij(
*A: Union[torch.Tensor, Sequence[torch.Tensor]] *A: Union[torch.Tensor, Sequence[torch.Tensor]],
) -> Tuple[torch.Tensor, ...]: # pragma: no cover ) -> Tuple[torch.Tensor, ...]: # pragma: no cover
""" """
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij

View File

@@ -7,7 +7,6 @@
*/ */
#include <torch/extension.h> #include <torch/extension.h>
#include <queue>
#include <tuple> #include <tuple>
std::tuple<at::Tensor, at::Tensor> BallQueryCpu( std::tuple<at::Tensor, at::Tensor> BallQueryCpu(

View File

@@ -28,7 +28,6 @@ __global__ void alphaCompositeCudaForwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas, const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) { const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on // clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3); const int64_t W = points_idx.size(3);
@@ -79,7 +78,6 @@ __global__ void alphaCompositeCudaBackwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas, const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) { const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on // clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3); const int64_t W = points_idx.size(3);

View File

@@ -28,7 +28,6 @@ __global__ void weightedSumNormCudaForwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas, const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) { const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on // clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3); const int64_t W = points_idx.size(3);
@@ -92,7 +91,6 @@ __global__ void weightedSumNormCudaBackwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas, const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) { const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on // clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3); const int64_t W = points_idx.size(3);

View File

@@ -26,7 +26,6 @@ __global__ void weightedSumCudaForwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas, const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) { const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on // clang-format on
const int64_t batch_size = result.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3); const int64_t W = points_idx.size(3);
@@ -74,7 +73,6 @@ __global__ void weightedSumCudaBackwardKernel(
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas, const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) { const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
// clang-format on // clang-format on
const int64_t batch_size = points_idx.size(0);
const int64_t C = features.size(0); const int64_t C = features.size(0);
const int64_t H = points_idx.size(2); const int64_t H = points_idx.size(2);
const int64_t W = points_idx.size(3); const int64_t W = points_idx.size(3);

View File

@@ -7,15 +7,11 @@
*/ */
// clang-format off // clang-format off
#if !defined(USE_ROCM)
#include "./pulsar/global.h" // Include before <torch/extension.h>. #include "./pulsar/global.h" // Include before <torch/extension.h>.
#endif
#include <torch/extension.h> #include <torch/extension.h>
// clang-format on // clang-format on
#if !defined(USE_ROCM)
#include "./pulsar/pytorch/renderer.h" #include "./pulsar/pytorch/renderer.h"
#include "./pulsar/pytorch/tensor_util.h" #include "./pulsar/pytorch/tensor_util.h"
#endif
#include "ball_query/ball_query.h" #include "ball_query/ball_query.h"
#include "blending/sigmoid_alpha_blend.h" #include "blending/sigmoid_alpha_blend.h"
#include "compositing/alpha_composite.h" #include "compositing/alpha_composite.h"
@@ -104,7 +100,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Pulsar. // Pulsar.
// Pulsar not enabled on AMD. // Pulsar not enabled on AMD.
#if !defined(USE_ROCM)
#ifdef PULSAR_LOGGING_ENABLED #ifdef PULSAR_LOGGING_ENABLED
c10::ShowLogInfoToStderr(); c10::ShowLogInfoToStderr();
#endif #endif
@@ -154,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("gamma"), py::arg("gamma"),
py::arg("max_depth"), py::arg("max_depth"),
py::arg("min_depth") /* = 0.f*/, py::arg("min_depth") /* = 0.f*/,
py::arg( py::arg("bg_col") /* = std::nullopt not exposed properly in
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */ pytorch 1.1. */
, ,
py::arg("opacity") /* = at::nullopt ... */, py::arg("opacity") /* = std::nullopt ... */,
py::arg("percent_allowed_difference") = 0.01f, py::arg("percent_allowed_difference") = 0.01f,
py::arg("max_n_hits") = MAX_UINT, py::arg("max_n_hits") = MAX_UINT,
py::arg("mode") = 0) py::arg("mode") = 0)
@@ -189,5 +184,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.attr("MAX_UINT") = py::int_(MAX_UINT); m.attr("MAX_UINT") = py::int_(MAX_UINT);
m.attr("MAX_USHORT") = py::int_(MAX_USHORT); m.attr("MAX_USHORT") = py::int_(MAX_USHORT);
m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES); m.attr("PULSAR_MAX_GRAD_SPHERES") = py::int_(MAX_GRAD_SPHERES);
#endif
} }

View File

@@ -7,10 +7,7 @@
*/ */
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h>
#include <list> #include <list>
#include <numeric>
#include <queue>
#include <tuple> #include <tuple>
#include "iou_box3d/iou_utils.h" #include "iou_box3d/iou_utils.h"

View File

@@ -461,10 +461,8 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
__device__ inline bool IsCoplanarTriTri( __device__ inline bool IsCoplanarTriTri(
const FaceVerts& tri1, const FaceVerts& tri1,
const FaceVerts& tri2) { const FaceVerts& tri2) {
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2}); const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2}); const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
// Check if parallel // Check if parallel
@@ -500,7 +498,6 @@ __device__ inline bool IsCoplanarTriPlane(
const FaceVerts& tri, const FaceVerts& tri,
const FaceVerts& plane, const FaceVerts& plane,
const float3& normal) { const float3& normal) {
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2}); const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
// check if parallel // check if parallel
@@ -728,7 +725,7 @@ __device__ inline int BoxIntersections(
} }
} }
// Update the face_verts_out tris // Update the face_verts_out tris
num_tris = offset; num_tris = min(MAX_TRIS, offset);
for (int j = 0; j < num_tris; ++j) { for (int j = 0; j < num_tris; ++j) {
face_verts_out[j] = tri_verts_updated[j]; face_verts_out[j] = tri_verts_updated[j];
} }

View File

@@ -8,9 +8,7 @@
#include <torch/csrc/autograd/VariableTypeUtils.h> #include <torch/csrc/autograd/VariableTypeUtils.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <algorithm>
#include <cmath> #include <cmath>
#include <thread>
#include <vector> #include <vector>
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to // In the x direction, the location {0, ..., grid_size_x - 1} correspond to

View File

@@ -36,11 +36,13 @@
#pragma nv_diag_suppress 2951 #pragma nv_diag_suppress 2951
#pragma nv_diag_suppress 2967 #pragma nv_diag_suppress 2967
#else #else
#if !defined(USE_ROCM)
#pragma diag_suppress = attribute_not_allowed #pragma diag_suppress = attribute_not_allowed
#pragma diag_suppress = 1866 #pragma diag_suppress = 1866
#pragma diag_suppress = 2941 #pragma diag_suppress = 2941
#pragma diag_suppress = 2951 #pragma diag_suppress = 2951
#pragma diag_suppress = 2967 #pragma diag_suppress = 2967
#endif //! USE_ROCM
#endif #endif
#else // __CUDACC__ #else // __CUDACC__
#define INLINE inline #define INLINE inline
@@ -56,7 +58,9 @@
#pragma clang diagnostic pop #pragma clang diagnostic pop
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#if !defined(USE_ROCM)
#include <vector_functions.h> #include <vector_functions.h>
#endif //! USE_ROCM
#else #else
#ifndef cudaStream_t #ifndef cudaStream_t
typedef void* cudaStream_t; typedef void* cudaStream_t;

View File

@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
#define SHARED __shared__ #define SHARED __shared__
#define ACTIVEMASK() __activemask() #define ACTIVEMASK() __activemask()
#define BALLOT(mask, val) __ballot_sync((mask), val) #define BALLOT(mask, val) __ballot_sync((mask), val)
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
* supports __shfl_*. Disabling until the move to ROCM-6.2.
*/
#if !defined(USE_ROCM)
/** /**
* Find the cumulative sum within a warp up to the current * Find the cumulative sum within a warp up to the current
* thread lane, with each mask thread contributing base. * thread lane, with each mask thread contributing base.
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
ret.z = WARP_SUM(group, mask, base.z); ret.z = WARP_SUM(group, mask, base.z);
return ret; return ret;
} }
#endif //! USE_ROCM
// Floating point. // Floating point.
// #define FMUL(a, b) __fmul_rn((a), (b)) // #define FMUL(a, b) __fmul_rn((a), (b))
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
#define FMA(x, y, z) __fmaf_rn((x), (y), (z)) #define FMA(x, y, z) __fmaf_rn((x), (y), (z))
#define I2F(a) __int2float_rn(a) #define I2F(a) __int2float_rn(a)
#define FRCP(x) __frcp_rn(x) #define FRCP(x) __frcp_rn(x)
#if !defined(USE_ROCM)
__device__ static float atomicMax(float* address, float val) { __device__ static float atomicMax(float* address, float val) {
int* address_as_i = (int*)address; int* address_as_i = (int*)address;
int old = *address_as_i, assumed; int old = *address_as_i, assumed;
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
} while (assumed != old); } while (assumed != old);
return __int_as_float(old); return __int_as_float(old);
} }
#endif //! USE_ROCM
#define DMAX(a, b) FMAX(a, b) #define DMAX(a, b) FMAX(a, b)
#define DMIN(a, b) FMIN(a, b) #define DMIN(a, b) FMIN(a, b)
#define DSQRT(a) sqrt(a) #define DSQRT(a) sqrt(a)

View File

@@ -14,7 +14,7 @@
#include "./commands.h" #include "./commands.h"
namespace pulsar { namespace pulsar {
IHD CamGradInfo::CamGradInfo() { IHD CamGradInfo::CamGradInfo(int x) {
cam_pos = make_float3(0.f, 0.f, 0.f); cam_pos = make_float3(0.f, 0.f, 0.f);
pixel_0_0_center = make_float3(0.f, 0.f, 0.f); pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
pixel_dir_x = make_float3(0.f, 0.f, 0.f); pixel_dir_x = make_float3(0.f, 0.f, 0.f);

View File

@@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
}; };
struct CamGradInfo { struct CamGradInfo {
HOST DEVICE CamGradInfo(); HOST DEVICE CamGradInfo(int = 0);
float3 cam_pos; float3 cam_pos;
float3 pixel_0_0_center; float3 pixel_0_0_center;
float3 pixel_dir_x; float3 pixel_dir_x;

View File

@@ -24,7 +24,7 @@
// #pragma diag_suppress = 68 // #pragma diag_suppress = 68
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
// #pragma pop // #pragma pop
#include "../cuda/commands.h" #include "../gpu/commands.h"
#else #else
#pragma clang diagnostic push #pragma clang diagnostic push
#pragma clang diagnostic ignored "-Weverything" #pragma clang diagnostic ignored "-Weverything"

View File

@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
} }
// TODO: put intrinsics here. // TODO: put intrinsics here.
#if !defined(USE_ROCM)
IHD float3 operator+(const float3& a, const float3& b) { IHD float3 operator+(const float3& a, const float3& b) {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z); return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
} }
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
IHD float3 operator*(const float& a, const float3& b) { IHD float3 operator*(const float& a, const float3& b) {
return b * a; return b * a;
} }
#endif //! USE_ROCM
INLINE DEVICE float length(const float3& v) { INLINE DEVICE float length(const float3& v) {
// TODO: benchmark what's faster. // TODO: benchmark what's faster.

View File

@@ -283,9 +283,15 @@ GLOBAL void render(
(percent_allowed_difference > 0.f && (percent_allowed_difference > 0.f &&
max_closest_possible_intersection > depth_threshold) || max_closest_possible_intersection > depth_threshold) ||
tracker.get_n_hits() >= max_n_hits; tracker.get_n_hits() >= max_n_hits;
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
unsigned long long warp_done = __ballot(done);
int warp_done_bit_cnt = __popcll(warp_done);
#else
uint warp_done = thread_warp.ballot(done); uint warp_done = thread_warp.ballot(done);
int warp_done_bit_cnt = POPC(warp_done);
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
if (thread_warp.thread_rank() == 0) if (thread_warp.thread_rank() == 0)
ATOMICADD_B(&n_pixels_done, POPC(warp_done)); ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
// This sync is necessary to keep n_loaded until all threads are done with // This sync is necessary to keep n_loaded until all threads are done with
// painting. // painting.
thread_block.sync(); thread_block.sync();

View File

@@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
const float& gamma, const float& gamma,
const float& max_depth, const float& max_depth,
float& min_depth, float& min_depth,
const c10::optional<torch::Tensor>& bg_col, const std::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity, const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference, const float& percent_allowed_difference,
const uint& max_n_hits, const uint& max_n_hits,
const uint& mode) { const uint& mode) {
@@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
const float& gamma, const float& gamma,
const float& max_depth, const float& max_depth,
float min_depth, float min_depth,
const c10::optional<torch::Tensor>& bg_col, const std::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity, const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference, const float& percent_allowed_difference,
const uint& max_n_hits, const uint& max_n_hits,
const uint& mode) { const uint& mode) {
@@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
}; };
std::tuple< std::tuple<
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>> std::optional<torch::Tensor>>
Renderer::backward( Renderer::backward(
const torch::Tensor& grad_im, const torch::Tensor& grad_im,
const torch::Tensor& image, const torch::Tensor& image,
@@ -912,8 +912,8 @@ Renderer::backward(
const float& gamma, const float& gamma,
const float& max_depth, const float& max_depth,
float min_depth, float min_depth,
const c10::optional<torch::Tensor>& bg_col, const std::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity, const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference, const float& percent_allowed_difference,
const uint& max_n_hits, const uint& max_n_hits,
const uint& mode, const uint& mode,
@@ -922,7 +922,7 @@ Renderer::backward(
const bool& dif_rad, const bool& dif_rad,
const bool& dif_cam, const bool& dif_cam,
const bool& dif_opy, const bool& dif_opy,
const at::optional<std::pair<uint, uint>>& dbg_pos) { const std::optional<std::pair<uint, uint>>& dbg_pos) {
this->ensure_on_device(this->device_tracker.device()); this->ensure_on_device(this->device_tracker.device());
size_t batch_size; size_t batch_size;
size_t n_points; size_t n_points;
@@ -1045,14 +1045,14 @@ Renderer::backward(
} }
// Prepare the return value. // Prepare the return value.
std::tuple< std::tuple<
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>> std::optional<torch::Tensor>>
ret; ret;
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) { if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
return ret; return ret;

View File

@@ -44,21 +44,21 @@ struct Renderer {
const float& gamma, const float& gamma,
const float& max_depth, const float& max_depth,
float min_depth, float min_depth,
const c10::optional<torch::Tensor>& bg_col, const std::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity, const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference, const float& percent_allowed_difference,
const uint& max_n_hits, const uint& max_n_hits,
const uint& mode); const uint& mode);
std::tuple< std::tuple<
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>, std::optional<torch::Tensor>,
at::optional<torch::Tensor>> std::optional<torch::Tensor>>
backward( backward(
const torch::Tensor& grad_im, const torch::Tensor& grad_im,
const torch::Tensor& image, const torch::Tensor& image,
@@ -75,8 +75,8 @@ struct Renderer {
const float& gamma, const float& gamma,
const float& max_depth, const float& max_depth,
float min_depth, float min_depth,
const c10::optional<torch::Tensor>& bg_col, const std::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity, const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference, const float& percent_allowed_difference,
const uint& max_n_hits, const uint& max_n_hits,
const uint& mode, const uint& mode,
@@ -85,7 +85,7 @@ struct Renderer {
const bool& dif_rad, const bool& dif_rad,
const bool& dif_cam, const bool& dif_cam,
const bool& dif_opy, const bool& dif_opy,
const at::optional<std::pair<uint, uint>>& dbg_pos); const std::optional<std::pair<uint, uint>>& dbg_pos);
// Infrastructure. // Infrastructure.
/** /**
@@ -115,8 +115,8 @@ struct Renderer {
const float& gamma, const float& gamma,
const float& max_depth, const float& max_depth,
float& min_depth, float& min_depth,
const c10::optional<torch::Tensor>& bg_col, const std::optional<torch::Tensor>& bg_col,
const c10::optional<torch::Tensor>& opacity, const std::optional<torch::Tensor>& opacity,
const float& percent_allowed_difference, const float& percent_allowed_difference,
const uint& max_n_hits, const uint& max_n_hits,
const uint& mode); const uint& mode);

View File

@@ -8,6 +8,7 @@
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
#endif #endif
#include <torch/extension.h> #include <torch/extension.h>
@@ -33,13 +34,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
.contiguous(); .contiguous();
if (forw_info.device().type() == c10::DeviceType::CUDA) { if (forw_info.device().type() == c10::DeviceType::CUDA) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
cudaMemcpyAsync( C10_CUDA_CHECK(cudaMemcpyAsync(
result.data_ptr(), result.data_ptr(),
tmp.data_ptr(), tmp.data_ptr(),
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) * sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
tmp.size(3), tmp.size(3),
cudaMemcpyDeviceToDevice, cudaMemcpyDeviceToDevice,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream()));
#else #else
throw std::runtime_error( throw std::runtime_error(
"Copy on CUDA device initiated but built " "Copy on CUDA device initiated but built "

View File

@@ -7,6 +7,7 @@
*/ */
#ifdef WITH_CUDA #ifdef WITH_CUDA
#include <c10/cuda/CUDAException.h>
#include <cuda_runtime_api.h> #include <cuda_runtime_api.h>
namespace pulsar { namespace pulsar {
@@ -17,7 +18,8 @@ void cudaDevToDev(
const void* src, const void* src,
const int& size, const int& size,
const cudaStream_t& stream) { const cudaStream_t& stream) {
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream); C10_CUDA_CHECK(
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
} }
void cudaDevToHost( void cudaDevToHost(
@@ -25,7 +27,8 @@ void cudaDevToHost(
const void* src, const void* src,
const int& size, const int& size,
const cudaStream_t& stream) { const cudaStream_t& stream) {
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream); C10_CUDA_CHECK(
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
} }
} // namespace pytorch } // namespace pytorch

View File

@@ -9,7 +9,6 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <algorithm> #include <algorithm>
#include <list> #include <list>
#include <queue>
#include <thread> #include <thread>
#include <tuple> #include <tuple>
#include "ATen/core/TensorAccessor.h" #include "ATen/core/TensorAccessor.h"

View File

@@ -35,8 +35,6 @@ __global__ void FarthestPointSamplingKernel(
__shared__ int64_t selected_store; __shared__ int64_t selected_store;
// Get constants // Get constants
const int64_t N = points.size(0);
const int64_t P = points.size(1);
const int64_t D = points.size(2); const int64_t D = points.size(2);
// Get batch index and thread index // Get batch index and thread index

View File

@@ -376,8 +376,6 @@ PointLineDistanceBackward(
float tt = t_top / t_bot; float tt = t_top / t_bot;
tt = __saturatef(tt); tt = __saturatef(tt);
const float2 p_proj = (1.0f - tt) * v0 + tt * v1; const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
const float2 d = p - p_proj;
const float dist = sqrt(dot(d, d));
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p); const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p); const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);

View File

@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
): ):
synset_set.add(synset) synset_set.add(synset)
elif (synset in self.synset_inv.keys()) and ( elif (synset in self.synset_inv.keys()) and (
(path.isdir(path.join(data_dir, self.synset_inv[synset]))) path.isdir(path.join(data_dir, self.synset_inv[synset]))
): ):
synset_set.add(self.synset_inv[synset]) synset_set.add(self.synset_inv[synset])
else: else:

View File

@@ -36,7 +36,6 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
collated_dict["mesh"] = None collated_dict["mesh"] = None
if {"verts", "faces"}.issubset(collated_dict.keys()): if {"verts", "faces"}.issubset(collated_dict.keys()):
textures = None textures = None
if "textures" in collated_dict: if "textures" in collated_dict:
textures = TexturesAtlas(atlas=collated_dict["textures"]) textures = TexturesAtlas(atlas=collated_dict["textures"])

View File

@@ -26,7 +26,7 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
from pytorch3d.implicitron.dataset import types from pytorch3d.implicitron.dataset import orm_types, types
from pytorch3d.implicitron.dataset.utils import ( from pytorch3d.implicitron.dataset.utils import (
adjust_camera_to_bbox_crop_, adjust_camera_to_bbox_crop_,
adjust_camera_to_image_scale_, adjust_camera_to_image_scale_,
@@ -48,8 +48,12 @@ from pytorch3d.implicitron.dataset.utils import (
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
@dataclass @dataclass
class FrameData(Mapping[str, Any]): class FrameData(Mapping[str, Any]):
@@ -122,9 +126,9 @@ class FrameData(Mapping[str, Any]):
meta: A dict for storing additional frame information. meta: A dict for storing additional frame information.
""" """
frame_number: Optional[torch.LongTensor] frame_number: Optional[torch.LongTensor] = None
sequence_name: Union[str, List[str]] sequence_name: Union[str, List[str]] = ""
sequence_category: Union[str, List[str]] sequence_category: Union[str, List[str]] = ""
frame_timestamp: Optional[torch.Tensor] = None frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.LongTensor] = None image_size_hw: Optional[torch.LongTensor] = None
effective_image_size_hw: Optional[torch.LongTensor] = None effective_image_size_hw: Optional[torch.LongTensor] = None
@@ -155,7 +159,7 @@ class FrameData(Mapping[str, Any]):
new_params = {} new_params = {}
for field_name in iter(self): for field_name in iter(self):
value = getattr(self, field_name) value = getattr(self, field_name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
new_params[field_name] = value.to(*args, **kwargs) new_params[field_name] = value.to(*args, **kwargs)
else: else:
new_params[field_name] = value new_params[field_name] = value
@@ -417,7 +421,6 @@ class FrameData(Mapping[str, Any]):
for f in fields(elem): for f in fields(elem):
if not f.init: if not f.init:
continue continue
list_values = override_fields.get( list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch] f.name, [getattr(d, f.name) for d in batch]
) )
@@ -426,7 +429,7 @@ class FrameData(Mapping[str, Any]):
if all(list_value is not None for list_value in list_values) if all(list_value is not None for list_value in list_values)
else None else None
) )
return cls(**collated) return type(elem)(**collated)
elif isinstance(elem, Pointclouds): elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch) return join_pointclouds_as_batch(batch)
@@ -434,6 +437,8 @@ class FrameData(Mapping[str, Any]):
elif isinstance(elem, CamerasBase): elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space # TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch) return join_cameras_as_batch(batch)
elif isinstance(elem, Meshes):
return join_meshes_as_batch(batch)
else: else:
return torch.utils.data.dataloader.default_collate(batch) return torch.utils.data.dataloader.default_collate(batch)
@@ -454,8 +459,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
@abstractmethod @abstractmethod
def build( def build(
self, self,
frame_annotation: types.FrameAnnotation, frame_annotation: FrameAnnotationT,
sequence_annotation: types.SequenceAnnotation, sequence_annotation: SequenceAnnotationT,
*, *,
load_blobs: bool = True, load_blobs: bool = True,
**kwargs, **kwargs,
@@ -541,8 +546,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def build( def build(
self, self,
frame_annotation: types.FrameAnnotation, frame_annotation: FrameAnnotationT,
sequence_annotation: types.SequenceAnnotation, sequence_annotation: SequenceAnnotationT,
*, *,
load_blobs: bool = True, load_blobs: bool = True,
**kwargs, **kwargs,
@@ -586,58 +591,81 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
), ),
) )
fg_mask_np: Optional[np.ndarray] = None dataset_root = self.dataset_root
mask_annotation = frame_annotation.mask mask_annotation = frame_annotation.mask
if mask_annotation is not None: depth_annotation = frame_annotation.depth
if load_blobs and self.load_masks: image_path: str | None = None
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation) mask_path: str | None = None
depth_path: str | None = None
pcl_path: str | None = None
if dataset_root is not None: # set all paths even if we wont load blobs
if frame_annotation.image.path is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if mask_annotation is not None and mask_annotation.path:
mask_path = os.path.join(dataset_root, mask_annotation.path)
frame_data.mask_path = mask_path frame_data.mask_path = mask_path
if depth_annotation is not None and depth_annotation.path is not None:
depth_path = os.path.join(dataset_root, depth_annotation.path)
frame_data.depth_path = depth_path
if point_cloud is not None:
pcl_path = os.path.join(dataset_root, point_cloud.path)
frame_data.sequence_point_cloud_path = pcl_path
fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None
if mask_annotation is not None:
if load_blobs and self.load_masks and mask_path:
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float) frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
bbox_xywh = mask_annotation.bounding_box_xywh bbox_xywh = mask_annotation.bounding_box_xywh
if bbox_xywh is None and fg_mask_np is not None:
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
if frame_annotation.image is not None: if frame_annotation.image is not None:
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long) image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
frame_data.image_size_hw = image_size_hw # original image size frame_data.image_size_hw = image_size_hw # original image size
# image size after crop/resize # image size after crop/resize
frame_data.effective_image_size_hw = image_size_hw frame_data.effective_image_size_hw = image_size_hw
image_path = None
dataset_root = self.dataset_root
if frame_annotation.image.path is not None and dataset_root is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if load_blobs and self.load_images: if load_blobs and self.load_images:
if image_path is None: if image_path is None:
raise ValueError("Image path is required to load images.") raise ValueError("Image path is required to load images.")
image_np = load_image(self._local_path(image_path)) no_mask = fg_mask_np is None # didnt read the mask file
image_np = load_image(
self._local_path(image_path), try_read_alpha=no_mask
)
if image_np.shape[0] == 4: # RGBA image
if no_mask:
fg_mask_np = image_np[3:]
frame_data.fg_probability = safe_as_tensor(
fg_mask_np, torch.float
)
image_np = image_np[:3]
frame_data.image_rgb = self._postprocess_image( frame_data.image_rgb = self._postprocess_image(
image_np, frame_annotation.image.size, frame_data.fg_probability image_np, frame_annotation.image.size, frame_data.fg_probability
) )
if ( if bbox_xywh is None and fg_mask_np is not None:
load_blobs bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
and self.load_depths frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
and frame_annotation.depth is not None
and frame_annotation.depth.path is not None if load_blobs and self.load_depths and depth_path is not None:
): frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
( frame_annotation, depth_path, fg_mask_np
frame_data.depth_map, )
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(frame_annotation, fg_mask_np)
if load_blobs and self.load_point_clouds and point_cloud is not None: if load_blobs and self.load_point_clouds and point_cloud is not None:
pcl_path = self._fix_point_cloud_path(point_cloud.path) assert pcl_path is not None
frame_data.sequence_point_cloud = load_pointcloud( frame_data.sequence_point_cloud = load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points self._local_path(pcl_path), max_points=self.max_points
) )
frame_data.sequence_point_cloud_path = pcl_path
if frame_annotation.viewpoint is not None: if frame_annotation.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(frame_annotation) frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
@@ -653,18 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return frame_data return frame_data
def _load_fg_probability( def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
self, entry: types.FrameAnnotation fg_probability = load_mask(self._local_path(path))
) -> Tuple[np.ndarray, str]:
assert self.dataset_root is not None and entry.mask is not None
full_path = os.path.join(self.dataset_root, entry.mask.path)
fg_probability = load_mask(self._local_path(full_path))
if fg_probability.shape[-2:] != entry.image.size: if fg_probability.shape[-2:] != entry.image.size:
raise ValueError( raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
) )
return fg_probability, full_path return fg_probability
def _postprocess_image( def _postprocess_image(
self, self,
@@ -685,14 +709,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_mask_depth( def _load_mask_depth(
self, self,
entry: types.FrameAnnotation, entry: FrameAnnotationT,
path: str,
fg_mask: Optional[np.ndarray], fg_mask: Optional[np.ndarray],
) -> Tuple[torch.Tensor, str, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
entry_depth = entry.depth entry_depth = entry.depth
dataset_root = self.dataset_root dataset_root = self.dataset_root
assert dataset_root is not None assert dataset_root is not None
assert entry_depth is not None and entry_depth.path is not None assert entry_depth is not None
path = os.path.join(dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.mask_depths: if self.mask_depths:
@@ -706,11 +730,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
else: else:
depth_mask = (depth_map > 0.0).astype(np.float32) depth_mask = (depth_map > 0.0).astype(np.float32)
return torch.tensor(depth_map), path, torch.tensor(depth_mask) return torch.tensor(depth_map), torch.tensor(depth_mask)
def _get_pytorch3d_camera( def _get_pytorch3d_camera(
self, self,
entry: types.FrameAnnotation, entry: FrameAnnotationT,
) -> PerspectiveCameras: ) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None assert entry_viewpoint is not None
@@ -739,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
) )
def _fix_point_cloud_path(self, path: str) -> str:
"""
Fix up a point cloud path from the dataset.
Some files in Co3Dv2 have an accidental absolute path stored.
"""
unwanted_prefix = (
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str: def _local_path(self, path: str) -> str:
if self.path_manager is None: if self.path_manager is None:
return path return path

View File

@@ -222,7 +222,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
self.dataset_map = dataset_map self.dataset_map = dataset_map
def _load_category(self, category: str) -> DatasetMap: def _load_category(self, category: str) -> DatasetMap:
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz") frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join( sequence_file = os.path.join(
self.dataset_root, category, "sequence_annotations.jgz" self.dataset_root, category, "sequence_annotations.jgz"

View File

@@ -75,7 +75,6 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
def _load_data( def _load_data(
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
): ):
poses_arr = np.load( poses_arr = np.load(
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy")) _local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
) )
@@ -164,7 +163,6 @@ def ptstocam(pts, c2w):
def poses_avg(poses): def poses_avg(poses):
hwf = poses[0, :3, -1:] hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0) center = poses[:, :3, 3].mean(0)
@@ -192,7 +190,6 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
def recenter_poses(poses): def recenter_poses(poses):
poses_ = poses + 0 poses_ = poses + 0
bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
c2w = poses_avg(poses) c2w = poses_avg(poses)
@@ -256,7 +253,6 @@ def spherify_poses(poses, bds):
new_poses = [] new_poses = []
for th in np.linspace(0.0, 2.0 * np.pi, 120): for th in np.linspace(0.0, 2.0 * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.0]) up = np.array([0, 0, -1.0])
@@ -311,7 +307,6 @@ def load_llff_data(
path_zflat=False, path_zflat=False,
path_manager=None, path_manager=None,
): ):
poses, bds, imgs = _load_data( poses, bds, imgs = _load_data(
basedir, factor=factor, path_manager=path_manager basedir, factor=factor, path_manager=path_manager
) # factor=8 downsamples original imgs by 8x ) # factor=8 downsamples original imgs by 8x

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This functionality requires SQLAlchemy 2.0 or later. # This functionality requires SQLAlchemy 2.0 or later.
import math import math

View File

@@ -4,11 +4,15 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import hashlib import hashlib
import json import json
import logging import logging
import os import os
from dataclasses import dataclass
import urllib
from dataclasses import dataclass, Field, field
from typing import ( from typing import (
Any, Any,
ClassVar, ClassVar,
@@ -29,17 +33,18 @@ import sqlalchemy as sa
import torch import torch
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.frame_data import ( # noqa from pytorch3d.implicitron.dataset.frame_data import (
FrameData, FrameData,
FrameDataBuilder, FrameDataBuilder, # noqa
FrameDataBuilderBase, FrameDataBuilderBase,
) )
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
registry, registry,
ReplaceableBase, ReplaceableBase,
run_auto_creation, run_auto_creation,
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import scoped_session, Session, sessionmaker
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
@@ -51,7 +56,7 @@ _SET_LISTS_TABLE: str = "set_lists"
@registry.register @registry.register
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore class SqlIndexDataset(DatasetBase, ReplaceableBase):
""" """
A dataset with annotations stored as SQLite tables. This is an index-based dataset. A dataset with annotations stored as SQLite tables. This is an index-based dataset.
The length is returned after all sequence and frame filters are applied (see param The length is returned after all sequence and frame filters are applied (see param
@@ -88,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
engine verbatim. Dont expose it to end users of your application! engine verbatim. Dont expose it to end users of your application!
pick_categories: Restrict the dataset to the given list of categories. pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to. pick_sequences: A Sequence of sequence names to restrict the dataset to.
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
exclude_sequences: A Sequence of the names of the sequences to exclude. exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_per_category_to: Limit the dataset to the first up to N limit_sequences_per_category_to: Limit the dataset to the first up to N
sequences within each category (applies after all other sequence filters sequences within each category (applies after all other sequence filters
@@ -102,9 +108,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
more frames than that; applied after other frame-level filters. more frames than that; applied after other frame-level filters.
seed: The seed of the random generator sampling `n_frames_per_sequence` seed: The seed of the random generator sampling `n_frames_per_sequence`
random frames per sequence. random frames per sequence.
preload_metadata: If True, the metadata is preloaded into memory.
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
scoped_session: If True, allows different parts of the code to share
a global session to access the database.
""" """
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
SqlSequenceAnnotation
)
sqlite_metadata_file: str = "" sqlite_metadata_file: str = ""
dataset_root: Optional[str] = None dataset_root: Optional[str] = None
@@ -117,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
pick_categories: Tuple[str, ...] = () pick_categories: Tuple[str, ...] = ()
pick_sequences: Tuple[str, ...] = () pick_sequences: Tuple[str, ...] = ()
pick_sequences_sql_clause: Optional[str] = None
exclude_sequences: Tuple[str, ...] = () exclude_sequences: Tuple[str, ...] = ()
limit_sequences_per_category_to: int = 0 limit_sequences_per_category_to: int = 0
limit_sequences_to: int = 0 limit_sequences_to: int = 0
@@ -124,12 +138,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
n_frames_per_sequence: int = -1 n_frames_per_sequence: int = -1
seed: int = 0 seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000 remove_empty_masks_poll_whole_table_threshold: int = 300_000
preload_metadata: bool = False
precompute_seq_to_idx: bool = False
# we set it manually in the constructor # we set it manually in the constructor
# _index: pd.DataFrame = field(init=False) _index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
_sql_engine: sa.engine.Engine = field(
init=False, metadata={"omegaconf_ignore": True}
)
eval_batches: Optional[List[Any]] = field(
init=False, metadata={"omegaconf_ignore": True}
)
frame_data_builder: FrameDataBuilderBase frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
frame_data_builder_class_type: str = "FrameDataBuilder" frame_data_builder_class_type: str = "FrameDataBuilder"
scoped_session: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
if sa.__version__ < "2.0": if sa.__version__ < "2.0":
raise ImportError("This class requires SQL Alchemy 2.0 or later") raise ImportError("This class requires SQL Alchemy 2.0 or later")
@@ -138,19 +162,28 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
raise ValueError("sqlite_metadata_file must be set") raise ValueError("sqlite_metadata_file must be set")
if self.dataset_root: if self.dataset_root:
frame_builder_type = self.frame_data_builder_class_type frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[ getattr(self, frame_args)["dataset_root"] = self.dataset_root
"dataset_root" getattr(self, frame_args)["path_manager"] = self.path_manager
] = self.dataset_root
run_auto_creation(self) run_auto_creation(self)
self.frame_data_builder.path_manager = self.path_manager
# pyre-ignore # NOTE: sqlite-specific args (read-only mode). if self.path_manager is not None:
self.sqlite_metadata_file = self.path_manager.get_local_path(
self.sqlite_metadata_file
)
self.subset_lists_file = self.path_manager.get_local_path(
self.subset_lists_file
)
# NOTE: sqlite-specific args (read-only mode).
self._sql_engine = sa.create_engine( self._sql_engine = sa.create_engine(
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true" f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
) )
if self.preload_metadata:
self._sql_engine = self._preload_database(self._sql_engine)
sequences = self._get_filtered_sequences_if_any() sequences = self._get_filtered_sequences_if_any()
if self.subsets: if self.subsets:
@@ -166,16 +199,29 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
if len(index) == 0: if len(index) == 0:
raise ValueError(f"There are no frames in the subsets: {self.subsets}!") raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore self._index = index.set_index(["sequence_name", "frame_number"])
self.eval_batches = None # pyre-ignore self.eval_batches = None
if self.eval_batches_file: if self.eval_batches_file:
self.eval_batches = self._load_filter_eval_batches() self.eval_batches = self._load_filter_eval_batches()
logger.info(str(self)) logger.info(str(self))
if self.scoped_session:
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
if self.precompute_seq_to_idx:
# This is deprecated and will be removed in the future.
# After we backport https://github.com/facebookresearch/uco3d/pull/3
logger.warning(
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
)
self._index["rowid"] = np.arange(len(self._index))
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
del self._index["rowid"]
def __len__(self) -> int: def __len__(self) -> int:
# pyre-ignore[16]
return len(self._index) return len(self._index)
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData: def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
@@ -232,12 +278,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
self.frame_annotations_type.frame_number self.frame_annotations_type.frame_number
== int(frame), # cast from np.int64 == int(frame), # cast from np.int64
) )
seq_stmt = sa.select(SqlSequenceAnnotation).where( seq_stmt = sa.select(self.sequence_annotations_type).where(
SqlSequenceAnnotation.sequence_name == seq self.sequence_annotations_type.sequence_name == seq
) )
with Session(self._sql_engine) as session: if self.scoped_session:
entry = session.scalars(stmt).one() # pyre-ignore
seq_metadata = session.scalars(seq_stmt).one() with scoped_session(self._session_factory)() as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
else:
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"] assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
@@ -250,7 +302,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
return frame_data return frame_data
def __str__(self) -> str: def __str__(self) -> str:
# pyre-ignore[16]
return f"SqlIndexDataset #frames={len(self._index)}" return f"SqlIndexDataset #frames={len(self._index)}"
def sequence_names(self) -> Iterable[str]: def sequence_names(self) -> Iterable[str]:
@@ -260,9 +311,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
# override # override
def category_to_sequence_names(self) -> Dict[str, List[str]]: def category_to_sequence_names(self) -> Dict[str, List[str]]:
stmt = sa.select( stmt = sa.select(
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name self.sequence_annotations_type.category,
self.sequence_annotations_type.sequence_name,
).where( # we limit results to sequences that have frames after all filters ).where( # we limit results to sequences that have frames after all filters
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names()) self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
) )
with self._sql_engine.connect() as connection: with self._sql_engine.connect() as connection:
cat_to_seqs = pd.read_sql(stmt, connection) cat_to_seqs = pd.read_sql(stmt, connection)
@@ -335,17 +387,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
rows = self._index.index.get_loc(seq_name) rows = self._index.index.get_loc(seq_name)
if isinstance(rows, slice): if isinstance(rows, slice):
assert rows.stop is not None, "Unexpected result from pandas" assert rows.stop is not None, "Unexpected result from pandas"
rows = range(rows.start or 0, rows.stop, rows.step or 1) rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
else: else:
rows = np.where(rows)[0] rows_seq = list(np.where(rows)[0])
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices( index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
rows, seq_name, subset_filter rows_seq, seq_name, subset_filter
) )
index_slice["idx"] = idx index_slice["idx"] = idx
yield from index_slice.itertuples(index=False) yield from index_slice.itertuples(index=False)
# override
def sequence_indices_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
if self.precompute_seq_to_idx and subset_filter is None:
# pyre-ignore
yield from self._seq_to_indices[seq_name]
else:
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
yield idx
# override # override
def get_eval_batches(self) -> Optional[List[Any]]: def get_eval_batches(self) -> Optional[List[Any]]:
""" """
@@ -379,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
or self.limit_sequences_to > 0 or self.limit_sequences_to > 0
or self.limit_sequences_per_category_to > 0 or self.limit_sequences_per_category_to > 0
or len(self.pick_sequences) > 0 or len(self.pick_sequences) > 0
or self.pick_sequences_sql_clause is not None
or len(self.exclude_sequences) > 0 or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0 or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0 or self.n_frames_per_sequence > 0
) )
def _preload_database(
self, source_engine: sa.engine.base.Engine
) -> sa.engine.base.Engine:
destination_engine = sa.create_engine("sqlite:///:memory:")
metadata = sa.MetaData()
metadata.reflect(bind=source_engine)
metadata.create_all(bind=destination_engine)
with source_engine.connect() as source_conn:
with destination_engine.connect() as destination_conn:
for table_obj in metadata.tables.values():
# Select all rows from the source table
source_rows = source_conn.execute(table_obj.select())
# Insert rows into the destination table
for row in source_rows:
destination_conn.execute(table_obj.insert().values(row))
# Commit the changes for each table
destination_conn.commit()
return destination_engine
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]: def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible filter (if limit_sequences_per_category_to == 0): # maximum possible filter (if limit_sequences_per_category_to == 0):
# WHERE category IN 'self.pick_categories' # WHERE category IN 'self.pick_categories'
@@ -396,19 +486,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
*self._get_pick_filters(), *self._get_pick_filters(),
*self._get_exclude_filters(), *self._get_exclude_filters(),
] ]
if self.pick_sequences_sql_clause:
print("Applying the custom SQL clause.")
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
def add_where(stmt): def add_where(stmt):
return stmt.where(*where_conditions) if where_conditions else stmt return stmt.where(*where_conditions) if where_conditions else stmt
if self.limit_sequences_per_category_to <= 0: if self.limit_sequences_per_category_to <= 0:
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name)) stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
else: else:
subquery = sa.select( subquery = sa.select(
SqlSequenceAnnotation.sequence_name, self.sequence_annotations_type.sequence_name,
sa.func.row_number() sa.func.row_number()
.over( .over(
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
partition_by=SqlSequenceAnnotation.category, partition_by=self.sequence_annotations_type.category,
) )
.label("row_number"), .label("row_number"),
) )
@@ -444,31 +537,34 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
return [] return []
logger.info(f"Limiting dataset to categories: {self.pick_categories}") logger.info(f"Limiting dataset to categories: {self.pick_categories}")
return [SqlSequenceAnnotation.category.in_(self.pick_categories)] return [self.sequence_annotations_type.category.in_(self.pick_categories)]
def _get_pick_filters(self) -> List[sa.ColumnElement]: def _get_pick_filters(self) -> List[sa.ColumnElement]:
if not self.pick_sequences: if not self.pick_sequences:
return [] return []
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}") logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)] return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
def _get_exclude_filters(self) -> List[sa.ColumnOperators]: def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
if not self.exclude_sequences: if not self.exclude_sequences:
return [] return []
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}") logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)] return [
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
]
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame: def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
assert self.subsets is not None subsets = self.subsets
assert subsets is not None
with open(subset_lists_path, "r") as f: with open(subset_lists_path, "r") as f:
subset_to_seq_frame = json.load(f) subset_to_seq_frame = json.load(f)
seq_frame_list = sum( seq_frame_list = sum(
( (
[(*row, subset) for row in subset_to_seq_frame[subset]] [(*row, subset) for row in subset_to_seq_frame[subset]]
for subset in self.subsets for subset in subsets
), ),
[], [],
) )
@@ -522,7 +618,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
stmt = sa.select( stmt = sa.select(
self.frame_annotations_type.sequence_name, self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number, self.frame_annotations_type.frame_number,
).where(self.frame_annotations_type._mask_mass == 0) ).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
with Session(self._sql_engine) as session: with Session(self._sql_engine) as session:
to_remove = session.execute(stmt).all() to_remove = session.execute(stmt).all()
@@ -586,7 +682,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
stmt = sa.select( stmt = sa.select(
self.frame_annotations_type.sequence_name, self.frame_annotations_type.sequence_name,
self.frame_annotations_type.frame_number, self.frame_annotations_type.frame_number,
self.frame_annotations_type._image_path, self.frame_annotations_type._image_path, # pyre-ignore[16]
sa.null().label("subset"), sa.null().label("subset"),
) )
where_conditions = [] where_conditions = []
@@ -600,7 +696,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
logger.info(" excluding samples with empty masks") logger.info(" excluding samples with empty masks")
where_conditions.append( where_conditions.append(
sa.or_( sa.or_(
self.frame_annotations_type._mask_mass.is_(None), self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
self.frame_annotations_type._mask_mass != 0, self.frame_annotations_type._mask_mass != 0,
) )
) )
@@ -634,7 +730,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
assert self.eval_batches_file assert self.eval_batches_file
logger.info(f"Loading eval batches from {self.eval_batches_file}") logger.info(f"Loading eval batches from {self.eval_batches_file}")
if not os.path.isfile(self.eval_batches_file): if (
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
# The batch indices file does not exist. # The batch indices file does not exist.
# Most probably the user has not specified the root folder. # Most probably the user has not specified the root folder.
raise ValueError( raise ValueError(
@@ -642,7 +740,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
+ "Please specify a correct dataset_root folder." + "Please specify a correct dataset_root folder."
) )
with open(self.eval_batches_file, "r") as f: eval_batches_file = self._local_path(self.eval_batches_file)
with open(eval_batches_file, "r") as f:
eval_batches = json.load(f) eval_batches = json.load(f)
# limit the dataset to sequences to allow multiple evaluations in one file # limit the dataset to sequences to allow multiple evaluations in one file
@@ -726,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
self.frame_annotations_type.sequence_name == seq_name, self.frame_annotations_type.sequence_name == seq_name,
self.frame_annotations_type.frame_number.in_(frames), self.frame_annotations_type.frame_number.in_(frames),
) )
frame_no_ts = None
with self._sql_engine.connect() as connection: if self.scoped_session:
frame_no_ts = pd.read_sql_query(stmt, connection) stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
with scoped_session(self._session_factory)() as session: # pyre-ignore
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
else:
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if len(frame_no_ts) != len(index_slice): if len(frame_no_ts) != len(index_slice):
raise ValueError( raise ValueError(
@@ -758,11 +863,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
prefixes=["TEMP"], # NOTE SQLite specific! prefixes=["TEMP"], # NOTE SQLite specific!
) )
@classmethod
def pre_expand(cls) -> None:
# remove dataclass annotations that are not meant to be init params
# because they cause troubles for OmegaConf
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
if isinstance(attr_value, Field) and attr_value.metadata.get(
"omegaconf_ignore", False
):
delattr(cls, attr)
del cls.__annotations__[attr]
def _seq_name_to_seed(seq_name) -> int: def _seq_name_to_seed(seq_name) -> int:
"""Generates numbers in [0, 2 ** 28)""" """Generates numbers in [0, 2 ** 28)"""
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16) return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
def _safe_as_tensor(data, dtype):
return torch.tensor(data, dtype=dtype) if data is not None else None

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import os import os
@@ -43,7 +45,7 @@ logger = logging.getLogger(__name__)
@registry.register @registry.register
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
""" """
Generates the training, validation, and testing dataset objects for Generates the training, validation, and testing dataset objects for
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base. a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
@@ -193,9 +195,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
# this is a mould that is never constructed, used to build self._dataset_map values # this is a mould that is never constructed, used to build self._dataset_map values
dataset_class_type: str = "SqlIndexDataset" dataset_class_type: str = "SqlIndexDataset"
dataset: SqlIndexDataset dataset: SqlIndexDataset # pyre-ignore [13]
path_manager_factory: PathManagerFactory path_manager_factory: PathManagerFactory # pyre-ignore [13]
path_manager_factory_class_type: str = "PathManagerFactory" path_manager_factory_class_type: str = "PathManagerFactory"
def __post_init__(self): def __post_init__(self):
@@ -282,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
logger.info(f"Val dataset: {str(val_dataset)}") logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.") logger.debug("Extracting test dataset.")
eval_batches_file = self._get_lists_file("eval_batches") if self.eval_batches_path is None:
del common_dataset_kwargs["eval_batches_file"] eval_batches_file = None
else:
eval_batches_file = self._get_lists_file("eval_batches")
if "eval_batches_file" in common_dataset_kwargs:
common_dataset_kwargs.pop("eval_batches_file", None)
test_dataset = dataset_type( test_dataset = dataset_type(
**common_dataset_kwargs, **common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True), subsets=self._get_subsets(self.test_subsets, True),

View File

@@ -87,6 +87,15 @@ def is_train_frame(
def get_bbox_from_mask( def get_bbox_from_mask(
mask: np.ndarray, thr: float, decrease_quant: float = 0.05 mask: np.ndarray, thr: float, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]: ) -> Tuple[int, int, int, int]:
# these corner cases need to be handled in order to avoid an infinite loop
if mask.size == 0:
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
return 0, 0, 1, 1
if not mask.min() >= 0.0:
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
mask = mask.clip(min=0.0)
# bbox in xywh # bbox in xywh
masks_for_box = np.zeros_like(mask) masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0: while masks_for_box.sum() <= 1.0:
@@ -134,7 +143,15 @@ T = TypeVar("T", bound=torch.Tensor)
def bbox_xyxy_to_xywh(xyxy: T) -> T: def bbox_xyxy_to_xywh(xyxy: T) -> T:
wh = xyxy[2:] - xyxy[:2] wh = xyxy[2:] - xyxy[:2]
xywh = torch.cat([xyxy[:2], wh]) xywh = torch.cat([xyxy[:2], wh])
return xywh # pyre-ignore return xywh # pyre-ignore[7]
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
wh = xywh[2:]
if clamp_size is not None:
wh = wh.clamp(min=clamp_size)
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
return xyxy # pyre-ignore[7]
def get_clamp_bbox( def get_clamp_bbox(
@@ -180,16 +197,6 @@ def rescale_bbox(
return bbox * rel_size return bbox * rel_size
def bbox_xywh_to_xyxy(
xywh: torch.Tensor, clamp_size: Optional[int] = None
) -> torch.Tensor:
xyxy = xywh.clone()
if clamp_size is not None:
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
xyxy[2:] += xyxy[:2]
return xyxy
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]: def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
nz = np.flatnonzero(arr) nz = np.flatnonzero(arr)
return nz[0], nz[-1] + 1 return nz[0], nz[-1] + 1
@@ -201,18 +208,24 @@ def resize_image(
image_width: Optional[int], image_width: Optional[int],
mode: str = "bilinear", mode: str = "bilinear",
) -> Tuple[torch.Tensor, float, torch.Tensor]: ) -> Tuple[torch.Tensor, float, torch.Tensor]:
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
image = torch.from_numpy(image) image = torch.from_numpy(image)
if image_height is None or image_width is None: if (
image_height is None
or image_width is None
or image.shape[-2] == 0
or image.shape[-1] == 0
):
# skip the resizing # skip the resizing
return image, 1.0, torch.ones_like(image[:1]) return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor # takes numpy array or tensor, returns pytorch tensor
minscale = min( minscale = min(
image_height / image.shape[-2], image_height / image.shape[-2],
image_width / image.shape[-1], image_width / image.shape[-1],
) )
imre = torch.nn.functional.interpolate( imre = torch.nn.functional.interpolate(
image[None], image[None],
scale_factor=minscale, scale_factor=minscale,
@@ -220,6 +233,7 @@ def resize_image(
align_corners=False if mode == "bilinear" else None, align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True, recompute_scale_factor=True,
)[0] )[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width) imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width) mask = torch.zeros(1, image_height, image_width)
@@ -232,9 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
return im.astype(np.float32) / 255.0 return im.astype(np.float32) / 255.0
def load_image(path: str) -> np.ndarray: def load_image(
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
) -> np.ndarray:
"""
Load an image from a path and return it as a numpy array.
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
returned as the fourth channel.
Otherwise, the image is read as RGB and a three-channel image is returned.
"""
with Image.open(path) as pil_im: with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB")) # Check if the image has an alpha channel
if try_read_alpha and pil_im.mode == "RGBA":
im = np.array(pil_im)
else:
im = np.array(pil_im.convert(pil_format))
return transpose_normalize_image(im) return transpose_normalize_image(im)
@@ -329,6 +355,7 @@ def adjust_camera_to_bbox_crop_(
focal_length_px, principal_point_px = _convert_ndc_to_pixels( focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0], camera.focal_length[0],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera.principal_point[0], camera.principal_point[0],
image_size_wh, image_size_wh,
) )
@@ -341,6 +368,7 @@ def adjust_camera_to_bbox_crop_(
) )
camera.focal_length = focal_length[None] camera.focal_length = focal_length[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_cropped[None] camera.principal_point = principal_point_cropped[None]
@@ -352,6 +380,7 @@ def adjust_camera_to_image_scale_(
) -> PerspectiveCameras: ) -> PerspectiveCameras:
focal_length_px, principal_point_px = _convert_ndc_to_pixels( focal_length_px, principal_point_px = _convert_ndc_to_pixels(
camera.focal_length[0], camera.focal_length[0],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera.principal_point[0], camera.principal_point[0],
original_size_wh, original_size_wh,
) )
@@ -368,7 +397,8 @@ def adjust_camera_to_image_scale_(
image_size_wh_output, image_size_wh_output,
) )
camera.focal_length = focal_length_scaled[None] camera.focal_length = focal_length_scaled[None]
camera.principal_point = principal_point_scaled[None] # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
# NOTE this cache is per-worker; they are implemented as processes. # NOTE this cache is per-worker; they are implemented as processes.

View File

@@ -299,7 +299,6 @@ def eval_batch(
) )
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")): for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
loss_mask_now = mask_crop * loss_fg_mask loss_mask_now = mask_crop * loss_fg_mask
for rgb_metric_name, rgb_metric_fun in zip( for rgb_metric_name, rgb_metric_fun in zip(

View File

@@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
self.layers = torch.nn.ModuleList() self.layers = torch.nn.ModuleList()
self.proj_layers = torch.nn.ModuleList() self.proj_layers = torch.nn.ModuleList()
for stage in range(self.max_stage): for stage in range(self.max_stage):
stage_name = f"layer{stage+1}" stage_name = f"layer{stage + 1}"
feature_name = self._get_resnet_stage_feature_name(stage) feature_name = self._get_resnet_stage_feature_name(stage)
if (stage + 1) in self.stages: if (stage + 1) in self.stages:
if ( if (
@@ -139,12 +139,18 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
self.stages = set(self.stages) # convert to set for faster "in" self.stages = set(self.stages) # convert to set for faster "in"
def _get_resnet_stage_feature_name(self, stage) -> str: def _get_resnet_stage_feature_name(self, stage) -> str:
return f"res_layer_{stage+1}" return f"res_layer_{stage + 1}"
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
return (img - self._resnet_mean) / self._resnet_std return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self) -> int: def get_feat_dims(self) -> int:
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
# not a function.
return sum(self._feat_dim.values()) return sum(self._feat_dim.values())
def forward( def forward(
@@ -183,7 +189,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
else: else:
imgs_normed = imgs_resized imgs_normed = imgs_resized
# is not a function. # is not a function.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
feats = self.stem(imgs_normed) feats = self.stem(imgs_normed)
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
# `Union[Tensor, Module]`.
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
# `Union[Tensor, Module]`.
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)): for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
feats = layer(feats) feats = layer(feats)
# just a sanity check below # just a sanity check below

View File

@@ -478,6 +478,8 @@ class GenericModel(ImplicitronModelBase):
) )
custom_args["global_code"] = global_code custom_args["global_code"] = global_code
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for func in self._implicit_functions: for func in self._implicit_functions:
func.bind_args(**custom_args) func.bind_args(**custom_args)
@@ -500,6 +502,8 @@ class GenericModel(ImplicitronModelBase):
# Unbind the custom arguments to prevent pytorch from storing # Unbind the custom arguments to prevent pytorch from storing
# large buffers of intermediate results due to points in the # large buffers of intermediate results due to points in the
# bound arguments. # bound arguments.
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for func in self._implicit_functions: for func in self._implicit_functions:
func.unbind_args() func.unbind_args()

View File

@@ -71,6 +71,7 @@ class Autodecoder(Configurable, torch.nn.Module):
return key_map return key_map
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
return (self._autodecoder_codes.weight**2).mean() return (self._autodecoder_codes.weight**2).mean()
def get_encoding_dim(self) -> int: def get_encoding_dim(self) -> int:
@@ -95,6 +96,7 @@ class Autodecoder(Configurable, torch.nn.Module):
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as # pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
# `Tensor`. # `Tensor`.
x = torch.tensor( x = torch.tensor(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
[self._key_map[elem] for elem in x], [self._key_map[elem] for elem in x],
dtype=torch.long, dtype=torch.long,
device=next(self.parameters()).device, device=next(self.parameters()).device,
@@ -102,6 +104,7 @@ class Autodecoder(Configurable, torch.nn.Module):
except StopIteration: except StopIteration:
raise ValueError("Not enough n_instances in the autodecoder") from None raise ValueError("Not enough n_instances in the autodecoder") from None
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._autodecoder_codes(x) return self._autodecoder_codes(x)
def _load_key_map_hook( def _load_key_map_hook(

View File

@@ -122,6 +122,7 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
if frame_timestamp.shape[-1] != 1: if frame_timestamp.shape[-1] != 1:
raise ValueError("Frame timestamp's last dimensions should be one.") raise ValueError("Frame timestamp's last dimensions should be one.")
time = frame_timestamp / self.time_divisor time = frame_timestamp / self.time_divisor
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._harmonic_embedding(time) return self._harmonic_embedding(time)
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]: def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:

View File

@@ -232,9 +232,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
# if the skip tensor is None, we use `x` instead. # if the skip tensor is None, we use `x` instead.
z = x z = x
skipi = 0 skipi = 0
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
# `Union[Tensor, Module]`.
for li, layer in enumerate(self.mlp): for li, layer in enumerate(self.mlp):
# pyre-fixme[58]: `in` is not supported for right operand type
# `Union[Tensor, Module]`.
if li in self._input_skips: if li in self._input_skips:
if self._skip_affine_trans: if self._skip_affine_trans:
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
y = self._apply_affine_layer(self.skip_affines[skipi], y, z) y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
else: else:
y = torch.cat((y, z), dim=-1) y = torch.cat((y, z), dim=-1)

View File

@@ -141,11 +141,16 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
self.embed_fn is None and fun_viewpool is None and global_code is None self.embed_fn is None and fun_viewpool is None and global_code is None
): ):
return torch.tensor( return torch.tensor(
[], device=rays_points_world.device, dtype=rays_points_world.dtype [],
device=rays_points_world.device,
dtype=rays_points_world.dtype,
# pyre-fixme[6]: For 2nd argument expected `Union[int, SymInt]` but got
# `Union[Module, Tensor]`.
).view(0, self.out_dim) ).view(0, self.out_dim)
embeddings = [] embeddings = []
if self.embed_fn is not None: if self.embed_fn is not None:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
embeddings.append(self.embed_fn(rays_points_world)) embeddings.append(self.embed_fn(rays_points_world))
if fun_viewpool is not None: if fun_viewpool is not None:
@@ -164,13 +169,19 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
embedding = torch.cat(embeddings, dim=-1) embedding = torch.cat(embeddings, dim=-1)
x = embedding x = embedding
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
for layer_idx in range(self.num_layers - 1): for layer_idx in range(self.num_layers - 1):
if layer_idx in self.skip_in: if layer_idx in self.skip_in:
x = torch.cat([x, embedding], dim=-1) / 2**0.5 x = torch.cat([x, embedding], dim=-1) / 2**0.5
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
x = self.linear_layers[layer_idx](x) x = self.linear_layers[layer_idx](x)
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Module, Tensor]` is not a function.
if layer_idx < self.num_layers - 2: if layer_idx < self.num_layers - 2:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
x = self.softplus(x) x = self.softplus(x)
return x return x

View File

@@ -123,8 +123,10 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
# Normalize the ray_directions to unit l2 norm. # Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions. # Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed) rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self.color_layer((self.intermediate_linear(features), rays_embedding)) return self.color_layer((self.intermediate_linear(features), rays_embedding))
@staticmethod @staticmethod
@@ -195,6 +197,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
embeds = create_embeddings_for_implicit_function( embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world, xyz_world=rays_points_world,
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
# got `Union[None, Tensor, Module]`.
xyz_embedding_function=( xyz_embedding_function=(
self.harmonic_embedding_xyz if self.input_xyz else None self.harmonic_embedding_xyz if self.input_xyz else None
), ),
@@ -206,12 +210,14 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
) )
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3] # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
features = self.xyz_encoder(embeds) features = self.xyz_encoder(embeds)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz] # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size # NNs operate on the flattenned rays; reshaping to the correct spatial size
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape # TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
features = features.reshape(*rays_points_world.shape[:-1], -1) features = features.reshape(*rays_points_world.shape[:-1], -1)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raw_densities = self.density_layer(features) raw_densities = self.density_layer(features)
# raw_densities.shape = [minibatch x ... x 1] in [0-1] # raw_densities.shape = [minibatch x ... x 1] in [0-1]
@@ -219,6 +225,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
if camera is None: if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = ray_bundle.directions @ camera.R directions = ray_bundle.directions @ camera.R
else: else:
directions = ray_bundle.directions directions = ray_bundle.directions

View File

@@ -103,6 +103,8 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
embeds = create_embeddings_for_implicit_function( embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world, xyz_world=rays_points_world,
# pyre-fixme[6]: For 2nd argument expected `Optional[(...) -> Any]` but
# got `Union[Tensor, Module]`.
xyz_embedding_function=self._harmonic_embedding, xyz_embedding_function=self._harmonic_embedding,
global_code=global_code, global_code=global_code,
fun_viewpool=fun_viewpool, fun_viewpool=fun_viewpool,
@@ -112,6 +114,7 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
# Before running the network, we have to resize embeds to ndims=3, # Before running the network, we have to resize embeds to ndims=3,
# otherwise the SRN layers consume huge amounts of memory. # otherwise the SRN layers consume huge amounts of memory.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raymarch_features = self._net( raymarch_features = self._net(
embeds.view(embeds.shape[0], -1, embeds.shape[-1]) embeds.view(embeds.shape[0], -1, embeds.shape[-1])
) )
@@ -166,7 +169,9 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
# Normalize the ray_directions to unit l2 norm. # Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions. # Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_embedding = self._harmonic_embedding(rays_directions_normed) rays_embedding = self._harmonic_embedding(rays_directions_normed)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self._color_layer((features, rays_embedding)) return self._color_layer((features, rays_embedding))
def forward( def forward(
@@ -195,6 +200,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
denoting the color of each ray point. denoting the color of each ray point.
""" """
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3] # raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
features = self._net(raymarch_features) features = self._net(raymarch_features)
# features.shape = [minibatch x ... x self.n_hidden_units] # features.shape = [minibatch x ... x self.n_hidden_units]
@@ -202,6 +208,8 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
if camera is None: if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = ray_bundle.directions @ camera.R directions = ray_bundle.directions @ camera.R
else: else:
directions = ray_bundle.directions directions = ray_bundle.directions
@@ -209,6 +217,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
# NNs operate on the flattenned rays; reshaping to the correct spatial size # NNs operate on the flattenned rays; reshaping to the correct spatial size
features = features.reshape(*raymarch_features.shape[:-1], -1) features = features.reshape(*raymarch_features.shape[:-1], -1)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
raw_densities = self._density_layer(features) raw_densities = self._density_layer(features)
rays_colors = self._get_colors(features, directions) rays_colors = self._get_colors(features, directions)
@@ -269,6 +278,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
srn_raymarch_function. srn_raymarch_function.
""" """
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
net = self._hypernet(global_code) net = self._hypernet(global_code)
# use the hyper-net generated network to instantiate the raymarch module # use the hyper-net generated network to instantiate the raymarch module
@@ -296,7 +306,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
global_code=None, global_code=None,
**kwargs, **kwargs,
): ):
if global_code is None: if global_code is None:
raise ValueError("SRN Hypernetwork requires a non-trivial global code.") raise ValueError("SRN Hypernetwork requires a non-trivial global code.")
@@ -304,6 +313,8 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
# across LSTM iterations for the same global_code. # across LSTM iterations for the same global_code.
if self.cached_srn_raymarch_function is None: if self.cached_srn_raymarch_function is None:
# generate the raymarching network from the hypernet # generate the raymarching network from the hypernet
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
# `cached_srn_raymarch_function`.
self.cached_srn_raymarch_function = self._run_hypernet(global_code) self.cached_srn_raymarch_function = self._run_hypernet(global_code)
(srn_raymarch_function,) = cast( (srn_raymarch_function,) = cast(
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
@@ -331,6 +342,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def create_raymarch_function(self) -> None: def create_raymarch_function(self) -> None:
self.raymarch_function = SRNRaymarchFunction( self.raymarch_function = SRNRaymarchFunction(
latent_dim=self.latent_dim, latent_dim=self.latent_dim,
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.raymarch_function_args, **self.raymarch_function_args,
) )
@@ -389,6 +401,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
self.hypernet = SRNRaymarchHyperNet( self.hypernet = SRNRaymarchHyperNet(
latent_dim=self.latent_dim, latent_dim=self.latent_dim,
latent_dim_hypernet=self.latent_dim_hypernet, latent_dim_hypernet=self.latent_dim_hypernet,
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.hypernet_args, **self.hypernet_args,
) )

View File

@@ -40,7 +40,6 @@ def create_embeddings_for_implicit_function(
xyz_embedding_function: Optional[Callable], xyz_embedding_function: Optional[Callable],
diag_cov: Optional[torch.Tensor] = None, diag_cov: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
if xyz_in_camera_coords: if xyz_in_camera_coords:
@@ -64,7 +63,6 @@ def create_embeddings_for_implicit_function(
0, 0,
) )
else: else:
embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov) embeds = xyz_embedding_function(ray_points_for_embed, diag_cov=diag_cov)
embeds = embeds.reshape( embeds = embeds.reshape(
bs, bs,

View File

@@ -269,6 +269,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
for name, tensor in vars(grid_values_with_wanted_resolution).items() for name, tensor in vars(grid_values_with_wanted_resolution).items()
} }
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
return self.values_type(**params), True return self.values_type(**params), True
def get_resolution_change_epochs(self) -> Tuple[int, ...]: def get_resolution_change_epochs(self) -> Tuple[int, ...]:
@@ -882,6 +883,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
torch.Tensor of shape (..., n_features) torch.Tensor of shape (..., n_features)
""" """
locator = self._get_volume_locator() locator = self._get_volume_locator()
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params) grid_values = self.voxel_grid.values_type(**self.params)
# voxel grids operate with extra n_grids dimension, which we fix to one # voxel grids operate with extra n_grids dimension, which we fix to one
return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0] return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
@@ -895,6 +897,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
replace current parameters replace current parameters
""" """
if self.hold_voxel_grid_as_parameters: if self.hold_voxel_grid_as_parameters:
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
self.params = torch.nn.ParameterDict( self.params = torch.nn.ParameterDict(
{ {
k: torch.nn.Parameter(val) k: torch.nn.Parameter(val)
@@ -945,6 +948,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
Returns: Returns:
True if parameter change has happened else False. True if parameter change has happened else False.
""" """
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
grid_values = self.voxel_grid.values_type(**self.params) grid_values = self.voxel_grid.values_type(**self.params)
grid_values, change = self.voxel_grid.change_resolution( grid_values, change = self.voxel_grid.change_resolution(
grid_values, epoch=epoch grid_values, epoch=epoch
@@ -992,16 +996,21 @@ class VoxelGridModule(Configurable, torch.nn.Module):
""" """
''' '''
new_params = {} new_params = {}
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
# function.
for name in self.params: for name in self.params:
key = prefix + "params." + name key = prefix + "params." + name
if key in state_dict: if key in state_dict:
new_params[name] = torch.zeros_like(state_dict[key]) new_params[name] = torch.zeros_like(state_dict[key])
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params)) self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
def get_device(self) -> torch.device: def get_device(self) -> torch.device:
""" """
Returns torch.device on which module parameters are located Returns torch.device on which module parameters are located
""" """
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
# not a function.
return next(val for val in self.params.values() if val is not None).device return next(val for val in self.params.values() if val is not None).device
def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None: def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
@@ -1018,6 +1027,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
""" """
locator = self._get_volume_locator() locator = self._get_volume_locator()
# torch.nn.modules.module.Module]` is not a function. # torch.nn.modules.module.Module]` is not a function.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
old_grid_values = self.voxel_grid.values_type(**self.params) old_grid_values = self.voxel_grid.values_type(**self.params)
new_grid_values = self.voxel_grid.crop_world( new_grid_values = self.voxel_grid.crop_world(
min_point, max_point, old_grid_values, locator min_point, max_point, old_grid_values, locator
@@ -1025,6 +1035,7 @@ class VoxelGridModule(Configurable, torch.nn.Module):
grid_values, _ = self.voxel_grid.change_resolution( grid_values, _ = self.voxel_grid.change_resolution(
new_grid_values, grid_values_with_wanted_resolution=old_grid_values new_grid_values, grid_values_with_wanted_resolution=old_grid_values
) )
# pyre-fixme[16]: `VoxelGridModule` has no attribute `params`.
self.params = torch.nn.ParameterDict( self.params = torch.nn.ParameterDict(
{ {
k: torch.nn.Parameter(val) k: torch.nn.Parameter(val)

View File

@@ -192,16 +192,26 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def __post_init__(self) -> None: def __post_init__(self) -> None:
run_auto_creation(self) run_auto_creation(self)
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `voxel_grid_scaffold`.
self.voxel_grid_scaffold = self._create_voxel_grid_scaffold() self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `harmonic_embedder_xyz_density`.
self.harmonic_embedder_xyz_density = HarmonicEmbedding( self.harmonic_embedder_xyz_density = HarmonicEmbedding(
**self.harmonic_embedder_xyz_density_args **self.harmonic_embedder_xyz_density_args
) )
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `harmonic_embedder_xyz_color`.
self.harmonic_embedder_xyz_color = HarmonicEmbedding( self.harmonic_embedder_xyz_color = HarmonicEmbedding(
**self.harmonic_embedder_xyz_color_args **self.harmonic_embedder_xyz_color_args
) )
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `harmonic_embedder_dir_color`.
self.harmonic_embedder_dir_color = HarmonicEmbedding( self.harmonic_embedder_dir_color = HarmonicEmbedding(
**self.harmonic_embedder_dir_color_args **self.harmonic_embedder_dir_color_args
) )
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `_scaffold_ready`.
self._scaffold_ready = False self._scaffold_ready = False
def forward( def forward(
@@ -252,6 +262,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
# ########## filter the points using the scaffold ########## # # ########## filter the points using the scaffold ########## #
if self._scaffold_ready and self.scaffold_filter_points: if self._scaffold_ready and self.scaffold_filter_points:
with torch.no_grad(): with torch.no_grad():
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0 non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
points = points[non_empty_points] points = points[non_empty_points]
if len(points) == 0: if len(points) == 0:
@@ -363,6 +374,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
feature dimensionality which `decoder_density` returns feature dimensionality which `decoder_density` returns
""" """
embeds_density = self.voxel_grid_density(points) embeds_density = self.voxel_grid_density(points)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density) harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
# shape = [..., density_dim] # shape = [..., density_dim]
return self.decoder_density(harmonic_embedding_density) return self.decoder_density(harmonic_embedding_density)
@@ -397,6 +409,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
if self.xyz_ray_dir_in_camera_coords: if self.xyz_ray_dir_in_camera_coords:
if camera is None: if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
# pyre-fixme[58]: `@` is not supported for operand types `Tensor` and
# `Union[Tensor, Module]`.
directions = directions @ camera.R directions = directions @ camera.R
# ########## get voxel grid output ########## # # ########## get voxel grid output ########## #
@@ -405,11 +419,13 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
# ########## embed with the harmonic function ########## # # ########## embed with the harmonic function ########## #
# Obtain the harmonic embedding of the voxel grid output. # Obtain the harmonic embedding of the voxel grid output.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color) harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
# Normalize the ray_directions to unit l2 norm. # Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1) rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions. # Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
harmonic_embedding_dir = self.harmonic_embedder_dir_color( harmonic_embedding_dir = self.harmonic_embedder_dir_color(
rays_directions_normed rays_directions_normed
) )
@@ -478,8 +494,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
an object inside, else False. an object inside, else False.
""" """
# find bounding box # find bounding box
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_grid_points`.
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch) points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
assert self._scaffold_ready, "Scaffold has to be calculated before cropping." assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0 occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
non_zero_idxs = torch.nonzero(occupancy) non_zero_idxs = torch.nonzero(occupancy)
if len(non_zero_idxs) == 0: if len(non_zero_idxs) == 0:
@@ -511,6 +530,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
""" """
planes = [] planes = []
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `get_grid_points`.
points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch) points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
chunk_size = ( chunk_size = (
@@ -530,7 +551,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
stride=1, stride=1,
) )
occupancy_cube = density_cube > self.scaffold_empty_space_threshold occupancy_cube = density_cube > self.scaffold_empty_space_threshold
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `params`.
self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float() self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
# pyre-fixme[16]: `VoxelGridImplicitFunction` has no attribute
# `_scaffold_ready`.
self._scaffold_ready = True self._scaffold_ready = True
return False return False
@@ -547,6 +571,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
decoding function to this value. decoding function to this value.
""" """
grid_args = self.voxel_grid_density_args grid_args = self.voxel_grid_density_args
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
# `Union[Tensor, Module]`.
grid_output_dim = VoxelGridModule.get_output_dim(grid_args) grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
embedder_args = self.harmonic_embedder_xyz_density_args embedder_args = self.harmonic_embedder_xyz_density_args
@@ -575,6 +601,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
decoding function to this value. decoding function to this value.
""" """
grid_args = self.voxel_grid_color_args grid_args = self.voxel_grid_color_args
# pyre-fixme[6]: For 1st argument expected `DictConfig` but got
# `Union[Tensor, Module]`.
grid_output_dim = VoxelGridModule.get_output_dim(grid_args) grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
embedder_args = self.harmonic_embedder_xyz_color_args embedder_args = self.harmonic_embedder_xyz_color_args
@@ -608,7 +636,9 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
`self.voxel_grid_density` `self.voxel_grid_density`
""" """
return VoxelGridModule( return VoxelGridModule(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
extents=self.voxel_grid_density_args["extents"], extents=self.voxel_grid_density_args["extents"],
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
translation=self.voxel_grid_density_args["translation"], translation=self.voxel_grid_density_args["translation"],
voxel_grid_class_type="FullResolutionVoxelGrid", voxel_grid_class_type="FullResolutionVoxelGrid",
hold_voxel_grid_as_parameters=False, hold_voxel_grid_as_parameters=False,

View File

@@ -6,7 +6,6 @@
# pyre-unsafe # pyre-unsafe
import warnings import warnings
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
_rgb_metrics( _rgb_metrics(
image_rgb, image_rgb,
image_rgb_pred, image_rgb_pred,
fg_probability, masks=fg_probability,
fg_probability_pred, masks_crop=mask_crop,
mask_crop,
) )
) )
@@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
metrics["mask_neg_iou"] = utils.neg_iou_loss( metrics["mask_neg_iou"] = utils.neg_iou_loss(
fg_probability_pred, fg_probability, mask=mask_crop fg_probability_pred, fg_probability, mask=mask_crop
) )
metrics["mask_bce"] = utils.calc_bce( if torch.is_autocast_enabled():
fg_probability_pred, fg_probability, mask=mask_crop # To avoid issues with mixed precision
) metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred.logit(),
fg_probability,
mask=mask_crop,
pred_logits=True,
)
else:
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred,
fg_probability,
mask=mask_crop,
pred_logits=False,
)
if depth_map is not None and depth_map_pred is not None: if depth_map is not None and depth_map_pred is not None:
assert mask_crop is not None assert mask_crop is not None
@@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
if fg_probability is not None: if fg_probability is not None:
mask = fg_probability * mask_crop mask = fg_probability * mask_crop
_, abs_ = utils.eval_depth( _, abs_ = utils.eval_depth(
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0 depth_map_pred,
depth_map,
get_best_scale=True,
mask=mask,
crop=0,
) )
metrics["depth_abs_fg"] = abs_.mean() metrics["depth_abs_fg"] = abs_.mean()
@@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
return metrics return metrics
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop): def _rgb_metrics(
images,
images_pred,
masks=None,
masks_crop=None,
huber_scaling: float = 0.03,
):
assert masks_crop is not None assert masks_crop is not None
if images.shape[1] != images_pred.shape[1]: if images.shape[1] != images_pred.shape[1]:
raise ValueError( raise ValueError(
f"Network output's RGB images had {images_pred.shape[1]} " f"Network output's RGB images had {images_pred.shape[1]} "
f"channels. {images.shape[1]} expected." f"channels. {images.shape[1]} expected."
) )
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03) rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
crop_mass = masks_crop.sum().clamp(1.0) crop_mass = masks_crop.sum().clamp(1.0)
results = { results = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop), "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
} }

View File

@@ -135,6 +135,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
break break
# run the lstm marcher # run the lstm marcher
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
state_h, state_c = self._lstm( state_h, state_c = self._lstm(
raymarch_features.view(-1, raymarch_features.shape[-1]), raymarch_features.view(-1, raymarch_features.shape[-1]),
states[-1], states[-1],
@@ -142,6 +143,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
if state_h.requires_grad: if state_h.requires_grad:
state_h.register_hook(lambda x: x.clamp(min=-10, max=10)) state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
# predict the next step size # predict the next step size
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape) signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
# log the lstm states # log the lstm states
states.append((state_h, state_c)) states.append((state_h, state_c))

View File

@@ -207,6 +207,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
""" """
sample_mask = None sample_mask = None
if ( if (
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[An...
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
and mask is not None and mask is not None
): ):
@@ -223,6 +224,7 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
EvaluationMode.EVALUATION: self._evaluation_raysampler, EvaluationMode.EVALUATION: self._evaluation_raysampler,
}[evaluation_mode] }[evaluation_mode]
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
ray_bundle = raysampler( ray_bundle = raysampler(
cameras=cameras, cameras=cameras,
mask=sample_mask, mask=sample_mask,
@@ -240,6 +242,8 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
"Heterogeneous ray bundle is not supported for conical frustum computation yet" "Heterogeneous ray bundle is not supported for conical frustum computation yet"
) )
elif self.cast_ray_bundle_as_cone: elif self.cast_ray_bundle_as_cone:
# pyre-fixme[9]: pixel_hw has type `Tuple[float, float]`; used as
# `Tuple[Union[Tensor, Module], Union[Tensor, Module]]`.
pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width) pixel_hw: Tuple[float, float] = (self.pixel_height, self.pixel_width)
pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw) pixel_radii_2d = compute_radii(cameras, ray_bundle.xys[..., :2], pixel_hw)
return ImplicitronRayBundle( return ImplicitronRayBundle(

View File

@@ -179,8 +179,10 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
rays_densities = torch.relu(rays_densities) rays_densities = torch.relu(rays_densities)
weighted_densities = deltas * rays_densities weighted_densities = deltas * rays_densities
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
capped_densities = self._capping_function(weighted_densities) capped_densities = self._capping_function(weighted_densities)
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
rays_opacities = self._capping_function( rays_opacities = self._capping_function(
torch.cumsum(weighted_densities, dim=-1) torch.cumsum(weighted_densities, dim=-1)
) )
@@ -190,6 +192,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
) )
absorption_shifted[..., : self.surface_thickness] = 1.0 absorption_shifted[..., : self.surface_thickness] = 1.0
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
weights = self._weight_function(capped_densities, absorption_shifted) weights = self._weight_function(capped_densities, absorption_shifted)
features = (weights[..., None] * rays_features).sum(dim=-2) features = (weights[..., None] * rays_features).sum(dim=-2)
depth = (weights * ray_lengths)[..., None].sum(dim=-2) depth = (weights * ray_lengths)[..., None].sum(dim=-2)
@@ -197,6 +200,8 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
alpha = opacities if self.blend_output else 1 alpha = opacities if self.blend_output else 1
if self._bg_color.shape[-1] not in [1, features.shape[-1]]: if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
raise ValueError("Wrong number of background color channels.") raise ValueError("Wrong number of background color channels.")
# pyre-fixme[58]: `*` is not supported for operand types `int` and
# `Union[Tensor, Module]`.
features = alpha * features + (1 - opacities) * self._bg_color features = alpha * features + (1 - opacities) * self._bg_color
return RendererOutput( return RendererOutput(

View File

@@ -61,6 +61,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
def create_ray_tracer(self) -> None: def create_ray_tracer(self) -> None:
self.ray_tracer = RayTracing( self.ray_tracer = RayTracing(
# pyre-fixme[32]: Keyword argument must be a mapping with string keys.
**self.ray_tracer_args, **self.ray_tracer_args,
object_bounding_sphere=self.object_bounding_sphere, object_bounding_sphere=self.object_bounding_sphere,
) )
@@ -149,6 +150,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
n_eik_points, n_eik_points,
3, 3,
# but got `Union[device, Tensor, Module]`. # but got `Union[device, Tensor, Module]`.
# pyre-fixme[6]: For 3rd argument expected `Union[None, int, str,
# device]` but got `Union[device, Tensor, Module]`.
device=self._bg_color.device, device=self._bg_color.device,
).uniform_(-eik_bounding_box, eik_bounding_box) ).uniform_(-eik_bounding_box, eik_bounding_box)
eikonal_pixel_points = points.clone() eikonal_pixel_points = points.clone()
@@ -205,6 +208,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
] ]
normals_full.view(-1, 3)[surface_mask] = normals normals_full.view(-1, 3)[surface_mask] = normals
render_full.view(-1, self.render_features_dimensions)[surface_mask] = ( render_full.view(-1, self.render_features_dimensions)[surface_mask] = (
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
self._rgb_network( self._rgb_network(
features, features,
differentiable_surface_points[None], differentiable_surface_points[None],
@@ -216,8 +220,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
) )
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid( mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`. # pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
-self.soft_mask_alpha -self.soft_mask_alpha * sdf_output[~surface_mask]
* sdf_output[~surface_mask]
) )
# scatter points with surface_mask # scatter points with surface_mask

View File

@@ -532,6 +532,7 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
# does not produce nans randomly unlike get_camera_center() below # does not produce nans randomly unlike get_camera_center() below
cam_centers_rep = -torch.bmm( cam_centers_rep = -torch.bmm(
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
camera_rep.T[:, None], camera_rep.T[:, None],
camera_rep.R.permute(0, 2, 1), camera_rep.R.permute(0, 2, 1),
).reshape(-1, *([1] * (pts.ndim - 2)), 3) ).reshape(-1, *([1] * (pts.ndim - 2)), 3)

View File

@@ -209,6 +209,7 @@ def handle_seq_id(
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device) seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
# pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has # pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has
# no attribute `to`. # no attribute `to`.
# pyre-fixme[7]: Expected `LongTensor` but got `Tensor`.
return seq_id.to(device) return seq_id.to(device)

View File

@@ -21,7 +21,6 @@ def cleanup_eval_depth(
sigma: float = 0.01, sigma: float = 0.01,
image=None, image=None,
): ):
ba, _, H, W = depth.shape ba, _, H, W = depth.shape
pcl = point_cloud.points_padded() pcl = point_cloud.points_padded()

View File

@@ -6,12 +6,15 @@
# pyre-unsafe # pyre-unsafe
import logging
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
logger = logging.getLogger(__name__)
def eval_depth( def eval_depth(
pred: torch.Tensor, pred: torch.Tensor,
@@ -21,6 +24,8 @@ def eval_depth(
get_best_scale: bool = True, get_best_scale: bool = True,
mask_thr: float = 0.5, mask_thr: float = 0.5,
best_scale_clamp_thr: float = 1e-4, best_scale_clamp_thr: float = 1e-4,
use_disparity: bool = False,
disparity_eps: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Evaluate the depth error between the prediction `pred` and the ground Evaluate the depth error between the prediction `pred` and the ground
@@ -64,6 +69,13 @@ def eval_depth(
# s.t. we get best possible mse error # s.t. we get best possible mse error
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr) scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
pred = pred * scale_best[:, None, None, None] pred = pred * scale_best[:, None, None, None]
if use_disparity:
gt = torch.div(1.0, (gt + disparity_eps))
pred = torch.div(1.0, (pred + disparity_eps))
scale_best = estimate_depth_scale_factor(
pred, gt, dmask, best_scale_clamp_thr
).detach()
pred = pred * scale_best[:, None, None, None]
df = gt - pred df = gt - pred
@@ -117,6 +129,7 @@ def calc_bce(
pred_eps: float = 0.01, pred_eps: float = 0.01,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
lerp_bound: Optional[float] = None, lerp_bound: Optional[float] = None,
pred_logits: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Calculates the binary cross entropy. Calculates the binary cross entropy.
@@ -139,9 +152,23 @@ def calc_bce(
weight = torch.ones_like(gt) * mask weight = torch.ones_like(gt) * mask
if lerp_bound is not None: if lerp_bound is not None:
# binary_cross_entropy_lerp requires pred to be in [0, 1]
if pred_logits:
pred = F.sigmoid(pred)
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound) return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
else: else:
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight) if pred_logits:
loss = F.binary_cross_entropy_with_logits(
pred,
gt,
reduction="none",
weight=weight,
)
else:
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
return loss.mean()
def binary_cross_entropy_lerp( def binary_cross_entropy_lerp(

View File

@@ -111,10 +111,10 @@ def load_model(fl, map_location: Optional[dict]):
flstats = get_stats_path(fl) flstats = get_stats_path(fl)
flmodel = get_model_path(fl) flmodel = get_model_path(fl)
flopt = get_optimizer_path(fl) flopt = get_optimizer_path(fl)
model_state_dict = torch.load(flmodel, map_location=map_location) model_state_dict = torch.load(flmodel, map_location=map_location, weights_only=True)
stats = load_stats(flstats) stats = load_stats(flstats)
if os.path.isfile(flopt): if os.path.isfile(flopt):
optimizer = torch.load(flopt, map_location=map_location) optimizer = torch.load(flopt, map_location=map_location, weights_only=True)
else: else:
optimizer = None optimizer = None

View File

@@ -100,7 +100,6 @@ def render_point_cloud_pytorch3d(
bin_size: Optional[int] = None, bin_size: Optional[int] = None,
**kwargs, **kwargs,
): ):
# feature dimension # feature dimension
featdim = point_cloud.features_packed().shape[-1] featdim = point_cloud.features_packed().shape[-1]

View File

@@ -37,7 +37,6 @@ class AverageMeter:
self.count = 0 self.count = 0
def update(self, val, n=1, epoch=0): def update(self, val, n=1, epoch=0):
# make sure the history is of the same len as epoch # make sure the history is of the same len as epoch
while len(self.history) <= epoch: while len(self.history) <= epoch:
self.history.append([]) self.history.append([])
@@ -115,7 +114,6 @@ class Stats:
visdom_server="http://localhost", visdom_server="http://localhost",
visdom_port=8097, visdom_port=8097,
): ):
self.log_vars = log_vars self.log_vars = log_vars
self.visdom_env = visdom_env self.visdom_env = visdom_env
self.visdom_server = visdom_server self.visdom_server = visdom_server
@@ -202,7 +200,6 @@ class Stats:
self.log_vars.append(add_log_var) self.log_vars.append(add_log_var)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"): def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
if self.epoch == -1: # uninitialized if self.epoch == -1: # uninitialized
logger.warning( logger.warning(
"epoch==-1 means uninitialized stats structure -> new_epoch() called" "epoch==-1 means uninitialized stats structure -> new_epoch() called"
@@ -219,7 +216,6 @@ class Stats:
epoch = self.epoch epoch = self.epoch
for stat in self.log_vars: for stat in self.log_vars:
if stat not in self.stats[stat_set]: if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter() self.stats[stat_set][stat] = AverageMeter()
@@ -248,7 +244,6 @@ class Stats:
self.stats[stat_set][stat].update(val, epoch=epoch, n=1) self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
def get_epoch_averages(self, epoch=None): def get_epoch_averages(self, epoch=None):
stat_sets = list(self.stats.keys()) stat_sets = list(self.stats.keys())
if epoch is None: if epoch is None:
@@ -345,7 +340,6 @@ class Stats:
def plot_stats( def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
): ):
# use the cached visdom env if none supplied # use the cached visdom env if none supplied
if visdom_env is None: if visdom_env is None:
visdom_env = self.visdom_env visdom_env = self.visdom_env
@@ -449,7 +443,6 @@ class Stats:
warnings.warn("Cant dump stats due to insufficient permissions!") warnings.warn("Cant dump stats due to insufficient permissions!")
def synchronize_logged_vars(self, log_vars, default_val=float("NaN")): def synchronize_logged_vars(self, log_vars, default_val=float("NaN")):
stat_sets = list(self.stats.keys()) stat_sets = list(self.stats.keys())
# remove the additional log_vars # remove the additional log_vars
@@ -490,11 +483,12 @@ class Stats:
for ep in range(lastep): for ep in range(lastep):
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep) self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
epoch_generated = self.stats[stat_set][stat].get_epoch() epoch_generated = self.stats[stat_set][stat].get_epoch()
assert ( assert epoch_generated == self.epoch + 1, (
epoch_generated == self.epoch + 1 "bad epoch of synchronized log_var! %d vs %d"
), "bad epoch of synchronized log_var! %d vs %d" % ( % (
self.epoch + 1, self.epoch + 1,
epoch_generated, epoch_generated,
)
) )

View File

@@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
_NO_TORCHVISION = False
try:
import torchvision
except ImportError:
_NO_TORCHVISION = True
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg") _DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
matplotlib.use("Agg") matplotlib.use("Agg")
@@ -36,6 +45,7 @@ class VideoWriter:
fps: int = 20, fps: int = 20,
output_format: str = "visdom", output_format: str = "visdom",
rmdir_allowed: bool = False, rmdir_allowed: bool = False,
use_torchvision_video_writer: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@@ -49,6 +59,8 @@ class VideoWriter:
is supported. is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty. it is not empty.
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
to write the video
""" """
self.rmdir_allowed = rmdir_allowed self.rmdir_allowed = rmdir_allowed
self.output_format = output_format self.output_format = output_format
@@ -56,10 +68,14 @@ class VideoWriter:
self.out_path = out_path self.out_path = out_path
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.ffmpeg_bin = ffmpeg_bin self.ffmpeg_bin = ffmpeg_bin
self.use_torchvision_video_writer = use_torchvision_video_writer
self.frames = [] self.frames = []
self.regexp = "frame_%08d.png" self.regexp = "frame_%08d.png"
self.frame_num = 0 self.frame_num = 0
if self.use_torchvision_video_writer:
assert not _NO_TORCHVISION, "torchvision not available"
if self.cache_dir is not None: if self.cache_dir is not None:
self.tmp_dir = None self.tmp_dir = None
if os.path.isdir(self.cache_dir): if os.path.isdir(self.cache_dir):
@@ -114,7 +130,7 @@ class VideoWriter:
resize = im.size resize = im.size
# make sure size is divisible by 2 # make sure size is divisible by 2
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)]) resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
im = im.resize(resize, Image.ANTIALIAS) im = im.resize(resize, Image.ANTIALIAS)
im.save(outfile) im.save(outfile)
@@ -139,38 +155,56 @@ class VideoWriter:
# got `Optional[str]`. # got `Optional[str]`.
regexp = os.path.join(self.cache_dir, self.regexp) regexp = os.path.join(self.cache_dir, self.regexp)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
if self.output_format == "visdom": # works for ppt too if self.output_format == "visdom": # works for ppt too
args = [ # Video codec parameters
self.ffmpeg_bin, video_codec = "h264"
"-r", crf = "18"
str(self.fps), b = "2000k"
"-i", pix_fmt = "yuv420p"
regexp,
"-vcodec", if self.use_torchvision_video_writer:
"h264", torchvision.io.write_video(
"-f", self.out_path,
"mp4", torch.stack(
"-y", [torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
"-crf", ),
"18", fps=self.fps,
"-b", video_codec=video_codec,
"2000k", options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
"-pix_fmt",
"yuv420p",
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
) )
else: else:
subprocess.check_call(args) if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
video_codec,
"-f",
"mp4",
"-y",
"-crf",
crf,
"-b",
b,
"-pix_fmt",
pix_fmt,
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
subprocess.check_call(args)
else: else:
raise ValueError("no such output type %s" % str(self.output_format)) raise ValueError("no such output type %s" % str(self.output_format))

View File

@@ -163,6 +163,8 @@ def _read_chunks(
if binary_data is not None: if binary_data is not None:
binary_data = np.frombuffer(binary_data, dtype=np.uint8) binary_data = np.frombuffer(binary_data, dtype=np.uint8)
assert binary_data is not None
return json_data, binary_data return json_data, binary_data

View File

@@ -7,6 +7,7 @@
# pyre-unsafe # pyre-unsafe
"""This module implements utility functions for loading .mtl files and textures.""" """This module implements utility functions for loading .mtl files and textures."""
import os import os
import warnings import warnings
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple

View File

@@ -8,6 +8,7 @@
"""This module implements utility functions for loading and saving meshes.""" """This module implements utility functions for loading and saving meshes."""
import os import os
import warnings import warnings
from collections import namedtuple from collections import namedtuple
@@ -813,7 +814,6 @@ def _save(
save_texture: bool = False, save_texture: bool = False,
save_normals: bool = False, save_normals: bool = False,
) -> None: ) -> None:
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3): if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)." message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message) raise ValueError(message)

View File

@@ -14,6 +14,7 @@ meshes as .off files.
This format is introduced, for example, at This format is introduced, for example, at
http://www.geomview.org/docs/html/OFF.html . http://www.geomview.org/docs/html/OFF.html .
""" """
import warnings import warnings
from typing import cast, Optional, Tuple, Union from typing import cast, Optional, Tuple, Union
@@ -84,7 +85,7 @@ def _read_faces_lump(
) )
data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces) data = np.loadtxt(file, dtype=np.float32, ndmin=2, max_rows=n_faces)
except ValueError as e: except ValueError as e:
if n_faces > 1 and "Wrong number of columns" in e.args[0]: if n_faces > 1 and "number of columns" in e.args[0]:
file.seek(old_offset) file.seek(old_offset)
return None return None
raise ValueError("Not enough face data.") from None raise ValueError("Not enough face data.") from None

View File

@@ -11,6 +11,7 @@
This module implements utility functions for loading and saving This module implements utility functions for loading and saving
meshes and point clouds as PLY files. meshes and point clouds as PLY files.
""" """
import itertools import itertools
import os import os
import struct import struct
@@ -1246,7 +1247,7 @@ def _save_ply(
return return
color_np_type = np.ubyte if colors_as_uint8 else np.float32 color_np_type = np.ubyte if colors_as_uint8 else np.float32
verts_dtype = [("verts", np.float32, 3)] verts_dtype: list = [("verts", np.float32, 3)]
if verts_normals is not None: if verts_normals is not None:
verts_dtype.append(("normals", np.float32, 3)) verts_dtype.append(("normals", np.float32, 3))
if verts_colors is not None: if verts_colors is not None:

View File

@@ -122,12 +122,17 @@ def corresponding_cameras_alignment(
# create a new cameras object and set the R and T accordingly # create a new cameras object and set the R and T accordingly
cameras_src_aligned = cameras_src.clone() cameras_src_aligned = cameras_src.clone()
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got `Union[Tensor, Module]`.
cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R) cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
cameras_src_aligned.T = ( cameras_src_aligned.T = (
torch.bmm( torch.bmm(
align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1), align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
# pyre-fixme[6]: For 2nd argument expected `Tensor` but got
# `Union[Tensor, Module]`.
cameras_src.R, cameras_src.R,
)[:, 0] )[:, 0]
# pyre-fixme[29]: `Union[(self: TensorBase, other: Union[bool, complex,
# float, int, Tensor]) -> Tensor, Tensor, Module]` is not a function.
+ cameras_src.T * align_t_s + cameras_src.T * align_t_s
) )
@@ -175,6 +180,7 @@ def _align_camera_extrinsics(
R_A = (U V^T)^T R_A = (U V^T)^T
``` ```
""" """
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0) RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
U, _, V = torch.svd(RRcov) U, _, V = torch.svd(RRcov)
align_t_R = V @ U.t() align_t_R = V @ U.t()
@@ -204,7 +210,11 @@ def _align_camera_extrinsics(
T_A = mean(B) - mean(A) * s_A T_A = mean(B) - mean(A) * s_A
``` ```
""" """
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0] A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `Union[Tensor, Module]`.
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, Any, ...
B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0] B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
Amu = A.mean(0, keepdim=True) Amu = A.mean(0, keepdim=True)
Bmu = B.mean(0, keepdim=True) Bmu = B.mean(0, keepdim=True)

View File

@@ -62,7 +62,7 @@ def cubify(
*, *,
feats: Optional[torch.Tensor] = None, feats: Optional[torch.Tensor] = None,
device=None, device=None,
align: str = "topleft" align: str = "topleft",
) -> Meshes: ) -> Meshes:
r""" r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube Converts a voxel to a mesh by replacing each occupied voxel with a cube

Some files were not shown because too many files have changed in this diff Show More