19 Commits

Author SHA1 Message Date
Jeremy Reizenstein
0eac8299d4 MKL version fix in CI (#1820)
Summary:
Fix for "undefined symbol: iJIT_NotifyEvent" build issue,

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1820

Differential Revision: D58685326
2024-06-20 09:24:07 -07:00
vedrenne
b0462d8079 Allow indexing for classes inheriting Transform3d (#1801)
Summary:
Currently, it is not possible to access a sub-transform using an indexer for all 3d transforms inheriting the `Transforms3d` class.
For instance:

```python
from pytorch3d import transforms

N = 10
r = transforms.random_rotations(N)
T = transforms.Transform3d().rotate(R=r)
R = transforms.Rotate(r)

x = T[0]  # ok
x = R[0]  # TypeError: __init__() got an unexpected keyword argument 'matrix'
```

This is because all these classes (namely `Rotate`, `Translate`, `Scale`, `RotateAxisAngle`) inherit the `__getitem__()` method from `Transform3d` which has the [following code on line 201](https://github.com/facebookresearch/pytorch3d/blob/main/pytorch3d/transforms/transform3d.py#L201):

```python
return self.__class__(matrix=self.get_matrix()[index])
```

The four classes inheriting `Transform3d` are not initialized through a matrix argument, hence they error.
I propose to modify the `__getitem__()` method of the `Transform3d` class to fix this behavior. The least invasive way to do it I can think of consists of creating an empty instance of the current class, then setting the `_matrix` attribute manually. Thus, instead of
```python
return self.__class__(matrix=self.get_matrix()[index])
```
I propose to do:
```python
instance = self.__class__.__new__(self.__class__)
instance._matrix = self.get_matrix()[index]
return instance
```

As far as I can tell, this modification occurs no modification whatsoever for the user, except for the ability to index all 3d transforms.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1801

Reviewed By: MichaelRamamonjisoa

Differential Revision: D58410389

Pulled By: bottler

fbshipit-source-id: f371e4c63d2ae4c927a7ad48c2de8862761078de
2024-06-17 07:48:18 -07:00
Jeremy Reizenstein
b66d17a324 Undo c10=>std optional rename
Summary: Undoes the pytorch3d changes in D57294278 because they break builds for for PyTorch<2.1 .

Reviewed By: MichaelRamamonjisoa

Differential Revision: D57379779

fbshipit-source-id: 47a12511abcec4c3f4e2f62eff5ba99deb2fab4c
2024-06-17 07:09:30 -07:00
Kyle Vedder
717493cb79 Fixed last dimension size check so that it doesn't trivially pass. (#1815)
Summary:
Currently, it checks that the `2`th dimension of `p2` is the same size as the `2`th dimension of `p2` instead of `p1`.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1815

Reviewed By: MichaelRamamonjisoa

Differential Revision: D58586966

Pulled By: bottler

fbshipit-source-id: d4f723fa264f90fe368c10825c1acdfdc4c406dc
2024-06-17 06:00:13 -07:00
Jeremy Reizenstein
302da69461 builds for PyTorch 2.2.1 2.2.2 2.3.0 2.3.1
Summary: Build for new pytorch versions

Reviewed By: MichaelRamamonjisoa

Differential Revision: D58668956

fbshipit-source-id: 7fdfb377b370448d6147daded6a21b8db87586fb
2024-06-17 05:57:59 -07:00
Roman Shapovalov
4ae25bfce7 Moving ray bundle to float dtype
Summary: We can now move ray bundle to float dtype (e.g. from fp16 like types).

Reviewed By: bottler

Differential Revision: D57493109

fbshipit-source-id: 4e18a427e968b646fe5feafbff653811cd007981
2024-05-30 10:06:38 -07:00
Richard Barnes
bd52f4a408 c10::optional -> std::optional in tensorboard/adhoc/Adhoc.h +9
Summary: `c10::optional` was switched to be `std::optional` after PyTorch moved to C++17. Let's eliminate `c10::optional`, if we can.

Reviewed By: albanD

Differential Revision: D57294278

fbshipit-source-id: f6f26133c43f8d92a4588f59df7d689e7909a0cd
2024-05-13 16:40:34 -07:00
generatedunixname89002005307016
17117106e4 upgrade pyre version in fbcode/vision - batch 2
Differential Revision: D57183103

fbshipit-source-id: 7e2f42ddc6a1fa02abc27a451987d67a00264cbb
2024-05-10 01:18:43 -07:00
Richard Barnes
aec76bb4c8 Remove unused-but-set variables in vision/fair/pytorch3d/pytorch3d/csrc/pulsar/include/renderer.render.device.h +1
Summary:
This diff removes a variable that was set, but which was not used.

LLVM-15 has a warning `-Wunused-but-set-variable` which we treat as an error because it's so often diagnostic of a code issue. Unused but set variables often indicate a programming mistake, but can also just be unnecessary cruft that harms readability and performance.

Removing this variable will not change how your code works, but the unused variable may indicate your code isn't working the way you thought it was. I've gone through each of these by hand, but mistakes may have slipped through. If you feel the diff needs changes before landing, **please commandeer** and make appropriate changes: there are hundreds of these and responding to them individually is challenging.

For questions/comments, contact r-barnes.

 - If you approve of this diff, please use the "Accept & Ship" button :-)

Reviewed By: bottler

Differential Revision: D56886956

fbshipit-source-id: 0c515ed98b812b1c106a59e19ec90751ce32e8c0
2024-05-02 13:58:05 -07:00
Andres Suarez
47d5dc8824 Apply clang-format 18
Summary: Previously this code conformed from clang-format 12.

Reviewed By: igorsugak

Differential Revision: D56065247

fbshipit-source-id: f5a985dd8f8b84f2f9e1818b3719b43c5a1b05b3
2024-04-14 11:28:32 -07:00
generatedunixname89002005307016
fe0b1bae49 upgrade pyre version in fbcode/vision - batch 2
Differential Revision: D55650177

fbshipit-source-id: d5faa4d805bb40fe3dea70b0601e7a1382b09f3a
2024-04-02 18:11:50 -07:00
Ruishen Lyu
ccf22911d4 Optimize list_to_packed to avoid for loop (#1737)
Summary:
For larger N and Mi value (e.g. N=154, Mi=238) I notice list_to_packed() has become a bottleneck for my application. By removing the for loop and running on GPU, i see a 10-20 x speedup.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1737

Reviewed By: MichaelRamamonjisoa

Differential Revision: D54187993

Pulled By: bottler

fbshipit-source-id: 16399a24cb63b48c30460c7d960abef603b115d0
2024-04-02 07:50:25 -07:00
Ashim Dahal
128be02fc0 feat: adjusted sample_nums (#1768)
Summary:
adjusted sample_nums to match the number of columns in the image grid. It originally produced image grid with 5 axes but only 3 images and after this fix, the block would work as intended.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1768

Reviewed By: MichaelRamamonjisoa

Differential Revision: D55632872

Pulled By: bottler

fbshipit-source-id: 44d633a8068076889e49d49b8a7910dba0db37a7
2024-04-02 06:02:48 -07:00
Roeia Kishk
31e3488a51 Changed tutorials' pip searching
Summary:
### Generalise tutorials' pip searching:
## Required Information:
This diff contains changes to several PyTorch3D tutorials.

**Purpose of this diff:**
Replace the current installation code with a more streamlined approach that tries to install the wheel first and falls back to installing from source if the wheel is not found.

**Why this diff is required:**
This diff makes it easier to cope with new PyTorch releases and reduce the need for manual intervention, as the current process involves checking the version of PyTorch in Colab and building a new wheel if it doesn't match the expected version, which generates additional work each time there is a a new PyTorch version in Colab.

**Changes:**
Before:
```
    if torch.__version__.startswith("2.1.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
```
After:
```
    pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
    version_str="".join([
        f"py3{sys.version_info.minor}_cu",
        torch.version.cuda.replace(".",""),
        f"_pyt{pyt_version_str}"
    ])
    !pip install fvcore iopath
    if sys.platform.startswith("linux"):
      # We try to install PyTorch3D via a released wheel.
      !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
      pip_list = !pip freeze
      need_pytorch3d = not any(i.startswith("pytorch3d==") for  i in pip_list)

    if need_pytorch3d:
        # We try to install PyTorch3D from source.
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'
```

Reviewed By: bottler

Differential Revision: D55431832

fbshipit-source-id: a8de9162470698320241ae8401427dcb1ce17c37
2024-03-28 11:24:43 -07:00
generatedunixname89002005307016
b215776f2d upgrade pyre version in fbcode/vision - batch 2
Differential Revision: D55395614

fbshipit-source-id: 71677892b5d6f219f6df25b4efb51fb0f6b1441b
2024-03-26 22:02:22 -07:00
Cijo Jose
38cf0dc1c5 TexturesUV multiple maps
Summary: Implements the  the TexturesUV with multiple map ids.

Reviewed By: bottler

Differential Revision: D53944063

fbshipit-source-id: 06c25eb6d69f72db0484f16566dd2ca32a560b82
2024-03-12 06:59:31 -07:00
Jaap Suter
7566530669 CUDA marching_cubes fix
Summary:
Fix an inclusive vs exclusive scan mix-up that was accidentally introduced when removing the Thrust dependency (`Thrust::exclusive_scan`) and reimplementing it using `at::cumsum` (which does an inclusive scan).

This fixes two Github reported issues:

 * https://github.com/facebookresearch/pytorch3d/issues/1731
 * https://github.com/facebookresearch/pytorch3d/issues/1751

Reviewed By: bottler

Differential Revision: D54605545

fbshipit-source-id: da9e92f3f8a9a35f7b7191428d0b9a9ca03e0d4d
2024-03-07 15:38:24 -08:00
Conner Nilsen
a27755db41 Pyre Configurationless migration for] [batch:85/112] [shard:6/N]
Reviewed By: inseokhwang

Differential Revision: D54438157

fbshipit-source-id: a6acfe146ed29fff82123b5e458906d4b4cee6a2
2024-03-04 18:30:37 -08:00
Amethyst Reese
3da7703c5a apply Black 2024 style in fbcode (4/16)
Summary:
Formats the covered files with pyfmt.

paintitblack

Reviewed By: aleivag

Differential Revision: D54447727

fbshipit-source-id: 8844b1caa08de94d04ac4df3c768dbf8c865fd2f
2024-03-02 17:31:19 -08:00
242 changed files with 1578 additions and 378 deletions

View File

@@ -302,6 +302,34 @@ workflows:
name: linux_conda_py38_cu121_pyt220 name: linux_conda_py38_cu121_pyt220
python_version: '3.8' python_version: '3.8'
pytorch_version: 2.2.0 pytorch_version: 2.2.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py38_cu118_pyt222
python_version: '3.8'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py38_cu121_pyt222
python_version: '3.8'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py38_cu118_pyt231
python_version: '3.8'
pytorch_version: 2.3.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py38_cu121_pyt231
python_version: '3.8'
pytorch_version: 2.3.1
- binary_linux_conda: - binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda113 conda_docker_image: pytorch/conda-builder:cuda113
context: DOCKERHUB_TOKEN context: DOCKERHUB_TOKEN
@@ -442,6 +470,34 @@ workflows:
name: linux_conda_py39_cu121_pyt220 name: linux_conda_py39_cu121_pyt220
python_version: '3.9' python_version: '3.9'
pytorch_version: 2.2.0 pytorch_version: 2.2.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py39_cu118_pyt222
python_version: '3.9'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py39_cu121_pyt222
python_version: '3.9'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py39_cu118_pyt231
python_version: '3.9'
pytorch_version: 2.3.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py39_cu121_pyt231
python_version: '3.9'
pytorch_version: 2.3.1
- binary_linux_conda: - binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda113 conda_docker_image: pytorch/conda-builder:cuda113
context: DOCKERHUB_TOKEN context: DOCKERHUB_TOKEN
@@ -582,6 +638,34 @@ workflows:
name: linux_conda_py310_cu121_pyt220 name: linux_conda_py310_cu121_pyt220
python_version: '3.10' python_version: '3.10'
pytorch_version: 2.2.0 pytorch_version: 2.2.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py310_cu118_pyt222
python_version: '3.10'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py310_cu121_pyt222
python_version: '3.10'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py310_cu118_pyt231
python_version: '3.10'
pytorch_version: 2.3.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py310_cu121_pyt231
python_version: '3.10'
pytorch_version: 2.3.1
- binary_linux_conda: - binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118 conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN context: DOCKERHUB_TOKEN
@@ -638,6 +722,34 @@ workflows:
name: linux_conda_py311_cu121_pyt220 name: linux_conda_py311_cu121_pyt220
python_version: '3.11' python_version: '3.11'
pytorch_version: 2.2.0 pytorch_version: 2.2.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py311_cu118_pyt222
python_version: '3.11'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py311_cu121_pyt222
python_version: '3.11'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py311_cu118_pyt231
python_version: '3.11'
pytorch_version: 2.3.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py311_cu121_pyt231
python_version: '3.11'
pytorch_version: 2.3.1
- binary_linux_conda: - binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118 conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN context: DOCKERHUB_TOKEN
@@ -652,6 +764,34 @@ workflows:
name: linux_conda_py312_cu121_pyt220 name: linux_conda_py312_cu121_pyt220
python_version: '3.12' python_version: '3.12'
pytorch_version: 2.2.0 pytorch_version: 2.2.0
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py312_cu118_pyt222
python_version: '3.12'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py312_cu121_pyt222
python_version: '3.12'
pytorch_version: 2.2.2
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda118
context: DOCKERHUB_TOKEN
cu_version: cu118
name: linux_conda_py312_cu118_pyt231
python_version: '3.12'
pytorch_version: 2.3.1
- binary_linux_conda:
conda_docker_image: pytorch/conda-builder:cuda121
context: DOCKERHUB_TOKEN
cu_version: cu121
name: linux_conda_py312_cu121_pyt231
python_version: '3.12'
pytorch_version: 2.3.1
- binary_linux_conda_cuda: - binary_linux_conda_cuda:
name: testrun_conda_cuda_py310_cu117_pyt201 name: testrun_conda_cuda_py310_cu117_pyt201
context: DOCKERHUB_TOKEN context: DOCKERHUB_TOKEN

View File

@@ -29,6 +29,8 @@ CONDA_CUDA_VERSIONS = {
"2.1.1": ["cu118", "cu121"], "2.1.1": ["cu118", "cu121"],
"2.1.2": ["cu118", "cu121"], "2.1.2": ["cu118", "cu121"],
"2.2.0": ["cu118", "cu121"], "2.2.0": ["cu118", "cu121"],
"2.2.2": ["cu118", "cu121"],
"2.3.1": ["cu118", "cu121"],
} }

View File

@@ -9,7 +9,7 @@ The core library is written in PyTorch. Several components have underlying imple
- Linux or macOS or Windows - Linux or macOS or Windows
- Python 3.8, 3.9 or 3.10 - Python 3.8, 3.9 or 3.10
- PyTorch 1.12.0, 1.12.1, 1.13.0, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2 or 2.2.0. - PyTorch 1.12.0, 1.12.1, 1.13.0, 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0 or 2.3.1.
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this. - torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
- gcc & g++ ≥ 4.9 - gcc & g++ ≥ 4.9
- [fvcore](https://github.com/facebookresearch/fvcore) - [fvcore](https://github.com/facebookresearch/fvcore)

View File

@@ -83,25 +83,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -70,25 +70,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -45,25 +45,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {
@@ -405,7 +411,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"random_model_images = shapenet_dataset.render(\n", "random_model_images = shapenet_dataset.render(\n",
" sample_nums=[3],\n", " sample_nums=[5],\n",
" device=device,\n", " device=device,\n",
" cameras=cameras,\n", " cameras=cameras,\n",
" raster_settings=raster_settings,\n", " raster_settings=raster_settings,\n",

View File

@@ -84,25 +84,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -50,25 +50,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -62,25 +62,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -41,25 +41,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -72,25 +72,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -66,25 +66,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -44,25 +44,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -51,25 +51,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -67,25 +67,31 @@
"import os\n", "import os\n",
"import sys\n", "import sys\n",
"import torch\n", "import torch\n",
"import subprocess\n",
"need_pytorch3d=False\n", "need_pytorch3d=False\n",
"try:\n", "try:\n",
" import pytorch3d\n", " import pytorch3d\n",
"except ModuleNotFoundError:\n", "except ModuleNotFoundError:\n",
" need_pytorch3d=True\n", " need_pytorch3d=True\n",
"if need_pytorch3d:\n", "if need_pytorch3d:\n",
" if torch.__version__.startswith(\"2.2.\") and sys.platform.startswith(\"linux\"):\n", " pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
" # We try to install PyTorch3D via a released wheel.\n", " version_str=\"\".join([\n",
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n", " f\"py3{sys.version_info.minor}_cu\",\n",
" version_str=\"\".join([\n", " torch.version.cuda.replace(\".\",\"\"),\n",
" f\"py3{sys.version_info.minor}_cu\",\n", " f\"_pyt{pyt_version_str}\"\n",
" torch.version.cuda.replace(\".\",\"\"),\n", " ])\n",
" f\"_pyt{pyt_version_str}\"\n", " !pip install fvcore iopath\n",
" ])\n", " if sys.platform.startswith(\"linux\"):\n",
" !pip install fvcore iopath\n", " print(\"Trying to install wheel for PyTorch3D\")\n",
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n", " !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
" else:\n", " pip_list = !pip freeze\n",
" # We try to install PyTorch3D from source.\n", " need_pytorch3d = not any(i.startswith(\"pytorch3d==\") for i in pip_list)\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'" " if need_pytorch3d:\n",
" print(f\"failed to find/install wheel for {version_str}\")\n",
"if need_pytorch3d:\n",
" print(\"Installing PyTorch3D from source\")\n",
" !pip install ninja\n",
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
] ]
}, },
{ {

View File

@@ -80,6 +80,12 @@ def setup_cuda():
def setup_conda_pytorch_constraint() -> List[str]: def setup_conda_pytorch_constraint() -> List[str]:
pytorch_constraint = f"- pytorch=={PYTORCH_VERSION}" pytorch_constraint = f"- pytorch=={PYTORCH_VERSION}"
os.environ["CONDA_PYTORCH_CONSTRAINT"] = pytorch_constraint os.environ["CONDA_PYTORCH_CONSTRAINT"] = pytorch_constraint
if pytorch_major_minor < (2, 2):
os.environ["CONDA_PYTORCH_MKL_CONSTRAINT"] = "- mkl!=2024.1.0"
os.environ["SETUPTOOLS_CONSTRAINT"] = "- setuptools<70"
else:
os.environ["CONDA_PYTORCH_MKL_CONSTRAINT"] = ""
os.environ["SETUPTOOLS_CONSTRAINT"] = "- setuptools"
os.environ["CONDA_PYTORCH_BUILD_CONSTRAINT"] = pytorch_constraint os.environ["CONDA_PYTORCH_BUILD_CONSTRAINT"] = pytorch_constraint
os.environ["PYTORCH_VERSION_NODOT"] = PYTORCH_VERSION.replace(".", "") os.environ["PYTORCH_VERSION_NODOT"] = PYTORCH_VERSION.replace(".", "")

View File

@@ -12,8 +12,9 @@ requirements:
host: host:
- python - python
- setuptools {{ environ.get('SETUPTOOLS_CONSTRAINT') }}
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }} {{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }} {{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
{{ environ.get('CONDA_CPUONLY_FEATURE') }} {{ environ.get('CONDA_CPUONLY_FEATURE') }}

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
"""" """"
This file is the entry point for launching experiments with Implicitron. This file is the entry point for launching experiments with Implicitron.

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import os import os
from typing import Optional from typing import Optional

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import inspect import inspect
import logging import logging
import os import os

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import os import os
import time import time

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import random import random

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import os import os
import tempfile import tempfile
import unittest import unittest

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import os import os
import unittest import unittest

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import os import os
import unittest import unittest

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import contextlib import contextlib
import logging import logging
import os import os

View File

@@ -5,6 +5,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
""" """
Script to visualize a previously trained model. Example call: Script to visualize a previously trained model. Example call:

View File

@@ -343,12 +343,14 @@ class RadianceFieldRenderer(torch.nn.Module):
# For a full render pass concatenate the output chunks, # For a full render pass concatenate the output chunks,
# and reshape to image size. # and reshape to image size.
out = { out = {
k: torch.cat( k: (
[ch_o[k] for ch_o in chunk_outputs], torch.cat(
dim=1, [ch_o[k] for ch_o in chunk_outputs],
).view(-1, *self._image_size, 3) dim=1,
if chunk_outputs[0][k] is not None ).view(-1, *self._image_size, 3)
else None if chunk_outputs[0][k] is not None
else None
)
for k in ("rgb_fine", "rgb_coarse", "rgb_gt") for k in ("rgb_fine", "rgb_coarse", "rgb_gt")
} }
else: else:

View File

@@ -4,4 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
__version__ = "0.7.6" __version__ = "0.7.6"

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from .datatypes import Device, get_device, make_device from .datatypes import Device, get_device, make_device

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Sequence, Tuple, Union from typing import Sequence, Tuple, Union
import torch import torch

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Optional, Union from typing import Optional, Union
import torch import torch

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import math import math
from typing import Tuple from typing import Tuple

View File

@@ -4,5 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from .symeig3x3 import symeig3x3 from .symeig3x3 import symeig3x3
from .utils import _safe_det_3x3 from .utils import _safe_det_3x3

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch import torch

View File

@@ -338,7 +338,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = lengths1.options().dtype(at::kLong); auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype); auto idxs = at::zeros({N, P1, K}, long_dtype);
auto dists = at::zeros({N, P1, K}, p1.options()); auto dists = at::zeros({N, P1, K}, p1.options());

View File

@@ -382,6 +382,44 @@ __global__ void GenerateFacesKernel(
} // end for grid-strided kernel } // end for grid-strided kernel
} }
// ATen/Torch does not have an exclusive-scan operator. Additionally, in the
// code below we need to get the "total number of items to work on" after
// a scan, which with an inclusive-scan would simply be the value of the last
// element in the tensor.
//
// This utility function hits two birds with one stone, by running
// an inclusive-scan into a right-shifted view of a tensor that's
// allocated to be one element bigger than the input tensor.
//
// Note; return tensor is `int64_t` per element, even if the input
// tensor is only 32-bit. Also, the return tensor is one element bigger
// than the input one.
//
// Secondary optional argument is an output argument that gets the
// value of the last element of the return tensor (because you almost
// always need this CPU-side right after this function anyway).
static at::Tensor ExclusiveScanAndTotal(
const at::Tensor& inTensor,
int64_t* optTotal = nullptr) {
const auto inSize = inTensor.sizes()[0];
auto retTensor = at::zeros({inSize + 1}, at::kLong).to(inTensor.device());
using at::indexing::None;
using at::indexing::Slice;
auto rightShiftedView = retTensor.index({Slice(1, None)});
// Do an (inclusive-scan) cumulative sum in to the view that's
// shifted one element to the right...
at::cumsum_out(rightShiftedView, inTensor, 0, at::kLong);
if (optTotal) {
*optTotal = retTensor[inSize].cpu().item<int64_t>();
}
// ...so that the not-shifted tensor holds the exclusive-scan
return retTensor;
}
// Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to // Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to
// create triangle meshes from an implicit function (one of the form f(x, y, z) // create triangle meshes from an implicit function (one of the form f(x, y, z)
// = 0). It works by iteratively checking a grid of cubes superimposed over a // = 0). It works by iteratively checking a grid of cubes superimposed over a
@@ -444,20 +482,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
using at::indexing::Slice; using at::indexing::Slice;
auto d_voxelVerts = auto d_voxelVerts =
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt)) at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device()); .to(vol.device());
auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
auto d_voxelOccupied = auto d_voxelOccupied =
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt)) at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device()); .to(vol.device());
auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
// Execute "ClassifyVoxelKernel" kernel to precompute // Execute "ClassifyVoxelKernel" kernel to precompute
// two arrays - d_voxelOccupied and d_voxelVertices to global memory, // two arrays - d_voxelOccupied and d_voxelVertices to global memory,
// which stores the occupancy state and number of voxel vertices per voxel. // which stores the occupancy state and number of voxel vertices per voxel.
ClassifyVoxelKernel<<<grid, threads, 0, stream>>>( ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(), d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(), d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(), vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
isolevel); isolevel);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
@@ -467,12 +503,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
// count for voxels in the grid and compute the number of active voxels. // count for voxels in the grid and compute the number of active voxels.
// If the number of active voxels is 0, return zero tensor for verts and // If the number of active voxels is 0, return zero tensor for verts and
// faces. // faces.
int64_t activeVoxels = 0;
auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0); auto d_voxelOccupiedScan =
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)}); ExclusiveScanAndTotal(d_voxelOccupied, &activeVoxels);
// number of active voxels
int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();
const int device_id = vol.device().index(); const int device_id = vol.device().index();
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id); auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@@ -487,24 +520,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
return std::make_tuple(verts, faces, ids); return std::make_tuple(verts, faces, ids);
} }
// Execute "CompactVoxelsKernel" kernel to compress voxels for accleration. // Execute "CompactVoxelsKernel" kernel to compress voxels for acceleration.
// This allows us to run triangle generation on only the occupied voxels. // This allows us to run triangle generation on only the occupied voxels.
auto d_compVoxelArray = at::zeros({activeVoxels}, opt); auto d_compVoxelArray = at::zeros({activeVoxels}, opt);
CompactVoxelsKernel<<<grid, threads, 0, stream>>>( CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(), d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(), d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_ d_voxelOccupiedScan
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
numVoxels); numVoxels);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize(); cudaDeviceSynchronize();
// Scan d_voxelVerts array to generate offsets of vertices for each voxel // Scan d_voxelVerts array to generate offsets of vertices for each voxel
auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0); int64_t totalVerts = 0;
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)}); auto d_voxelVertsScan = ExclusiveScanAndTotal(d_voxelVerts, &totalVerts);
// total number of vertices
int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();
// Execute "GenerateFacesKernel" kernel // Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels. // This runs only on the occupied voxels.
@@ -524,7 +554,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(), faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(), ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(), d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), d_voxelVertsScan.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
activeVoxels, activeVoxels,
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(), vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(), faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),

View File

@@ -357,11 +357,11 @@ void MAX_WS(
// //
// //
#define END_PARALLEL() \ #define END_PARALLEL() \
end_parallel:; \ end_parallel :; \
} }
#define END_PARALLEL_NORET() } #define END_PARALLEL_NORET() }
#define END_PARALLEL_2D() \ #define END_PARALLEL_2D() \
end_parallel:; \ end_parallel :; \
} \ } \
} }
#define END_PARALLEL_2D_NORET() \ #define END_PARALLEL_2D_NORET() \

View File

@@ -93,7 +93,7 @@ HOST void construct(
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls); MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
MALLOC(self->region_flags_d, char, max_num_balls); MALLOC(self->region_flags_d, char, max_num_balls);
MALLOC(self->num_selected_d, size_t, 1); MALLOC(self->num_selected_d, size_t, 1);
MALLOC(self->forw_info_d, float, width* height*(3 + 2 * n_track)); MALLOC(self->forw_info_d, float, width* height * (3 + 2 * n_track));
MALLOC(self->min_max_pixels_d, IntersectInfo, 1); MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
MALLOC(self->grad_pos_d, float3, max_num_balls); MALLOC(self->grad_pos_d, float3, max_num_balls);
MALLOC(self->grad_col_d, float, max_num_balls* n_channels); MALLOC(self->grad_col_d, float, max_num_balls* n_channels);

View File

@@ -99,7 +99,7 @@ GLOBAL void render(
/** Whether loading of balls is completed. */ /** Whether loading of balls is completed. */
SHARED bool loading_done; SHARED bool loading_done;
/** The number of balls loaded overall (just for statistics). */ /** The number of balls loaded overall (just for statistics). */
SHARED int n_balls_loaded; [[maybe_unused]] SHARED int n_balls_loaded;
/** The area this thread block covers. */ /** The area this thread block covers. */
SHARED IntersectInfo block_area; SHARED IntersectInfo block_area;
if (thread_block.thread_rank() == 0) { if (thread_block.thread_rank() == 0) {

View File

@@ -244,8 +244,7 @@ at::Tensor RasterizeCoarseCuda(
if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) { if (num_bins_y >= kMaxItemsPerBin || num_bins_x >= kMaxItemsPerBin) {
std::stringstream ss; std::stringstream ss;
ss << "In RasterizeCoarseCuda got num_bins_y: " << num_bins_y ss << "In RasterizeCoarseCuda got num_bins_y: " << num_bins_y
<< ", num_bins_x: " << num_bins_x << ", " << ", num_bins_x: " << num_bins_x << ", " << "; that's too many!";
<< "; that's too many!";
AT_ERROR(ss.str()); AT_ERROR(ss.str());
} }
auto opts = elems_per_batch.options().dtype(at::kInt); auto opts = elems_per_batch.options().dtype(at::kInt);

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from .r2n2 import BlenderCamera, collate_batched_R2N2, R2N2, render_cubified_voxels from .r2n2 import BlenderCamera, collate_batched_R2N2, R2N2, render_cubified_voxels
from .shapenet import ShapeNetCore from .shapenet import ShapeNetCore
from .utils import collate_batched_meshes from .utils import collate_batched_meshes

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from .r2n2 import R2N2 from .r2n2 import R2N2
from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels from .utils import BlenderCamera, collate_batched_R2N2, render_cubified_voxels

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import json import json
import warnings import warnings
from os import path from os import path

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import math import math
from typing import Dict, List from typing import Dict, List

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from .shapenet_core import ShapeNetCore from .shapenet_core import ShapeNetCore

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import json import json
import os import os
import warnings import warnings

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import warnings import warnings
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Dict, List from typing import Dict, List
from pytorch3d.renderer.mesh import TexturesAtlas from pytorch3d.renderer.mesh import TexturesAtlas

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import torch import torch
from pytorch3d.implicitron.tools.config import registry from pytorch3d.implicitron.tools.config import registry

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Iterator, List, Optional, Tuple from typing import Iterator, List, Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Optional, Tuple from typing import Optional, Tuple
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import ( from typing import (

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
@@ -576,11 +578,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
camera_quality_score=safe_as_tensor( camera_quality_score=safe_as_tensor(
sequence_annotation.viewpoint_quality_score, torch.float sequence_annotation.viewpoint_quality_score, torch.float
), ),
point_cloud_quality_score=safe_as_tensor( point_cloud_quality_score=(
point_cloud.quality_score, torch.float safe_as_tensor(point_cloud.quality_score, torch.float)
) if point_cloud is not None
if point_cloud is not None else None
else None, ),
) )
fg_mask_np: Optional[np.ndarray] = None fg_mask_np: Optional[np.ndarray] = None

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import copy import copy
import functools import functools
import gzip import gzip
@@ -124,9 +126,9 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
dimension of the cropping bounding box, relative to box size. dimension of the cropping bounding box, relative to box size.
""" """
frame_annotations_type: ClassVar[ frame_annotations_type: ClassVar[Type[types.FrameAnnotation]] = (
Type[types.FrameAnnotation] types.FrameAnnotation
] = types.FrameAnnotation )
path_manager: Any = None path_manager: Any = None
frame_annotations_file: str = "" frame_annotations_file: str = ""

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import json import json
import os import os

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import copy import copy
import json import json

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import numpy as np import numpy as np
import torch import torch

View File

@@ -1,6 +1,8 @@
# @lint-ignore-every LICENSELINT # @lint-ignore-every LICENSELINT
# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py # Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py
# Copyright (c) 2020 bmild # Copyright (c) 2020 bmild
# pyre-unsafe
import json import json
import os import os

View File

@@ -1,6 +1,8 @@
# @lint-ignore-every LICENSELINT # @lint-ignore-every LICENSELINT
# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py # Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py
# Copyright (c) 2020 bmild # Copyright (c) 2020 bmild
# pyre-unsafe
import logging import logging
import os import os
import warnings import warnings

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from os.path import dirname, join, realpath from os.path import dirname, join, realpath
from typing import Optional, Tuple from typing import Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import warnings import warnings
from collections import Counter from collections import Counter

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
# This file defines a base class for dataset map providers which # This file defines a base class for dataset map providers which
# provide data for a single scene. # provide data for a single scene.

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import dataclasses import dataclasses
import gzip import gzip

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import functools import functools
import warnings import warnings

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import cast, Optional, Tuple from typing import cast, Optional, Tuple
import torch import torch
@@ -88,9 +90,11 @@ def get_implicitron_sequence_pointcloud(
frame_data.camera, frame_data.camera,
frame_data.image_rgb, frame_data.image_rgb,
frame_data.depth_map, frame_data.depth_map,
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float() (
if mask_points and frame_data.fg_probability is not None (cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
else None, if mask_points and frame_data.fg_probability is not None
else None
),
) )
return point_cloud, frame_data return point_cloud, frame_data

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import dataclasses import dataclasses
import os import os

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import copy import copy
import warnings import warnings
@@ -282,9 +284,9 @@ def eval_batch(
image_rgb_masked=image_rgb_masked, image_rgb_masked=image_rgb_masked,
depth_render=cloned_render["depth_render"], depth_render=cloned_render["depth_render"],
depth_map=frame_data.depth_map, depth_map=frame_data.depth_map,
depth_mask=frame_data.depth_mask[:1] depth_mask=(
if frame_data.depth_mask is not None frame_data.depth_mask[:1] if frame_data.depth_mask is not None else None
else None, ),
visdom_env=visualize_visdom_env, visdom_env=visualize_visdom_env,
) )

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import copy import copy
import json import json
import logging import logging

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
# Allows to register the models # Allows to register the models
# see: pytorch3d.implicitron.tools.config.registry:register # see: pytorch3d.implicitron.tools.config.registry:register
from pytorch3d.implicitron.models.generic_model import GenericModel from pytorch3d.implicitron.models.generic_model import GenericModel

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional

View File

@@ -4,4 +4,6 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from .feature_extractor import FeatureExtractorBase from .feature_extractor import FeatureExtractorBase

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import math import math
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
# Note: The #noqa comments below are for unused imports of pluggable implementations # Note: The #noqa comments below are for unused imports of pluggable implementations
# which are part of implicitron. They ensure that the registry is prepopulated. # which are part of implicitron. They ensure that the registry is prepopulated.
@@ -395,9 +397,11 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
n_targets = ( n_targets = (
1 1
if evaluation_mode == EvaluationMode.EVALUATION if evaluation_mode == EvaluationMode.EVALUATION
else batch_size else (
if self.n_train_target_views <= 0 batch_size
else min(self.n_train_target_views, batch_size) if self.n_train_target_views <= 0
else min(self.n_train_target_views, batch_size)
)
) )
# A helper function for selecting n_target first elements from the input # A helper function for selecting n_target first elements from the input
@@ -422,9 +426,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
ray_bundle: ImplicitronRayBundle = self.raysampler( ray_bundle: ImplicitronRayBundle = self.raysampler(
target_cameras, target_cameras,
evaluation_mode, evaluation_mode,
mask=mask_crop[:n_targets] mask=(
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE mask_crop[:n_targets]
else None, if mask_crop is not None
and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None
),
) )
# custom_args hold additional arguments to the implicit function. # custom_args hold additional arguments to the implicit function.

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import List, Optional, Union from typing import List, Optional, Union
import torch import torch

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Optional

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
""" """
This file contains This file contains
- modules which get used by ImplicitFunction objects for decoding an embedding defined in - modules which get used by ImplicitFunction objects for decoding an embedding defined in

View File

@@ -2,6 +2,8 @@
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/ # Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py # implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv # Copyright (c) 2020 Lior Yariv
# pyre-unsafe
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -102,9 +104,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0: elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_( torch.nn.init.normal_(lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5)
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
)
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in: elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5) torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple
@@ -193,9 +195,9 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
embeds = create_embeddings_for_implicit_function( embeds = create_embeddings_for_implicit_function(
xyz_world=rays_points_world, xyz_world=rays_points_world,
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz xyz_embedding_function=(
if self.input_xyz self.harmonic_embedding_xyz if self.input_xyz else None
else None, ),
global_code=global_code, global_code=global_code,
fun_viewpool=fun_viewpool, fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords, xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,

View File

@@ -1,6 +1,8 @@
# @lint-ignore-every LICENSELINT # @lint-ignore-every LICENSELINT
# Adapted from https://github.com/vsitzmann/scene-representation-networks # Adapted from https://github.com/vsitzmann/scene-representation-networks
# Copyright (c) 2019 Vincent Sitzmann # Copyright (c) 2019 Vincent Sitzmann
# pyre-unsafe
from typing import Any, cast, Optional, Tuple from typing import Any, cast, Optional, Tuple
import torch import torch

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
""" """
This file contains classes that implement Voxel grids, both in their full resolution This file contains classes that implement Voxel grids, both in their full resolution
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import logging import logging
import math import math
import warnings import warnings

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
import warnings import warnings
from typing import Any, Dict, Optional from typing import Any, Dict, Optional

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple

View File

@@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe
# Note: The #noqa comments below are for unused imports of pluggable implementations # Note: The #noqa comments below are for unused imports of pluggable implementations
# which are part of implicitron. They ensure that the registry is prepopulated. # which are part of implicitron. They ensure that the registry is prepopulated.
@@ -356,9 +358,12 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
ray_bundle: ImplicitronRayBundle = self.raysampler( ray_bundle: ImplicitronRayBundle = self.raysampler(
camera, camera,
evaluation_mode, evaluation_mode,
mask=mask_crop mask=(
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE mask_crop
else None, if mask_crop is not None
and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None
),
) )
inputs_to_be_chunked = {} inputs_to_be_chunked = {}
@@ -381,10 +386,12 @@ class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
frame_timestamp=frame_timestamp, frame_timestamp=frame_timestamp,
) )
implicit_functions = [ implicit_functions = [
functools.partial(implicit_function, global_code=global_code) (
if isinstance(implicit_function, Callable) functools.partial(implicit_function, global_code=global_code)
else functools.partial( if isinstance(implicit_function, Callable)
implicit_function.forward, global_code=global_code else functools.partial(
implicit_function.forward, global_code=global_code
)
) )
for implicit_function in implicit_functions for implicit_function in implicit_functions
] ]

View File

@@ -3,3 +3,5 @@
# #
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
# pyre-unsafe

Some files were not shown because too many files have changed in this diff Show More