mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-03 02:35:58 +08:00
Compare commits
46 Commits
v0.7.7
...
bottler/un
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62a2031dd4 | ||
|
|
3987612062 | ||
|
|
06a76ef8dd | ||
|
|
21205730d9 | ||
|
|
7e09505538 | ||
|
|
20bd8b33f6 | ||
|
|
7a3c0cbc9d | ||
|
|
215590b497 | ||
|
|
43cd681d4f | ||
|
|
42a4a7d432 | ||
|
|
699bc671ca | ||
|
|
49cf5a0f37 | ||
|
|
89b851e64c | ||
|
|
5247f6ad74 | ||
|
|
e41aff47db | ||
|
|
64a5bfadc8 | ||
|
|
055ab3a2e3 | ||
|
|
f6c2ca6bfc | ||
|
|
e20cbe9b0e | ||
|
|
c17e6f947a | ||
|
|
91c9f34137 | ||
|
|
81d82980bc | ||
|
|
8fe6934885 | ||
|
|
c434957b2a | ||
|
|
dd2a11b5fc | ||
|
|
9563ef79ca | ||
|
|
008c7ab58c | ||
|
|
9eaed4c495 | ||
|
|
e13848265d | ||
|
|
58566963d6 | ||
|
|
e17ed5cd50 | ||
|
|
8ed0c7a002 | ||
|
|
2da913c7e6 | ||
|
|
fca83e6369 | ||
|
|
75ebeeaea0 | ||
|
|
ab793177c6 | ||
|
|
9acdd67b83 | ||
|
|
3f428d9981 | ||
|
|
05cbea115a | ||
|
|
38afdcfc68 | ||
|
|
1e0b1d9c72 | ||
|
|
44702fdb4b | ||
|
|
7edaee71a9 | ||
|
|
d0d0e02007 | ||
|
|
4df110b0a9 | ||
|
|
51fd114d8b |
@@ -162,34 +162,6 @@ workflows:
|
|||||||
jobs:
|
jobs:
|
||||||
# - main:
|
# - main:
|
||||||
# context: DOCKERHUB_TOKEN
|
# context: DOCKERHUB_TOKEN
|
||||||
- binary_linux_conda:
|
|
||||||
conda_docker_image: pytorch/conda-builder:cuda117
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu117
|
|
||||||
name: linux_conda_py38_cu117_pyt200
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 2.0.0
|
|
||||||
- binary_linux_conda:
|
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu118
|
|
||||||
name: linux_conda_py38_cu118_pyt200
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 2.0.0
|
|
||||||
- binary_linux_conda:
|
|
||||||
conda_docker_image: pytorch/conda-builder:cuda117
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu117
|
|
||||||
name: linux_conda_py38_cu117_pyt201
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 2.0.1
|
|
||||||
- binary_linux_conda:
|
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu118
|
|
||||||
name: linux_conda_py38_cu118_pyt201
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 2.0.1
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
@@ -275,33 +247,33 @@ workflows:
|
|||||||
python_version: '3.8'
|
python_version: '3.8'
|
||||||
pytorch_version: 2.3.1
|
pytorch_version: 2.3.1
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda117
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu117
|
cu_version: cu118
|
||||||
name: linux_conda_py39_cu117_pyt200
|
name: linux_conda_py38_cu118_pyt240
|
||||||
python_version: '3.9'
|
python_version: '3.8'
|
||||||
pytorch_version: 2.0.0
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py38_cu121_pyt240
|
||||||
|
python_version: '3.8'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu118
|
cu_version: cu118
|
||||||
name: linux_conda_py39_cu118_pyt200
|
name: linux_conda_py38_cu118_pyt241
|
||||||
python_version: '3.9'
|
python_version: '3.8'
|
||||||
pytorch_version: 2.0.0
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda117
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu117
|
cu_version: cu121
|
||||||
name: linux_conda_py39_cu117_pyt201
|
name: linux_conda_py38_cu121_pyt241
|
||||||
python_version: '3.9'
|
python_version: '3.8'
|
||||||
pytorch_version: 2.0.1
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda:
|
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu118
|
|
||||||
name: linux_conda_py39_cu118_pyt201
|
|
||||||
python_version: '3.9'
|
|
||||||
pytorch_version: 2.0.1
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
@@ -387,33 +359,33 @@ workflows:
|
|||||||
python_version: '3.9'
|
python_version: '3.9'
|
||||||
pytorch_version: 2.3.1
|
pytorch_version: 2.3.1
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda117
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu117
|
cu_version: cu118
|
||||||
name: linux_conda_py310_cu117_pyt200
|
name: linux_conda_py39_cu118_pyt240
|
||||||
python_version: '3.10'
|
python_version: '3.9'
|
||||||
pytorch_version: 2.0.0
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py39_cu121_pyt240
|
||||||
|
python_version: '3.9'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu118
|
cu_version: cu118
|
||||||
name: linux_conda_py310_cu118_pyt200
|
name: linux_conda_py39_cu118_pyt241
|
||||||
python_version: '3.10'
|
python_version: '3.9'
|
||||||
pytorch_version: 2.0.0
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda117
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu117
|
cu_version: cu121
|
||||||
name: linux_conda_py310_cu117_pyt201
|
name: linux_conda_py39_cu121_pyt241
|
||||||
python_version: '3.10'
|
python_version: '3.9'
|
||||||
pytorch_version: 2.0.1
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda:
|
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu118
|
|
||||||
name: linux_conda_py310_cu118_pyt201
|
|
||||||
python_version: '3.10'
|
|
||||||
pytorch_version: 2.0.1
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
@@ -498,6 +470,34 @@ workflows:
|
|||||||
name: linux_conda_py310_cu121_pyt231
|
name: linux_conda_py310_cu121_pyt231
|
||||||
python_version: '3.10'
|
python_version: '3.10'
|
||||||
pytorch_version: 2.3.1
|
pytorch_version: 2.3.1
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu118
|
||||||
|
name: linux_conda_py310_cu118_pyt240
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py310_cu121_pyt240
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu118
|
||||||
|
name: linux_conda_py310_cu118_pyt241
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 2.4.1
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py310_cu121_pyt241
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
@@ -582,6 +582,34 @@ workflows:
|
|||||||
name: linux_conda_py311_cu121_pyt231
|
name: linux_conda_py311_cu121_pyt231
|
||||||
python_version: '3.11'
|
python_version: '3.11'
|
||||||
pytorch_version: 2.3.1
|
pytorch_version: 2.3.1
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu118
|
||||||
|
name: linux_conda_py311_cu118_pyt240
|
||||||
|
python_version: '3.11'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py311_cu121_pyt240
|
||||||
|
python_version: '3.11'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu118
|
||||||
|
name: linux_conda_py311_cu118_pyt241
|
||||||
|
python_version: '3.11'
|
||||||
|
pytorch_version: 2.4.1
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py311_cu121_pyt241
|
||||||
|
python_version: '3.11'
|
||||||
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
conda_docker_image: pytorch/conda-builder:cuda118
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
@@ -624,6 +652,34 @@ workflows:
|
|||||||
name: linux_conda_py312_cu121_pyt231
|
name: linux_conda_py312_cu121_pyt231
|
||||||
python_version: '3.12'
|
python_version: '3.12'
|
||||||
pytorch_version: 2.3.1
|
pytorch_version: 2.3.1
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu118
|
||||||
|
name: linux_conda_py312_cu118_pyt240
|
||||||
|
python_version: '3.12'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py312_cu121_pyt240
|
||||||
|
python_version: '3.12'
|
||||||
|
pytorch_version: 2.4.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda118
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu118
|
||||||
|
name: linux_conda_py312_cu118_pyt241
|
||||||
|
python_version: '3.12'
|
||||||
|
pytorch_version: 2.4.1
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda121
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu121
|
||||||
|
name: linux_conda_py312_cu121_pyt241
|
||||||
|
python_version: '3.12'
|
||||||
|
pytorch_version: 2.4.1
|
||||||
- binary_linux_conda_cuda:
|
- binary_linux_conda_cuda:
|
||||||
name: testrun_conda_cuda_py310_cu117_pyt201
|
name: testrun_conda_cuda_py310_cu117_pyt201
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
|
|||||||
@@ -19,14 +19,14 @@ from packaging import version
|
|||||||
# The CUDA versions which have pytorch conda packages available for linux for each
|
# The CUDA versions which have pytorch conda packages available for linux for each
|
||||||
# version of pytorch.
|
# version of pytorch.
|
||||||
CONDA_CUDA_VERSIONS = {
|
CONDA_CUDA_VERSIONS = {
|
||||||
"2.0.0": ["cu117", "cu118"],
|
|
||||||
"2.0.1": ["cu117", "cu118"],
|
|
||||||
"2.1.0": ["cu118", "cu121"],
|
"2.1.0": ["cu118", "cu121"],
|
||||||
"2.1.1": ["cu118", "cu121"],
|
"2.1.1": ["cu118", "cu121"],
|
||||||
"2.1.2": ["cu118", "cu121"],
|
"2.1.2": ["cu118", "cu121"],
|
||||||
"2.2.0": ["cu118", "cu121"],
|
"2.2.0": ["cu118", "cu121"],
|
||||||
"2.2.2": ["cu118", "cu121"],
|
"2.2.2": ["cu118", "cu121"],
|
||||||
"2.3.1": ["cu118", "cu121"],
|
"2.3.1": ["cu118", "cu121"],
|
||||||
|
"2.4.0": ["cu118", "cu121"],
|
||||||
|
"2.4.1": ["cu118", "cu121"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -88,7 +88,6 @@ def workflow_pair(
|
|||||||
upload=False,
|
upload=False,
|
||||||
filter_branch,
|
filter_branch,
|
||||||
):
|
):
|
||||||
|
|
||||||
w = []
|
w = []
|
||||||
py = python_version.replace(".", "")
|
py = python_version.replace(".", "")
|
||||||
pyt = pytorch_version.replace(".", "")
|
pyt = pytorch_version.replace(".", "")
|
||||||
@@ -127,7 +126,6 @@ def generate_base_workflow(
|
|||||||
btype,
|
btype,
|
||||||
filter_branch=None,
|
filter_branch=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
d = {
|
d = {
|
||||||
"name": base_workflow_name,
|
"name": base_workflow_name,
|
||||||
"python_version": python_version,
|
"python_version": python_version,
|
||||||
|
|||||||
23
.github/workflows/build.yml
vendored
Normal file
23
.github/workflows/build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
name: facebookresearch/pytorch3d/build_and_test
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
jobs:
|
||||||
|
binary_linux_conda_cuda:
|
||||||
|
runs-on: 4-core-ubuntu-gpu-t4
|
||||||
|
env:
|
||||||
|
PYTHON_VERSION: "3.12"
|
||||||
|
BUILD_VERSION: "${{ github.run_number }}"
|
||||||
|
PYTORCH_VERSION: "2.4.1"
|
||||||
|
CU_VERSION: "cu121"
|
||||||
|
JUST_TESTRUN: 1
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- name: Build and run tests
|
||||||
|
run: |-
|
||||||
|
conda create --name env --yes --quiet conda-build
|
||||||
|
conda run --no-capture-output --name env python3 ./packaging/build_conda.py --use-conda-cuda
|
||||||
11
INSTALL.md
11
INSTALL.md
@@ -8,11 +8,10 @@
|
|||||||
The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/PyTorch. It is advised to use PyTorch3D with GPU support in order to use all the features.
|
The core library is written in PyTorch. Several components have underlying implementation in CUDA for improved performance. A subset of these components have CPU implementations in C++/PyTorch. It is advised to use PyTorch3D with GPU support in order to use all the features.
|
||||||
|
|
||||||
- Linux or macOS or Windows
|
- Linux or macOS or Windows
|
||||||
- Python 3.8, 3.9 or 3.10
|
- Python
|
||||||
- PyTorch 2.0.0, 2.0.1, 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0 or 2.3.1.
|
- PyTorch 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0 or 2.4.1.
|
||||||
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
|
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
|
||||||
- gcc & g++ ≥ 4.9
|
- gcc & g++ ≥ 4.9
|
||||||
- [fvcore](https://github.com/facebookresearch/fvcore)
|
|
||||||
- [ioPath](https://github.com/facebookresearch/iopath)
|
- [ioPath](https://github.com/facebookresearch/iopath)
|
||||||
- If CUDA is to be used, use a version which is supported by the corresponding pytorch version and at least version 9.2.
|
- If CUDA is to be used, use a version which is supported by the corresponding pytorch version and at least version 9.2.
|
||||||
- If CUDA older than 11.7 is to be used and you are building from source, the CUB library must be available. We recommend version 1.10.0.
|
- If CUDA older than 11.7 is to be used and you are building from source, the CUB library must be available. We recommend version 1.10.0.
|
||||||
@@ -22,7 +21,7 @@ The runtime dependencies can be installed by running:
|
|||||||
conda create -n pytorch3d python=3.9
|
conda create -n pytorch3d python=3.9
|
||||||
conda activate pytorch3d
|
conda activate pytorch3d
|
||||||
conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
|
conda install pytorch=1.13.0 torchvision pytorch-cuda=11.6 -c pytorch -c nvidia
|
||||||
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
|
conda install -c iopath iopath
|
||||||
```
|
```
|
||||||
|
|
||||||
For the CUB build time dependency, which you only need if you have CUDA older than 11.7, if you are using conda, you can continue with
|
For the CUB build time dependency, which you only need if you have CUDA older than 11.7, if you are using conda, you can continue with
|
||||||
@@ -49,6 +48,7 @@ For developing on top of PyTorch3D or contributing, you will need to run the lin
|
|||||||
- tdqm
|
- tdqm
|
||||||
- jupyter
|
- jupyter
|
||||||
- imageio
|
- imageio
|
||||||
|
- fvcore
|
||||||
- plotly
|
- plotly
|
||||||
- opencv-python
|
- opencv-python
|
||||||
|
|
||||||
@@ -59,6 +59,7 @@ conda install jupyter
|
|||||||
pip install scikit-image matplotlib imageio plotly opencv-python
|
pip install scikit-image matplotlib imageio plotly opencv-python
|
||||||
|
|
||||||
# Tests/Linting
|
# Tests/Linting
|
||||||
|
conda install -c fvcore -c conda-forge fvcore
|
||||||
pip install black usort flake8 flake8-bugbear flake8-comprehensions
|
pip install black usort flake8 flake8-bugbear flake8-comprehensions
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -97,7 +98,7 @@ version_str="".join([
|
|||||||
torch.version.cuda.replace(".",""),
|
torch.version.cuda.replace(".",""),
|
||||||
f"_pyt{pyt_version_str}"
|
f"_pyt{pyt_version_str}"
|
||||||
])
|
])
|
||||||
!pip install fvcore iopath
|
!pip install iopath
|
||||||
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -36,5 +36,5 @@ then
|
|||||||
|
|
||||||
echo "Running pyre..."
|
echo "Running pyre..."
|
||||||
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
echo "To restart/kill pyre server, run 'pyre restart' or 'pyre kill' in fbcode/"
|
||||||
( cd ~/fbsource/fbcode; pyre -l vision/fair/pytorch3d/ )
|
( cd ~/fbsource/fbcode; arc pyre check //vision/fair/pytorch3d/... )
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ conda init bash
|
|||||||
source ~/.bashrc
|
source ~/.bashrc
|
||||||
conda create -y -n myenv python=3.8 matplotlib ipython ipywidgets nbconvert
|
conda create -y -n myenv python=3.8 matplotlib ipython ipywidgets nbconvert
|
||||||
conda activate myenv
|
conda activate myenv
|
||||||
conda install -y -c fvcore -c iopath -c conda-forge fvcore iopath
|
conda install -y -c iopath iopath
|
||||||
conda install -y -c pytorch pytorch=1.6.0 cudatoolkit=10.1 torchvision
|
conda install -y -c pytorch pytorch=1.6.0 cudatoolkit=10.1 torchvision
|
||||||
conda install -y -c pytorch3d-nightly pytorch3d
|
conda install -y -c pytorch3d-nightly pytorch3d
|
||||||
pip install plotly scikit-image
|
pip install plotly scikit-image
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ This example demonstrates the most trivial, direct interface of the pulsar
|
|||||||
sphere renderer. It renders and saves an image with 10 random spheres.
|
sphere renderer. It renders and saves an image with 10 random spheres.
|
||||||
Output: basic.png.
|
Output: basic.png.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from os import path
|
from os import path
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ interface for sphere renderering. It renders and saves an image with
|
|||||||
10 random spheres.
|
10 random spheres.
|
||||||
Output: basic-pt3d.png.
|
Output: basic-pt3d.png.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
|||||||
original camera parameters.
|
original camera parameters.
|
||||||
Output: cam.gif.
|
Output: cam.gif.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from os import path
|
from os import path
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ distorted. Gradient-based optimization is used to converge towards the
|
|||||||
original camera parameters.
|
original camera parameters.
|
||||||
Output: cam-pt3d.gif
|
Output: cam-pt3d.gif
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from os import path
|
from os import path
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ This example is not available yet through the 'unified' interface,
|
|||||||
because opacity support has not landed in PyTorch3D for general data
|
because opacity support has not landed in PyTorch3D for general data
|
||||||
structures yet.
|
structures yet.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from os import path
|
from os import path
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
|||||||
optimization is used to converge towards a faithful
|
optimization is used to converge towards a faithful
|
||||||
scene representation.
|
scene representation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ The scene is initialized with random spheres. Gradient-based
|
|||||||
optimization is used to converge towards a faithful
|
optimization is used to converge towards a faithful
|
||||||
scene representation.
|
scene representation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ sphinx_rtd_theme
|
|||||||
sphinx_markdown_tables
|
sphinx_markdown_tables
|
||||||
numpy
|
numpy
|
||||||
iopath
|
iopath
|
||||||
fvcore
|
|
||||||
https://download.pytorch.org/whl/cpu/torchvision-0.15.2%2Bcpu-cp311-cp311-linux_x86_64.whl
|
https://download.pytorch.org/whl/cpu/torchvision-0.15.2%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
https://download.pytorch.org/whl/cpu/torch-2.0.1%2Bcpu-cp311-cp311-linux_x86_64.whl
|
https://download.pytorch.org/whl/cpu/torch-2.0.1%2Bcpu-cp311-cp311-linux_x86_64.whl
|
||||||
omegaconf
|
omegaconf
|
||||||
|
|||||||
@@ -96,7 +96,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -83,7 +83,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -58,7 +58,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -97,7 +97,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -63,7 +63,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -75,7 +75,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -54,7 +54,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -85,7 +85,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -79,7 +79,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -57,7 +57,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -64,7 +64,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -80,7 +80,7 @@
|
|||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
" torch.version.cuda.replace(\".\",\"\"),\n",
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
" f\"_pyt{pyt_version_str}\"\n",
|
||||||
" ])\n",
|
" ])\n",
|
||||||
" !pip install fvcore iopath\n",
|
" !pip install iopath\n",
|
||||||
" if sys.platform.startswith(\"linux\"):\n",
|
" if sys.platform.startswith(\"linux\"):\n",
|
||||||
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
" print(\"Trying to install wheel for PyTorch3D\")\n",
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
||||||
|
|||||||
@@ -4,10 +4,11 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import argparse
|
||||||
import os.path
|
import os.path
|
||||||
import runpy
|
import runpy
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import List
|
from typing import List, Tuple
|
||||||
|
|
||||||
# required env vars:
|
# required env vars:
|
||||||
# CU_VERSION: E.g. cu112
|
# CU_VERSION: E.g. cu112
|
||||||
@@ -23,7 +24,7 @@ pytorch_major_minor = tuple(int(i) for i in PYTORCH_VERSION.split(".")[:2])
|
|||||||
source_root_dir = os.environ["PWD"]
|
source_root_dir = os.environ["PWD"]
|
||||||
|
|
||||||
|
|
||||||
def version_constraint(version):
|
def version_constraint(version) -> str:
|
||||||
"""
|
"""
|
||||||
Given version "11.3" returns " >=11.3,<11.4"
|
Given version "11.3" returns " >=11.3,<11.4"
|
||||||
"""
|
"""
|
||||||
@@ -32,7 +33,7 @@ def version_constraint(version):
|
|||||||
return f" >={version},<{upper}"
|
return f" >={version},<{upper}"
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_major_minor():
|
def get_cuda_major_minor() -> Tuple[str, str]:
|
||||||
if CU_VERSION == "cpu":
|
if CU_VERSION == "cpu":
|
||||||
raise ValueError("fn only for cuda builds")
|
raise ValueError("fn only for cuda builds")
|
||||||
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
|
if len(CU_VERSION) != 5 or CU_VERSION[:2] != "cu":
|
||||||
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
|
|||||||
return major, minor
|
return major, minor
|
||||||
|
|
||||||
|
|
||||||
def setup_cuda():
|
def setup_cuda(use_conda_cuda: bool) -> List[str]:
|
||||||
if CU_VERSION == "cpu":
|
if CU_VERSION == "cpu":
|
||||||
return
|
return []
|
||||||
major, minor = get_cuda_major_minor()
|
major, minor = get_cuda_major_minor()
|
||||||
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
|
||||||
os.environ["FORCE_CUDA"] = "1"
|
os.environ["FORCE_CUDA"] = "1"
|
||||||
|
|
||||||
basic_nvcc_flags = (
|
basic_nvcc_flags = (
|
||||||
@@ -75,6 +75,15 @@ def setup_cuda():
|
|||||||
|
|
||||||
if os.environ.get("JUST_TESTRUN", "0") != "1":
|
if os.environ.get("JUST_TESTRUN", "0") != "1":
|
||||||
os.environ["NVCC_FLAGS"] = nvcc_flags
|
os.environ["NVCC_FLAGS"] = nvcc_flags
|
||||||
|
if use_conda_cuda:
|
||||||
|
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1"] = "- cuda-toolkit"
|
||||||
|
os.environ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2"] = (
|
||||||
|
f"- cuda-version={major}.{minor}"
|
||||||
|
)
|
||||||
|
return ["-c", f"nvidia/label/cuda-{major}.{minor}.0"]
|
||||||
|
else:
|
||||||
|
os.environ["CUDA_HOME"] = f"/usr/local/cuda-{major}.{minor}/"
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def setup_conda_pytorch_constraint() -> List[str]:
|
def setup_conda_pytorch_constraint() -> List[str]:
|
||||||
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
|
|||||||
return ["-c", "pytorch", "-c", "nvidia"]
|
return ["-c", "pytorch", "-c", "nvidia"]
|
||||||
|
|
||||||
|
|
||||||
def setup_conda_cudatoolkit_constraint():
|
def setup_conda_cudatoolkit_constraint() -> None:
|
||||||
if CU_VERSION == "cpu":
|
if CU_VERSION == "cpu":
|
||||||
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
|
os.environ["CONDA_CPUONLY_FEATURE"] = "- cpuonly"
|
||||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
|
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = ""
|
||||||
@@ -116,14 +125,14 @@ def setup_conda_cudatoolkit_constraint():
|
|||||||
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
|
os.environ["CONDA_CUDATOOLKIT_CONSTRAINT"] = toolkit
|
||||||
|
|
||||||
|
|
||||||
def do_build(start_args: List[str]):
|
def do_build(start_args: List[str]) -> None:
|
||||||
args = start_args.copy()
|
args = start_args.copy()
|
||||||
|
|
||||||
test_flag = os.environ.get("TEST_FLAG")
|
test_flag = os.environ.get("TEST_FLAG")
|
||||||
if test_flag is not None:
|
if test_flag is not None:
|
||||||
args.append(test_flag)
|
args.append(test_flag)
|
||||||
|
|
||||||
args.extend(["-c", "bottler", "-c", "fvcore", "-c", "iopath", "-c", "conda-forge"])
|
args.extend(["-c", "bottler", "-c", "iopath", "-c", "conda-forge"])
|
||||||
args.append("--no-anaconda-upload")
|
args.append("--no-anaconda-upload")
|
||||||
args.extend(["--python", os.environ["PYTHON_VERSION"]])
|
args.extend(["--python", os.environ["PYTHON_VERSION"]])
|
||||||
args.append("packaging/pytorch3d")
|
args.append("packaging/pytorch3d")
|
||||||
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Build the conda package.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-conda-cuda",
|
||||||
|
action="store_true",
|
||||||
|
help="get cuda from conda ignoring local cuda",
|
||||||
|
)
|
||||||
|
our_args = parser.parse_args()
|
||||||
|
|
||||||
args = ["conda", "build"]
|
args = ["conda", "build"]
|
||||||
setup_cuda()
|
args += setup_cuda(use_conda_cuda=our_args.use_conda_cuda)
|
||||||
|
|
||||||
init_path = source_root_dir + "/pytorch3d/__init__.py"
|
init_path = source_root_dir + "/pytorch3d/__init__.py"
|
||||||
build_version = runpy.run_path(init_path)["__version__"]
|
build_version = runpy.run_path(init_path)["__version__"]
|
||||||
|
|||||||
@@ -26,6 +26,6 @@ version_str="".join([
|
|||||||
torch.version.cuda.replace(".",""),
|
torch.version.cuda.replace(".",""),
|
||||||
f"_pyt{pyt_version_str}"
|
f"_pyt{pyt_version_str}"
|
||||||
])
|
])
|
||||||
!pip install fvcore iopath
|
!pip install iopath
|
||||||
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
!pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ do
|
|||||||
conda activate "$tag"
|
conda activate "$tag"
|
||||||
# shellcheck disable=SC2086
|
# shellcheck disable=SC2086
|
||||||
conda install -y -c pytorch $extra_channel "pytorch=$pytorch_version" "$cudatools=$CUDA_TAG"
|
conda install -y -c pytorch $extra_channel "pytorch=$pytorch_version" "$cudatools=$CUDA_TAG"
|
||||||
pip install fvcore iopath
|
pip install iopath
|
||||||
echo "python version" "$python_version" "pytorch version" "$pytorch_version" "cuda version" "$cu_version" "tag" "$tag"
|
echo "python version" "$python_version" "pytorch version" "$pytorch_version" "cuda version" "$cu_version" "tag" "$tag"
|
||||||
|
|
||||||
rm -rf dist
|
rm -rf dist
|
||||||
|
|||||||
@@ -8,10 +8,13 @@ source:
|
|||||||
requirements:
|
requirements:
|
||||||
build:
|
build:
|
||||||
- {{ compiler('c') }} # [win]
|
- {{ compiler('c') }} # [win]
|
||||||
|
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1', '') }}
|
||||||
|
{{ environ.get('CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2', '') }}
|
||||||
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
|
{{ environ.get('CONDA_CUB_CONSTRAINT') }}
|
||||||
|
|
||||||
host:
|
host:
|
||||||
- python
|
- python
|
||||||
|
- mkl =2023 # [x86_64]
|
||||||
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
|
{{ environ.get('SETUPTOOLS_CONSTRAINT') }}
|
||||||
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
|
{{ environ.get('CONDA_PYTORCH_BUILD_CONSTRAINT') }}
|
||||||
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
|
{{ environ.get('CONDA_PYTORCH_MKL_CONSTRAINT') }}
|
||||||
@@ -22,7 +25,7 @@ requirements:
|
|||||||
- python
|
- python
|
||||||
- numpy >=1.11
|
- numpy >=1.11
|
||||||
- torchvision >=0.5
|
- torchvision >=0.5
|
||||||
- fvcore
|
- mkl =2023 # [x86_64]
|
||||||
- iopath
|
- iopath
|
||||||
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
|
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}
|
||||||
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
|
{{ environ.get('CONDA_CUDATOOLKIT_CONSTRAINT') }}
|
||||||
@@ -48,8 +51,11 @@ test:
|
|||||||
- imageio
|
- imageio
|
||||||
- hydra-core
|
- hydra-core
|
||||||
- accelerate
|
- accelerate
|
||||||
|
- matplotlib
|
||||||
|
- tabulate
|
||||||
|
- pandas
|
||||||
|
- sqlalchemy
|
||||||
commands:
|
commands:
|
||||||
#pytest .
|
|
||||||
python -m unittest discover -v -s tests -t .
|
python -m unittest discover -v -s tests -t .
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
""""
|
""" "
|
||||||
This file is the entry point for launching experiments with Implicitron.
|
This file is the entry point for launching experiments with Implicitron.
|
||||||
|
|
||||||
Launch Training
|
Launch Training
|
||||||
@@ -44,6 +44,7 @@ The outputs of the experiment are saved and logged in multiple ways:
|
|||||||
config file.
|
config file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
@@ -99,7 +100,7 @@ except ModuleNotFoundError:
|
|||||||
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Configurable): # pyre-ignore: 13
|
class Experiment(Configurable):
|
||||||
"""
|
"""
|
||||||
This class is at the top level of Implicitron's config hierarchy. Its
|
This class is at the top level of Implicitron's config hierarchy. Its
|
||||||
members are high-level components necessary for training an implicit rende-
|
members are high-level components necessary for training an implicit rende-
|
||||||
@@ -120,12 +121,16 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
will be saved here.
|
will be saved here.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `data_source` is never initialized.
|
||||||
data_source: DataSourceBase
|
data_source: DataSourceBase
|
||||||
data_source_class_type: str = "ImplicitronDataSource"
|
data_source_class_type: str = "ImplicitronDataSource"
|
||||||
|
# pyre-fixme[13]: Attribute `model_factory` is never initialized.
|
||||||
model_factory: ModelFactoryBase
|
model_factory: ModelFactoryBase
|
||||||
model_factory_class_type: str = "ImplicitronModelFactory"
|
model_factory_class_type: str = "ImplicitronModelFactory"
|
||||||
|
# pyre-fixme[13]: Attribute `optimizer_factory` is never initialized.
|
||||||
optimizer_factory: OptimizerFactoryBase
|
optimizer_factory: OptimizerFactoryBase
|
||||||
optimizer_factory_class_type: str = "ImplicitronOptimizerFactory"
|
optimizer_factory_class_type: str = "ImplicitronOptimizerFactory"
|
||||||
|
# pyre-fixme[13]: Attribute `training_loop` is never initialized.
|
||||||
training_loop: TrainingLoopBase
|
training_loop: TrainingLoopBase
|
||||||
training_loop_class_type: str = "ImplicitronTrainingLoop"
|
training_loop_class_type: str = "ImplicitronTrainingLoop"
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelFactoryBase(ReplaceableBase):
|
class ModelFactoryBase(ReplaceableBase):
|
||||||
|
|
||||||
resume: bool = True # resume from the last checkpoint
|
resume: bool = True # resume from the last checkpoint
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||||
@@ -45,7 +44,7 @@ class ModelFactoryBase(ReplaceableBase):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
class ImplicitronModelFactory(ModelFactoryBase):
|
||||||
"""
|
"""
|
||||||
A factory class that initializes an implicit rendering model.
|
A factory class that initializes an implicit rendering model.
|
||||||
|
|
||||||
@@ -61,6 +60,7 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `model` is never initialized.
|
||||||
model: ImplicitronModelBase
|
model: ImplicitronModelBase
|
||||||
model_class_type: str = "GenericModel"
|
model_class_type: str = "GenericModel"
|
||||||
resume: bool = True
|
resume: bool = True
|
||||||
@@ -115,7 +115,9 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|||||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||||
}
|
}
|
||||||
model_state_dict = torch.load(
|
model_state_dict = torch.load(
|
||||||
model_io.get_model_path(model_path), map_location=map_location
|
model_io.get_model_path(model_path),
|
||||||
|
map_location=map_location,
|
||||||
|
weights_only=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
"""
|
"""
|
||||||
# Get the parameters to optimize
|
# Get the parameters to optimize
|
||||||
if hasattr(model, "_get_param_groups"): # use the model function
|
if hasattr(model, "_get_param_groups"): # use the model function
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||||
else:
|
else:
|
||||||
p_groups = [
|
p_groups = [
|
||||||
@@ -241,7 +242,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
map_location = {
|
map_location = {
|
||||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||||
}
|
}
|
||||||
optimizer_state = torch.load(opt_path, map_location)
|
optimizer_state = torch.load(opt_path, map_location, weights_only=True)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||||
return optimizer_state
|
return optimizer_state
|
||||||
|
|||||||
@@ -30,13 +30,13 @@ from .utils import seed_all_random_engines
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# pyre-fixme[13]: Attribute `evaluator` is never initialized.
|
|
||||||
class TrainingLoopBase(ReplaceableBase):
|
class TrainingLoopBase(ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
Members:
|
Members:
|
||||||
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `evaluator` is never initialized.
|
||||||
evaluator: Optional[EvaluatorBase]
|
evaluator: Optional[EvaluatorBase]
|
||||||
evaluator_class_type: Optional[str] = "ImplicitronEvaluator"
|
evaluator_class_type: Optional[str] = "ImplicitronEvaluator"
|
||||||
|
|
||||||
@@ -161,7 +161,6 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
for epoch in range(start_epoch, self.max_epochs):
|
for epoch in range(start_epoch, self.max_epochs):
|
||||||
# automatic new_epoch and plotting of stats at every epoch start
|
# automatic new_epoch and plotting of stats at every epoch start
|
||||||
with stats:
|
with stats:
|
||||||
|
|
||||||
# Make sure to re-seed random generators to ensure reproducibility
|
# Make sure to re-seed random generators to ensure reproducibility
|
||||||
# even after restart.
|
# even after restart.
|
||||||
seed_all_random_engines(seed + epoch)
|
seed_all_random_engines(seed + epoch)
|
||||||
@@ -395,6 +394,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):
|
|||||||
):
|
):
|
||||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||||
if hasattr(model, "visualize"):
|
if hasattr(model, "visualize"):
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
model.visualize(
|
model.visualize(
|
||||||
viz,
|
viz,
|
||||||
visdom_env_imgs,
|
visdom_env_imgs,
|
||||||
|
|||||||
@@ -53,12 +53,8 @@ class TestExperiment(unittest.TestCase):
|
|||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||||
"JsonIndexDatasetMapProvider"
|
"JsonIndexDatasetMapProvider"
|
||||||
)
|
)
|
||||||
dataset_args = (
|
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
)
|
|
||||||
dataloader_args = (
|
|
||||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
|
||||||
)
|
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
@@ -94,12 +90,8 @@ class TestExperiment(unittest.TestCase):
|
|||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
||||||
"JsonIndexDatasetMapProvider"
|
"JsonIndexDatasetMapProvider"
|
||||||
)
|
)
|
||||||
dataset_args = (
|
dataset_args = cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
dataloader_args = cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
)
|
|
||||||
dataloader_args = (
|
|
||||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
|
||||||
)
|
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
@@ -111,9 +103,7 @@ class TestExperiment(unittest.TestCase):
|
|||||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = 2
|
||||||
2
|
|
||||||
)
|
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
experiment.dump_cfg(cfg)
|
experiment.dump_cfg(cfg)
|
||||||
|
|||||||
@@ -81,8 +81,9 @@ class TestOptimizerFactory(unittest.TestCase):
|
|||||||
|
|
||||||
def test_param_overrides_self_param_group_assignment(self):
|
def test_param_overrides_self_param_group_assignment(self):
|
||||||
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
pa, pb, pc = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(3)]
|
||||||
na, nb = Node(params=[pa]), Node(
|
na, nb = (
|
||||||
params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}
|
Node(params=[pa]),
|
||||||
|
Node(params=[pb], param_groups={"self": "pb_self", "p1": "pb_param"}),
|
||||||
)
|
)
|
||||||
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
root = Node(children=[na, nb], params=[pc], param_groups={"m1": "pb_member"})
|
||||||
param_groups = self._get_param_groups(root)
|
param_groups = self._get_param_groups(root)
|
||||||
|
|||||||
@@ -84,9 +84,9 @@ def get_nerf_datasets(
|
|||||||
|
|
||||||
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
if autodownload and any(not os.path.isfile(p) for p in (cameras_path, image_path)):
|
||||||
# Automatically download the data files if missing.
|
# Automatically download the data files if missing.
|
||||||
download_data((dataset_name,), data_root=data_root)
|
download_data([dataset_name], data_root=data_root)
|
||||||
|
|
||||||
train_data = torch.load(cameras_path)
|
train_data = torch.load(cameras_path, weights_only=True)
|
||||||
n_cameras = train_data["cameras"]["R"].shape[0]
|
n_cameras = train_data["cameras"]["R"].shape[0]
|
||||||
|
|
||||||
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
_image_max_image_pixels = Image.MAX_IMAGE_PIXELS
|
||||||
|
|||||||
@@ -194,7 +194,6 @@ class Stats:
|
|||||||
it = self.it[stat_set]
|
it = self.it[stat_set]
|
||||||
|
|
||||||
for stat in self.log_vars:
|
for stat in self.log_vars:
|
||||||
|
|
||||||
if stat not in self.stats[stat_set]:
|
if stat not in self.stats[stat_set]:
|
||||||
self.stats[stat_set][stat] = AverageMeter()
|
self.stats[stat_set][stat] = AverageMeter()
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
|||||||
|
|
||||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
|
|
||||||
# Device on which to run.
|
# Device on which to run.
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = "cuda"
|
device = "cuda"
|
||||||
@@ -63,7 +62,7 @@ def main(cfg: DictConfig):
|
|||||||
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
raise ValueError(f"Model checkpoint {checkpoint_path} does not exist!")
|
||||||
|
|
||||||
print(f"Loading checkpoint {checkpoint_path}.")
|
print(f"Loading checkpoint {checkpoint_path}.")
|
||||||
loaded_data = torch.load(checkpoint_path)
|
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||||
# Do not load the cached xy grid.
|
# Do not load the cached xy grid.
|
||||||
# - this allows setting an arbitrary evaluation image size.
|
# - this allows setting an arbitrary evaluation image size.
|
||||||
state_dict = {
|
state_dict = {
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ class TestRaysampler(unittest.TestCase):
|
|||||||
cameras, rays = [], []
|
cameras, rays = [], []
|
||||||
|
|
||||||
for _ in range(batch_size):
|
for _ in range(batch_size):
|
||||||
|
|
||||||
R = random_rotations(1)
|
R = random_rotations(1)
|
||||||
T = torch.randn(1, 3)
|
T = torch.randn(1, 3)
|
||||||
focal_length = torch.rand(1, 2) + 0.5
|
focal_length = torch.rand(1, 2) + 0.5
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ CONFIG_DIR = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs"
|
|||||||
|
|
||||||
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
@hydra.main(config_path=CONFIG_DIR, config_name="lego")
|
||||||
def main(cfg: DictConfig):
|
def main(cfg: DictConfig):
|
||||||
|
|
||||||
# Set the relevant seeds for reproducibility.
|
# Set the relevant seeds for reproducibility.
|
||||||
np.random.seed(cfg.seed)
|
np.random.seed(cfg.seed)
|
||||||
torch.manual_seed(cfg.seed)
|
torch.manual_seed(cfg.seed)
|
||||||
@@ -77,7 +76,7 @@ def main(cfg: DictConfig):
|
|||||||
# Resume training if requested.
|
# Resume training if requested.
|
||||||
if cfg.resume and os.path.isfile(checkpoint_path):
|
if cfg.resume and os.path.isfile(checkpoint_path):
|
||||||
print(f"Resuming from checkpoint {checkpoint_path}.")
|
print(f"Resuming from checkpoint {checkpoint_path}.")
|
||||||
loaded_data = torch.load(checkpoint_path)
|
loaded_data = torch.load(checkpoint_path, weights_only=True)
|
||||||
model.load_state_dict(loaded_data["model"])
|
model.load_state_dict(loaded_data["model"])
|
||||||
stats = pickle.loads(loaded_data["stats"])
|
stats = pickle.loads(loaded_data["stats"])
|
||||||
print(f" => resuming from epoch {stats.epoch}.")
|
print(f" => resuming from epoch {stats.epoch}.")
|
||||||
@@ -219,7 +218,6 @@ def main(cfg: DictConfig):
|
|||||||
|
|
||||||
# Validation
|
# Validation
|
||||||
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:
|
||||||
|
|
||||||
# Sample a validation camera/image.
|
# Sample a validation camera/image.
|
||||||
val_batch = next(val_dataloader.__iter__())
|
val_batch = next(val_dataloader.__iter__())
|
||||||
val_image, val_camera, camera_idx = val_batch[0].values()
|
val_image, val_camera, camera_idx = val_batch[0].values()
|
||||||
|
|||||||
@@ -6,4 +6,4 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
__version__ = "0.7.7"
|
__version__ = "0.7.8"
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ Some functions which depend on PyTorch or Python versions.
|
|||||||
|
|
||||||
|
|
||||||
def meshgrid_ij(
|
def meshgrid_ij(
|
||||||
*A: Union[torch.Tensor, Sequence[torch.Tensor]]
|
*A: Union[torch.Tensor, Sequence[torch.Tensor]],
|
||||||
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
) -> Tuple[torch.Tensor, ...]: # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||||
|
|||||||
@@ -7,7 +7,6 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <queue>
|
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
|
|
||||||
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ __global__ void alphaCompositeCudaForwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int64_t batch_size = result.size(0);
|
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
@@ -79,7 +78,6 @@ __global__ void alphaCompositeCudaBackwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int64_t batch_size = points_idx.size(0);
|
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ __global__ void weightedSumNormCudaForwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int64_t batch_size = result.size(0);
|
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
@@ -92,7 +91,6 @@ __global__ void weightedSumNormCudaBackwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int64_t batch_size = points_idx.size(0);
|
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ __global__ void weightedSumCudaForwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int64_t batch_size = result.size(0);
|
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
@@ -74,7 +73,6 @@ __global__ void weightedSumCudaBackwardKernel(
|
|||||||
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
const at::PackedTensorAccessor64<float, 4, at::RestrictPtrTraits> alphas,
|
||||||
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
const at::PackedTensorAccessor64<int64_t, 4, at::RestrictPtrTraits> points_idx) {
|
||||||
// clang-format on
|
// clang-format on
|
||||||
const int64_t batch_size = points_idx.size(0);
|
|
||||||
const int64_t C = features.size(0);
|
const int64_t C = features.size(0);
|
||||||
const int64_t H = points_idx.size(2);
|
const int64_t H = points_idx.size(2);
|
||||||
const int64_t W = points_idx.size(3);
|
const int64_t W = points_idx.size(3);
|
||||||
|
|||||||
@@ -99,6 +99,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
m.def("marching_cubes", &MarchingCubes);
|
m.def("marching_cubes", &MarchingCubes);
|
||||||
|
|
||||||
// Pulsar.
|
// Pulsar.
|
||||||
|
// Pulsar not enabled on AMD.
|
||||||
#ifdef PULSAR_LOGGING_ENABLED
|
#ifdef PULSAR_LOGGING_ENABLED
|
||||||
c10::ShowLogInfoToStderr();
|
c10::ShowLogInfoToStderr();
|
||||||
#endif
|
#endif
|
||||||
@@ -148,10 +149,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|||||||
py::arg("gamma"),
|
py::arg("gamma"),
|
||||||
py::arg("max_depth"),
|
py::arg("max_depth"),
|
||||||
py::arg("min_depth") /* = 0.f*/,
|
py::arg("min_depth") /* = 0.f*/,
|
||||||
py::arg(
|
py::arg("bg_col") /* = std::nullopt not exposed properly in
|
||||||
"bg_col") /* = at::nullopt not exposed properly in pytorch 1.1. */
|
pytorch 1.1. */
|
||||||
,
|
,
|
||||||
py::arg("opacity") /* = at::nullopt ... */,
|
py::arg("opacity") /* = std::nullopt ... */,
|
||||||
py::arg("percent_allowed_difference") = 0.01f,
|
py::arg("percent_allowed_difference") = 0.01f,
|
||||||
py::arg("max_n_hits") = MAX_UINT,
|
py::arg("max_n_hits") = MAX_UINT,
|
||||||
py::arg("mode") = 0)
|
py::arg("mode") = 0)
|
||||||
|
|||||||
@@ -7,10 +7,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <torch/torch.h>
|
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <numeric>
|
|
||||||
#include <queue>
|
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include "iou_box3d/iou_utils.h"
|
#include "iou_box3d/iou_utils.h"
|
||||||
|
|
||||||
|
|||||||
@@ -461,10 +461,8 @@ __device__ inline std::tuple<float3, float3> ArgMaxVerts(
|
|||||||
__device__ inline bool IsCoplanarTriTri(
|
__device__ inline bool IsCoplanarTriTri(
|
||||||
const FaceVerts& tri1,
|
const FaceVerts& tri1,
|
||||||
const FaceVerts& tri2) {
|
const FaceVerts& tri2) {
|
||||||
const float3 tri1_ctr = FaceCenter({tri1.v0, tri1.v1, tri1.v2});
|
|
||||||
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
const float3 tri1_n = FaceNormal({tri1.v0, tri1.v1, tri1.v2});
|
||||||
|
|
||||||
const float3 tri2_ctr = FaceCenter({tri2.v0, tri2.v1, tri2.v2});
|
|
||||||
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
const float3 tri2_n = FaceNormal({tri2.v0, tri2.v1, tri2.v2});
|
||||||
|
|
||||||
// Check if parallel
|
// Check if parallel
|
||||||
@@ -500,7 +498,6 @@ __device__ inline bool IsCoplanarTriPlane(
|
|||||||
const FaceVerts& tri,
|
const FaceVerts& tri,
|
||||||
const FaceVerts& plane,
|
const FaceVerts& plane,
|
||||||
const float3& normal) {
|
const float3& normal) {
|
||||||
const float3 tri_ctr = FaceCenter({tri.v0, tri.v1, tri.v2});
|
|
||||||
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
const float3 nt = FaceNormal({tri.v0, tri.v1, tri.v2});
|
||||||
|
|
||||||
// check if parallel
|
// check if parallel
|
||||||
@@ -728,7 +725,7 @@ __device__ inline int BoxIntersections(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Update the face_verts_out tris
|
// Update the face_verts_out tris
|
||||||
num_tris = offset;
|
num_tris = min(MAX_TRIS, offset);
|
||||||
for (int j = 0; j < num_tris; ++j) {
|
for (int j = 0; j < num_tris; ++j) {
|
||||||
face_verts_out[j] = tri_verts_updated[j];
|
face_verts_out[j] = tri_verts_updated[j];
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,9 +8,7 @@
|
|||||||
|
|
||||||
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
#include <torch/csrc/autograd/VariableTypeUtils.h>
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <algorithm>
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <thread>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
// In the x direction, the location {0, ..., grid_size_x - 1} correspond to
|
||||||
|
|||||||
@@ -36,11 +36,13 @@
|
|||||||
#pragma nv_diag_suppress 2951
|
#pragma nv_diag_suppress 2951
|
||||||
#pragma nv_diag_suppress 2967
|
#pragma nv_diag_suppress 2967
|
||||||
#else
|
#else
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
#pragma diag_suppress = attribute_not_allowed
|
#pragma diag_suppress = attribute_not_allowed
|
||||||
#pragma diag_suppress = 1866
|
#pragma diag_suppress = 1866
|
||||||
#pragma diag_suppress = 2941
|
#pragma diag_suppress = 2941
|
||||||
#pragma diag_suppress = 2951
|
#pragma diag_suppress = 2951
|
||||||
#pragma diag_suppress = 2967
|
#pragma diag_suppress = 2967
|
||||||
|
#endif //! USE_ROCM
|
||||||
#endif
|
#endif
|
||||||
#else // __CUDACC__
|
#else // __CUDACC__
|
||||||
#define INLINE inline
|
#define INLINE inline
|
||||||
@@ -56,7 +58,9 @@
|
|||||||
#pragma clang diagnostic pop
|
#pragma clang diagnostic pop
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
#include <vector_functions.h>
|
#include <vector_functions.h>
|
||||||
|
#endif //! USE_ROCM
|
||||||
#else
|
#else
|
||||||
#ifndef cudaStream_t
|
#ifndef cudaStream_t
|
||||||
typedef void* cudaStream_t;
|
typedef void* cudaStream_t;
|
||||||
|
|||||||
@@ -59,6 +59,11 @@ getLastCudaError(const char* errorMessage, const char* file, const int line) {
|
|||||||
#define SHARED __shared__
|
#define SHARED __shared__
|
||||||
#define ACTIVEMASK() __activemask()
|
#define ACTIVEMASK() __activemask()
|
||||||
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
#define BALLOT(mask, val) __ballot_sync((mask), val)
|
||||||
|
|
||||||
|
/* TODO (ROCM-6.2): None of the WARP_* are used anywhere and ROCM-6.2 natively
|
||||||
|
* supports __shfl_*. Disabling until the move to ROCM-6.2.
|
||||||
|
*/
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
/**
|
/**
|
||||||
* Find the cumulative sum within a warp up to the current
|
* Find the cumulative sum within a warp up to the current
|
||||||
* thread lane, with each mask thread contributing base.
|
* thread lane, with each mask thread contributing base.
|
||||||
@@ -115,6 +120,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
|||||||
ret.z = WARP_SUM(group, mask, base.z);
|
ret.z = WARP_SUM(group, mask, base.z);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
#endif //! USE_ROCM
|
||||||
|
|
||||||
// Floating point.
|
// Floating point.
|
||||||
// #define FMUL(a, b) __fmul_rn((a), (b))
|
// #define FMUL(a, b) __fmul_rn((a), (b))
|
||||||
@@ -142,6 +148,7 @@ INLINE DEVICE float3 WARP_SUM_FLOAT3(
|
|||||||
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
#define FMA(x, y, z) __fmaf_rn((x), (y), (z))
|
||||||
#define I2F(a) __int2float_rn(a)
|
#define I2F(a) __int2float_rn(a)
|
||||||
#define FRCP(x) __frcp_rn(x)
|
#define FRCP(x) __frcp_rn(x)
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__device__ static float atomicMax(float* address, float val) {
|
__device__ static float atomicMax(float* address, float val) {
|
||||||
int* address_as_i = (int*)address;
|
int* address_as_i = (int*)address;
|
||||||
int old = *address_as_i, assumed;
|
int old = *address_as_i, assumed;
|
||||||
@@ -166,6 +173,7 @@ __device__ static float atomicMin(float* address, float val) {
|
|||||||
} while (assumed != old);
|
} while (assumed != old);
|
||||||
return __int_as_float(old);
|
return __int_as_float(old);
|
||||||
}
|
}
|
||||||
|
#endif //! USE_ROCM
|
||||||
#define DMAX(a, b) FMAX(a, b)
|
#define DMAX(a, b) FMAX(a, b)
|
||||||
#define DMIN(a, b) FMIN(a, b)
|
#define DMIN(a, b) FMIN(a, b)
|
||||||
#define DSQRT(a) sqrt(a)
|
#define DSQRT(a) sqrt(a)
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
#include "./commands.h"
|
#include "./commands.h"
|
||||||
|
|
||||||
namespace pulsar {
|
namespace pulsar {
|
||||||
IHD CamGradInfo::CamGradInfo() {
|
IHD CamGradInfo::CamGradInfo(int x) {
|
||||||
cam_pos = make_float3(0.f, 0.f, 0.f);
|
cam_pos = make_float3(0.f, 0.f, 0.f);
|
||||||
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
pixel_0_0_center = make_float3(0.f, 0.f, 0.f);
|
||||||
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
pixel_dir_x = make_float3(0.f, 0.f, 0.f);
|
||||||
|
|||||||
@@ -63,7 +63,7 @@ inline bool operator==(const CamInfo& a, const CamInfo& b) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct CamGradInfo {
|
struct CamGradInfo {
|
||||||
HOST DEVICE CamGradInfo();
|
HOST DEVICE CamGradInfo(int = 0);
|
||||||
float3 cam_pos;
|
float3 cam_pos;
|
||||||
float3 pixel_0_0_center;
|
float3 pixel_0_0_center;
|
||||||
float3 pixel_dir_x;
|
float3 pixel_dir_x;
|
||||||
|
|||||||
@@ -24,7 +24,7 @@
|
|||||||
// #pragma diag_suppress = 68
|
// #pragma diag_suppress = 68
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
// #pragma pop
|
// #pragma pop
|
||||||
#include "../cuda/commands.h"
|
#include "../gpu/commands.h"
|
||||||
#else
|
#else
|
||||||
#pragma clang diagnostic push
|
#pragma clang diagnostic push
|
||||||
#pragma clang diagnostic ignored "-Weverything"
|
#pragma clang diagnostic ignored "-Weverything"
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ IHD float3 outer_product_sum(const float3& a) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: put intrinsics here.
|
// TODO: put intrinsics here.
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
IHD float3 operator+(const float3& a, const float3& b) {
|
IHD float3 operator+(const float3& a, const float3& b) {
|
||||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||||
}
|
}
|
||||||
@@ -93,6 +94,7 @@ IHD float3 operator*(const float3& a, const float3& b) {
|
|||||||
IHD float3 operator*(const float& a, const float3& b) {
|
IHD float3 operator*(const float& a, const float3& b) {
|
||||||
return b * a;
|
return b * a;
|
||||||
}
|
}
|
||||||
|
#endif //! USE_ROCM
|
||||||
|
|
||||||
INLINE DEVICE float length(const float3& v) {
|
INLINE DEVICE float length(const float3& v) {
|
||||||
// TODO: benchmark what's faster.
|
// TODO: benchmark what's faster.
|
||||||
|
|||||||
@@ -283,9 +283,15 @@ GLOBAL void render(
|
|||||||
(percent_allowed_difference > 0.f &&
|
(percent_allowed_difference > 0.f &&
|
||||||
max_closest_possible_intersection > depth_threshold) ||
|
max_closest_possible_intersection > depth_threshold) ||
|
||||||
tracker.get_n_hits() >= max_n_hits;
|
tracker.get_n_hits() >= max_n_hits;
|
||||||
|
#if defined(__CUDACC__) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
unsigned long long warp_done = __ballot(done);
|
||||||
|
int warp_done_bit_cnt = __popcll(warp_done);
|
||||||
|
#else
|
||||||
uint warp_done = thread_warp.ballot(done);
|
uint warp_done = thread_warp.ballot(done);
|
||||||
|
int warp_done_bit_cnt = POPC(warp_done);
|
||||||
|
#endif //__CUDACC__ && __HIP_PLATFORM_AMD__
|
||||||
if (thread_warp.thread_rank() == 0)
|
if (thread_warp.thread_rank() == 0)
|
||||||
ATOMICADD_B(&n_pixels_done, POPC(warp_done));
|
ATOMICADD_B(&n_pixels_done, warp_done_bit_cnt);
|
||||||
// This sync is necessary to keep n_loaded until all threads are done with
|
// This sync is necessary to keep n_loaded until all threads are done with
|
||||||
// painting.
|
// painting.
|
||||||
thread_block.sync();
|
thread_block.sync();
|
||||||
|
|||||||
@@ -213,8 +213,8 @@ std::tuple<size_t, size_t, bool, torch::Tensor> Renderer::arg_check(
|
|||||||
const float& gamma,
|
const float& gamma,
|
||||||
const float& max_depth,
|
const float& max_depth,
|
||||||
float& min_depth,
|
float& min_depth,
|
||||||
const c10::optional<torch::Tensor>& bg_col,
|
const std::optional<torch::Tensor>& bg_col,
|
||||||
const c10::optional<torch::Tensor>& opacity,
|
const std::optional<torch::Tensor>& opacity,
|
||||||
const float& percent_allowed_difference,
|
const float& percent_allowed_difference,
|
||||||
const uint& max_n_hits,
|
const uint& max_n_hits,
|
||||||
const uint& mode) {
|
const uint& mode) {
|
||||||
@@ -668,8 +668,8 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
|||||||
const float& gamma,
|
const float& gamma,
|
||||||
const float& max_depth,
|
const float& max_depth,
|
||||||
float min_depth,
|
float min_depth,
|
||||||
const c10::optional<torch::Tensor>& bg_col,
|
const std::optional<torch::Tensor>& bg_col,
|
||||||
const c10::optional<torch::Tensor>& opacity,
|
const std::optional<torch::Tensor>& opacity,
|
||||||
const float& percent_allowed_difference,
|
const float& percent_allowed_difference,
|
||||||
const uint& max_n_hits,
|
const uint& max_n_hits,
|
||||||
const uint& mode) {
|
const uint& mode) {
|
||||||
@@ -888,14 +888,14 @@ std::tuple<torch::Tensor, torch::Tensor> Renderer::forward(
|
|||||||
};
|
};
|
||||||
|
|
||||||
std::tuple<
|
std::tuple<
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>>
|
std::optional<torch::Tensor>>
|
||||||
Renderer::backward(
|
Renderer::backward(
|
||||||
const torch::Tensor& grad_im,
|
const torch::Tensor& grad_im,
|
||||||
const torch::Tensor& image,
|
const torch::Tensor& image,
|
||||||
@@ -912,8 +912,8 @@ Renderer::backward(
|
|||||||
const float& gamma,
|
const float& gamma,
|
||||||
const float& max_depth,
|
const float& max_depth,
|
||||||
float min_depth,
|
float min_depth,
|
||||||
const c10::optional<torch::Tensor>& bg_col,
|
const std::optional<torch::Tensor>& bg_col,
|
||||||
const c10::optional<torch::Tensor>& opacity,
|
const std::optional<torch::Tensor>& opacity,
|
||||||
const float& percent_allowed_difference,
|
const float& percent_allowed_difference,
|
||||||
const uint& max_n_hits,
|
const uint& max_n_hits,
|
||||||
const uint& mode,
|
const uint& mode,
|
||||||
@@ -922,7 +922,7 @@ Renderer::backward(
|
|||||||
const bool& dif_rad,
|
const bool& dif_rad,
|
||||||
const bool& dif_cam,
|
const bool& dif_cam,
|
||||||
const bool& dif_opy,
|
const bool& dif_opy,
|
||||||
const at::optional<std::pair<uint, uint>>& dbg_pos) {
|
const std::optional<std::pair<uint, uint>>& dbg_pos) {
|
||||||
this->ensure_on_device(this->device_tracker.device());
|
this->ensure_on_device(this->device_tracker.device());
|
||||||
size_t batch_size;
|
size_t batch_size;
|
||||||
size_t n_points;
|
size_t n_points;
|
||||||
@@ -1045,14 +1045,14 @@ Renderer::backward(
|
|||||||
}
|
}
|
||||||
// Prepare the return value.
|
// Prepare the return value.
|
||||||
std::tuple<
|
std::tuple<
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>>
|
std::optional<torch::Tensor>>
|
||||||
ret;
|
ret;
|
||||||
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
|
if (mode == 1 || (!dif_pos && !dif_col && !dif_rad && !dif_cam && !dif_opy)) {
|
||||||
return ret;
|
return ret;
|
||||||
|
|||||||
@@ -44,21 +44,21 @@ struct Renderer {
|
|||||||
const float& gamma,
|
const float& gamma,
|
||||||
const float& max_depth,
|
const float& max_depth,
|
||||||
float min_depth,
|
float min_depth,
|
||||||
const c10::optional<torch::Tensor>& bg_col,
|
const std::optional<torch::Tensor>& bg_col,
|
||||||
const c10::optional<torch::Tensor>& opacity,
|
const std::optional<torch::Tensor>& opacity,
|
||||||
const float& percent_allowed_difference,
|
const float& percent_allowed_difference,
|
||||||
const uint& max_n_hits,
|
const uint& max_n_hits,
|
||||||
const uint& mode);
|
const uint& mode);
|
||||||
|
|
||||||
std::tuple<
|
std::tuple<
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>,
|
std::optional<torch::Tensor>,
|
||||||
at::optional<torch::Tensor>>
|
std::optional<torch::Tensor>>
|
||||||
backward(
|
backward(
|
||||||
const torch::Tensor& grad_im,
|
const torch::Tensor& grad_im,
|
||||||
const torch::Tensor& image,
|
const torch::Tensor& image,
|
||||||
@@ -75,8 +75,8 @@ struct Renderer {
|
|||||||
const float& gamma,
|
const float& gamma,
|
||||||
const float& max_depth,
|
const float& max_depth,
|
||||||
float min_depth,
|
float min_depth,
|
||||||
const c10::optional<torch::Tensor>& bg_col,
|
const std::optional<torch::Tensor>& bg_col,
|
||||||
const c10::optional<torch::Tensor>& opacity,
|
const std::optional<torch::Tensor>& opacity,
|
||||||
const float& percent_allowed_difference,
|
const float& percent_allowed_difference,
|
||||||
const uint& max_n_hits,
|
const uint& max_n_hits,
|
||||||
const uint& mode,
|
const uint& mode,
|
||||||
@@ -85,7 +85,7 @@ struct Renderer {
|
|||||||
const bool& dif_rad,
|
const bool& dif_rad,
|
||||||
const bool& dif_cam,
|
const bool& dif_cam,
|
||||||
const bool& dif_opy,
|
const bool& dif_opy,
|
||||||
const at::optional<std::pair<uint, uint>>& dbg_pos);
|
const std::optional<std::pair<uint, uint>>& dbg_pos);
|
||||||
|
|
||||||
// Infrastructure.
|
// Infrastructure.
|
||||||
/**
|
/**
|
||||||
@@ -115,8 +115,8 @@ struct Renderer {
|
|||||||
const float& gamma,
|
const float& gamma,
|
||||||
const float& max_depth,
|
const float& max_depth,
|
||||||
float& min_depth,
|
float& min_depth,
|
||||||
const c10::optional<torch::Tensor>& bg_col,
|
const std::optional<torch::Tensor>& bg_col,
|
||||||
const c10::optional<torch::Tensor>& opacity,
|
const std::optional<torch::Tensor>& opacity,
|
||||||
const float& percent_allowed_difference,
|
const float& percent_allowed_difference,
|
||||||
const uint& max_n_hits,
|
const uint& max_n_hits,
|
||||||
const uint& mode);
|
const uint& mode);
|
||||||
|
|||||||
@@ -8,6 +8,7 @@
|
|||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAException.h>
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
@@ -33,13 +34,13 @@ torch::Tensor sphere_ids_from_result_info_nograd(
|
|||||||
.contiguous();
|
.contiguous();
|
||||||
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
if (forw_info.device().type() == c10::DeviceType::CUDA) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
cudaMemcpyAsync(
|
C10_CUDA_CHECK(cudaMemcpyAsync(
|
||||||
result.data_ptr(),
|
result.data_ptr(),
|
||||||
tmp.data_ptr(),
|
tmp.data_ptr(),
|
||||||
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
sizeof(uint32_t) * tmp.size(0) * tmp.size(1) * tmp.size(2) *
|
||||||
tmp.size(3),
|
tmp.size(3),
|
||||||
cudaMemcpyDeviceToDevice,
|
cudaMemcpyDeviceToDevice,
|
||||||
at::cuda::getCurrentCUDAStream());
|
at::cuda::getCurrentCUDAStream()));
|
||||||
#else
|
#else
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Copy on CUDA device initiated but built "
|
"Copy on CUDA device initiated but built "
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
|
#include <c10/cuda/CUDAException.h>
|
||||||
#include <cuda_runtime_api.h>
|
#include <cuda_runtime_api.h>
|
||||||
|
|
||||||
namespace pulsar {
|
namespace pulsar {
|
||||||
@@ -17,7 +18,8 @@ void cudaDevToDev(
|
|||||||
const void* src,
|
const void* src,
|
||||||
const int& size,
|
const int& size,
|
||||||
const cudaStream_t& stream) {
|
const cudaStream_t& stream) {
|
||||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream);
|
C10_CUDA_CHECK(
|
||||||
|
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToDevice, stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
void cudaDevToHost(
|
void cudaDevToHost(
|
||||||
@@ -25,7 +27,8 @@ void cudaDevToHost(
|
|||||||
const void* src,
|
const void* src,
|
||||||
const int& size,
|
const int& size,
|
||||||
const cudaStream_t& stream) {
|
const cudaStream_t& stream) {
|
||||||
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream);
|
C10_CUDA_CHECK(
|
||||||
|
cudaMemcpyAsync(trg, src, size, cudaMemcpyDeviceToHost, stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace pytorch
|
} // namespace pytorch
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ __device__ void CheckPixelInsideFace(
|
|||||||
const bool zero_face_area =
|
const bool zero_face_area =
|
||||||
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
|
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
|
||||||
|
|
||||||
if (zmax < 0 || cull_backfaces && back_face || outside_bbox ||
|
if (zmax < 0 || (cull_backfaces && back_face) || outside_bbox ||
|
||||||
zero_face_area) {
|
zero_face_area) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,7 +9,6 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <list>
|
#include <list>
|
||||||
#include <queue>
|
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include "ATen/core/TensorAccessor.h"
|
#include "ATen/core/TensorAccessor.h"
|
||||||
|
|||||||
@@ -35,8 +35,6 @@ __global__ void FarthestPointSamplingKernel(
|
|||||||
__shared__ int64_t selected_store;
|
__shared__ int64_t selected_store;
|
||||||
|
|
||||||
// Get constants
|
// Get constants
|
||||||
const int64_t N = points.size(0);
|
|
||||||
const int64_t P = points.size(1);
|
|
||||||
const int64_t D = points.size(2);
|
const int64_t D = points.size(2);
|
||||||
|
|
||||||
// Get batch index and thread index
|
// Get batch index and thread index
|
||||||
|
|||||||
@@ -18,6 +18,8 @@ const auto vEpsilon = 1e-8;
|
|||||||
|
|
||||||
// Common functions and operators for float2.
|
// Common functions and operators for float2.
|
||||||
|
|
||||||
|
// Complex arithmetic is already defined for AMD.
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__device__ inline float2 operator-(const float2& a, const float2& b) {
|
__device__ inline float2 operator-(const float2& a, const float2& b) {
|
||||||
return make_float2(a.x - b.x, a.y - b.y);
|
return make_float2(a.x - b.x, a.y - b.y);
|
||||||
}
|
}
|
||||||
@@ -41,6 +43,7 @@ __device__ inline float2 operator*(const float2& a, const float2& b) {
|
|||||||
__device__ inline float2 operator*(const float a, const float2& b) {
|
__device__ inline float2 operator*(const float a, const float2& b) {
|
||||||
return make_float2(a * b.x, a * b.y);
|
return make_float2(a * b.x, a * b.y);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
__device__ inline float FloatMin3(const float a, const float b, const float c) {
|
__device__ inline float FloatMin3(const float a, const float b, const float c) {
|
||||||
return fminf(a, fminf(b, c));
|
return fminf(a, fminf(b, c));
|
||||||
|
|||||||
@@ -376,8 +376,6 @@ PointLineDistanceBackward(
|
|||||||
float tt = t_top / t_bot;
|
float tt = t_top / t_bot;
|
||||||
tt = __saturatef(tt);
|
tt = __saturatef(tt);
|
||||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||||
const float2 d = p - p_proj;
|
|
||||||
const float dist = sqrt(dot(d, d));
|
|
||||||
|
|
||||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||||
|
|||||||
@@ -23,37 +23,51 @@ WarpReduceMin(scalar_t* min_dists, int64_t* min_idxs, const size_t tid) {
|
|||||||
min_idxs[tid] = min_idxs[tid + 32];
|
min_idxs[tid] = min_idxs[tid + 32];
|
||||||
min_dists[tid] = min_dists[tid + 32];
|
min_dists[tid] = min_dists[tid + 32];
|
||||||
}
|
}
|
||||||
|
// AMD does not use explicit syncwarp and instead automatically inserts memory
|
||||||
|
// fences during compilation.
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
// s = 16
|
// s = 16
|
||||||
if (min_dists[tid] > min_dists[tid + 16]) {
|
if (min_dists[tid] > min_dists[tid + 16]) {
|
||||||
min_idxs[tid] = min_idxs[tid + 16];
|
min_idxs[tid] = min_idxs[tid + 16];
|
||||||
min_dists[tid] = min_dists[tid + 16];
|
min_dists[tid] = min_dists[tid + 16];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
// s = 8
|
// s = 8
|
||||||
if (min_dists[tid] > min_dists[tid + 8]) {
|
if (min_dists[tid] > min_dists[tid + 8]) {
|
||||||
min_idxs[tid] = min_idxs[tid + 8];
|
min_idxs[tid] = min_idxs[tid + 8];
|
||||||
min_dists[tid] = min_dists[tid + 8];
|
min_dists[tid] = min_dists[tid + 8];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
// s = 4
|
// s = 4
|
||||||
if (min_dists[tid] > min_dists[tid + 4]) {
|
if (min_dists[tid] > min_dists[tid + 4]) {
|
||||||
min_idxs[tid] = min_idxs[tid + 4];
|
min_idxs[tid] = min_idxs[tid + 4];
|
||||||
min_dists[tid] = min_dists[tid + 4];
|
min_dists[tid] = min_dists[tid + 4];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
// s = 2
|
// s = 2
|
||||||
if (min_dists[tid] > min_dists[tid + 2]) {
|
if (min_dists[tid] > min_dists[tid + 2]) {
|
||||||
min_idxs[tid] = min_idxs[tid + 2];
|
min_idxs[tid] = min_idxs[tid + 2];
|
||||||
min_dists[tid] = min_dists[tid + 2];
|
min_dists[tid] = min_dists[tid + 2];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
// s = 1
|
// s = 1
|
||||||
if (min_dists[tid] > min_dists[tid + 1]) {
|
if (min_dists[tid] > min_dists[tid + 1]) {
|
||||||
min_idxs[tid] = min_idxs[tid + 1];
|
min_idxs[tid] = min_idxs[tid + 1];
|
||||||
min_dists[tid] = min_dists[tid + 1];
|
min_dists[tid] = min_dists[tid + 1];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename scalar_t>
|
template <typename scalar_t>
|
||||||
@@ -65,30 +79,42 @@ __device__ void WarpReduceMax(
|
|||||||
dists[tid] = dists[tid + 32];
|
dists[tid] = dists[tid + 32];
|
||||||
dists_idx[tid] = dists_idx[tid + 32];
|
dists_idx[tid] = dists_idx[tid + 32];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
if (dists[tid] < dists[tid + 16]) {
|
if (dists[tid] < dists[tid + 16]) {
|
||||||
dists[tid] = dists[tid + 16];
|
dists[tid] = dists[tid + 16];
|
||||||
dists_idx[tid] = dists_idx[tid + 16];
|
dists_idx[tid] = dists_idx[tid + 16];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
if (dists[tid] < dists[tid + 8]) {
|
if (dists[tid] < dists[tid + 8]) {
|
||||||
dists[tid] = dists[tid + 8];
|
dists[tid] = dists[tid + 8];
|
||||||
dists_idx[tid] = dists_idx[tid + 8];
|
dists_idx[tid] = dists_idx[tid + 8];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
if (dists[tid] < dists[tid + 4]) {
|
if (dists[tid] < dists[tid + 4]) {
|
||||||
dists[tid] = dists[tid + 4];
|
dists[tid] = dists[tid + 4];
|
||||||
dists_idx[tid] = dists_idx[tid + 4];
|
dists_idx[tid] = dists_idx[tid + 4];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
if (dists[tid] < dists[tid + 2]) {
|
if (dists[tid] < dists[tid + 2]) {
|
||||||
dists[tid] = dists[tid + 2];
|
dists[tid] = dists[tid + 2];
|
||||||
dists_idx[tid] = dists_idx[tid + 2];
|
dists_idx[tid] = dists_idx[tid + 2];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
if (dists[tid] < dists[tid + 1]) {
|
if (dists[tid] < dists[tid + 1]) {
|
||||||
dists[tid] = dists[tid + 1];
|
dists[tid] = dists[tid + 1];
|
||||||
dists_idx[tid] = dists_idx[tid + 1];
|
dists_idx[tid] = dists_idx[tid + 1];
|
||||||
}
|
}
|
||||||
|
#if !defined(USE_ROCM)
|
||||||
__syncwarp();
|
__syncwarp();
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -83,7 +83,7 @@ class ShapeNetCore(ShapeNetBase): # pragma: no cover
|
|||||||
):
|
):
|
||||||
synset_set.add(synset)
|
synset_set.add(synset)
|
||||||
elif (synset in self.synset_inv.keys()) and (
|
elif (synset in self.synset_inv.keys()) and (
|
||||||
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
|
path.isdir(path.join(data_dir, self.synset_inv[synset]))
|
||||||
):
|
):
|
||||||
synset_set.add(self.synset_inv[synset])
|
synset_set.add(self.synset_inv[synset])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ def collate_batched_meshes(batch: List[Dict]): # pragma: no cover
|
|||||||
|
|
||||||
collated_dict["mesh"] = None
|
collated_dict["mesh"] = None
|
||||||
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
if {"verts", "faces"}.issubset(collated_dict.keys()):
|
||||||
|
|
||||||
textures = None
|
textures = None
|
||||||
if "textures" in collated_dict:
|
if "textures" in collated_dict:
|
||||||
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
textures = TexturesAtlas(atlas=collated_dict["textures"])
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class DataSourceBase(ReplaceableBase):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
class ImplicitronDataSource(DataSourceBase):
|
||||||
"""
|
"""
|
||||||
Represents the data used in Implicitron. This is the only implementation
|
Represents the data used in Implicitron. This is the only implementation
|
||||||
of DataSourceBase provided.
|
of DataSourceBase provided.
|
||||||
@@ -52,8 +52,11 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `dataset_map_provider` is never initialized.
|
||||||
dataset_map_provider: DatasetMapProviderBase
|
dataset_map_provider: DatasetMapProviderBase
|
||||||
|
# pyre-fixme[13]: Attribute `dataset_map_provider_class_type` is never initialized.
|
||||||
dataset_map_provider_class_type: str
|
dataset_map_provider_class_type: str
|
||||||
|
# pyre-fixme[13]: Attribute `data_loader_map_provider` is never initialized.
|
||||||
data_loader_map_provider: DataLoaderMapProviderBase
|
data_loader_map_provider: DataLoaderMapProviderBase
|
||||||
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from typing import (
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset import types
|
from pytorch3d.implicitron.dataset import orm_types, types
|
||||||
from pytorch3d.implicitron.dataset.utils import (
|
from pytorch3d.implicitron.dataset.utils import (
|
||||||
adjust_camera_to_bbox_crop_,
|
adjust_camera_to_bbox_crop_,
|
||||||
adjust_camera_to_image_scale_,
|
adjust_camera_to_image_scale_,
|
||||||
@@ -48,8 +48,12 @@ from pytorch3d.implicitron.dataset.utils import (
|
|||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
|
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||||
|
|
||||||
|
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||||
|
SequenceAnnotationT = types.SequenceAnnotation | orm_types.SqlSequenceAnnotation
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FrameData(Mapping[str, Any]):
|
class FrameData(Mapping[str, Any]):
|
||||||
@@ -122,9 +126,9 @@ class FrameData(Mapping[str, Any]):
|
|||||||
meta: A dict for storing additional frame information.
|
meta: A dict for storing additional frame information.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_number: Optional[torch.LongTensor]
|
frame_number: Optional[torch.LongTensor] = None
|
||||||
sequence_name: Union[str, List[str]]
|
sequence_name: Union[str, List[str]] = ""
|
||||||
sequence_category: Union[str, List[str]]
|
sequence_category: Union[str, List[str]] = ""
|
||||||
frame_timestamp: Optional[torch.Tensor] = None
|
frame_timestamp: Optional[torch.Tensor] = None
|
||||||
image_size_hw: Optional[torch.LongTensor] = None
|
image_size_hw: Optional[torch.LongTensor] = None
|
||||||
effective_image_size_hw: Optional[torch.LongTensor] = None
|
effective_image_size_hw: Optional[torch.LongTensor] = None
|
||||||
@@ -155,7 +159,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
new_params = {}
|
new_params = {}
|
||||||
for field_name in iter(self):
|
for field_name in iter(self):
|
||||||
value = getattr(self, field_name)
|
value = getattr(self, field_name)
|
||||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||||
new_params[field_name] = value.to(*args, **kwargs)
|
new_params[field_name] = value.to(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
new_params[field_name] = value
|
new_params[field_name] = value
|
||||||
@@ -276,6 +280,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
image_size_hw=tuple(self.effective_image_size_hw), # pyre-ignore
|
image_size_hw=tuple(self.effective_image_size_hw), # pyre-ignore
|
||||||
)
|
)
|
||||||
crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
crop_bbox_xywh = bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
||||||
|
self.crop_bbox_xywh = crop_bbox_xywh
|
||||||
|
|
||||||
if self.fg_probability is not None:
|
if self.fg_probability is not None:
|
||||||
self.fg_probability = crop_around_box(
|
self.fg_probability = crop_around_box(
|
||||||
@@ -416,7 +421,6 @@ class FrameData(Mapping[str, Any]):
|
|||||||
for f in fields(elem):
|
for f in fields(elem):
|
||||||
if not f.init:
|
if not f.init:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
list_values = override_fields.get(
|
list_values = override_fields.get(
|
||||||
f.name, [getattr(d, f.name) for d in batch]
|
f.name, [getattr(d, f.name) for d in batch]
|
||||||
)
|
)
|
||||||
@@ -425,7 +429,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
if all(list_value is not None for list_value in list_values)
|
if all(list_value is not None for list_value in list_values)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return cls(**collated)
|
return type(elem)(**collated)
|
||||||
|
|
||||||
elif isinstance(elem, Pointclouds):
|
elif isinstance(elem, Pointclouds):
|
||||||
return join_pointclouds_as_batch(batch)
|
return join_pointclouds_as_batch(batch)
|
||||||
@@ -433,6 +437,8 @@ class FrameData(Mapping[str, Any]):
|
|||||||
elif isinstance(elem, CamerasBase):
|
elif isinstance(elem, CamerasBase):
|
||||||
# TODO: don't store K; enforce working in NDC space
|
# TODO: don't store K; enforce working in NDC space
|
||||||
return join_cameras_as_batch(batch)
|
return join_cameras_as_batch(batch)
|
||||||
|
elif isinstance(elem, Meshes):
|
||||||
|
return join_meshes_as_batch(batch)
|
||||||
else:
|
else:
|
||||||
return torch.utils.data.dataloader.default_collate(batch)
|
return torch.utils.data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
@@ -453,8 +459,8 @@ class FrameDataBuilderBase(ReplaceableBase, Generic[FrameDataSubtype], ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: types.FrameAnnotation,
|
frame_annotation: FrameAnnotationT,
|
||||||
sequence_annotation: types.SequenceAnnotation,
|
sequence_annotation: SequenceAnnotationT,
|
||||||
*,
|
*,
|
||||||
load_blobs: bool = True,
|
load_blobs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -540,8 +546,8 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
frame_annotation: types.FrameAnnotation,
|
frame_annotation: FrameAnnotationT,
|
||||||
sequence_annotation: types.SequenceAnnotation,
|
sequence_annotation: SequenceAnnotationT,
|
||||||
*,
|
*,
|
||||||
load_blobs: bool = True,
|
load_blobs: bool = True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@@ -585,58 +591,81 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
fg_mask_np: Optional[np.ndarray] = None
|
dataset_root = self.dataset_root
|
||||||
mask_annotation = frame_annotation.mask
|
mask_annotation = frame_annotation.mask
|
||||||
if mask_annotation is not None:
|
depth_annotation = frame_annotation.depth
|
||||||
if load_blobs and self.load_masks:
|
image_path: str | None = None
|
||||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
mask_path: str | None = None
|
||||||
|
depth_path: str | None = None
|
||||||
|
pcl_path: str | None = None
|
||||||
|
if dataset_root is not None: # set all paths even if we won’t load blobs
|
||||||
|
if frame_annotation.image.path is not None:
|
||||||
|
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||||
|
frame_data.image_path = image_path
|
||||||
|
|
||||||
|
if mask_annotation is not None and mask_annotation.path:
|
||||||
|
mask_path = os.path.join(dataset_root, mask_annotation.path)
|
||||||
frame_data.mask_path = mask_path
|
frame_data.mask_path = mask_path
|
||||||
|
|
||||||
|
if depth_annotation is not None and depth_annotation.path is not None:
|
||||||
|
depth_path = os.path.join(dataset_root, depth_annotation.path)
|
||||||
|
frame_data.depth_path = depth_path
|
||||||
|
|
||||||
|
if point_cloud is not None:
|
||||||
|
pcl_path = os.path.join(dataset_root, point_cloud.path)
|
||||||
|
frame_data.sequence_point_cloud_path = pcl_path
|
||||||
|
|
||||||
|
fg_mask_np: np.ndarray | None = None
|
||||||
|
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||||
|
|
||||||
|
if mask_annotation is not None:
|
||||||
|
if load_blobs and self.load_masks and mask_path:
|
||||||
|
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
|
||||||
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
|
||||||
|
|
||||||
bbox_xywh = mask_annotation.bounding_box_xywh
|
bbox_xywh = mask_annotation.bounding_box_xywh
|
||||||
if bbox_xywh is None and fg_mask_np is not None:
|
|
||||||
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
|
||||||
|
|
||||||
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
|
||||||
|
|
||||||
if frame_annotation.image is not None:
|
if frame_annotation.image is not None:
|
||||||
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
image_size_hw = safe_as_tensor(frame_annotation.image.size, torch.long)
|
||||||
frame_data.image_size_hw = image_size_hw # original image size
|
frame_data.image_size_hw = image_size_hw # original image size
|
||||||
# image size after crop/resize
|
# image size after crop/resize
|
||||||
frame_data.effective_image_size_hw = image_size_hw
|
frame_data.effective_image_size_hw = image_size_hw
|
||||||
image_path = None
|
|
||||||
dataset_root = self.dataset_root
|
|
||||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
|
||||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
|
||||||
frame_data.image_path = image_path
|
|
||||||
|
|
||||||
if load_blobs and self.load_images:
|
if load_blobs and self.load_images:
|
||||||
if image_path is None:
|
if image_path is None:
|
||||||
raise ValueError("Image path is required to load images.")
|
raise ValueError("Image path is required to load images.")
|
||||||
|
|
||||||
image_np = load_image(self._local_path(image_path))
|
no_mask = fg_mask_np is None # didn’t read the mask file
|
||||||
|
image_np = load_image(
|
||||||
|
self._local_path(image_path), try_read_alpha=no_mask
|
||||||
|
)
|
||||||
|
if image_np.shape[0] == 4: # RGBA image
|
||||||
|
if no_mask:
|
||||||
|
fg_mask_np = image_np[3:]
|
||||||
|
frame_data.fg_probability = safe_as_tensor(
|
||||||
|
fg_mask_np, torch.float
|
||||||
|
)
|
||||||
|
|
||||||
|
image_np = image_np[:3]
|
||||||
|
|
||||||
frame_data.image_rgb = self._postprocess_image(
|
frame_data.image_rgb = self._postprocess_image(
|
||||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if bbox_xywh is None and fg_mask_np is not None:
|
||||||
load_blobs
|
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
|
||||||
and self.load_depths
|
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
|
||||||
and frame_annotation.depth is not None
|
|
||||||
and frame_annotation.depth.path is not None
|
if load_blobs and self.load_depths and depth_path is not None:
|
||||||
):
|
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
|
||||||
(
|
frame_annotation, depth_path, fg_mask_np
|
||||||
frame_data.depth_map,
|
)
|
||||||
frame_data.depth_path,
|
|
||||||
frame_data.depth_mask,
|
|
||||||
) = self._load_mask_depth(frame_annotation, fg_mask_np)
|
|
||||||
|
|
||||||
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
if load_blobs and self.load_point_clouds and point_cloud is not None:
|
||||||
pcl_path = self._fix_point_cloud_path(point_cloud.path)
|
assert pcl_path is not None
|
||||||
frame_data.sequence_point_cloud = load_pointcloud(
|
frame_data.sequence_point_cloud = load_pointcloud(
|
||||||
self._local_path(pcl_path), max_points=self.max_points
|
self._local_path(pcl_path), max_points=self.max_points
|
||||||
)
|
)
|
||||||
frame_data.sequence_point_cloud_path = pcl_path
|
|
||||||
|
|
||||||
if frame_annotation.viewpoint is not None:
|
if frame_annotation.viewpoint is not None:
|
||||||
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
|
||||||
@@ -652,18 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
def _load_fg_probability(
|
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
|
||||||
self, entry: types.FrameAnnotation
|
fg_probability = load_mask(self._local_path(path))
|
||||||
) -> Tuple[np.ndarray, str]:
|
|
||||||
assert self.dataset_root is not None and entry.mask is not None
|
|
||||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
|
||||||
fg_probability = load_mask(self._local_path(full_path))
|
|
||||||
if fg_probability.shape[-2:] != entry.image.size:
|
if fg_probability.shape[-2:] != entry.image.size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
|
||||||
)
|
)
|
||||||
|
|
||||||
return fg_probability, full_path
|
return fg_probability
|
||||||
|
|
||||||
def _postprocess_image(
|
def _postprocess_image(
|
||||||
self,
|
self,
|
||||||
@@ -684,14 +709,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
|
|
||||||
def _load_mask_depth(
|
def _load_mask_depth(
|
||||||
self,
|
self,
|
||||||
entry: types.FrameAnnotation,
|
entry: FrameAnnotationT,
|
||||||
|
path: str,
|
||||||
fg_mask: Optional[np.ndarray],
|
fg_mask: Optional[np.ndarray],
|
||||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
entry_depth = entry.depth
|
entry_depth = entry.depth
|
||||||
dataset_root = self.dataset_root
|
dataset_root = self.dataset_root
|
||||||
assert dataset_root is not None
|
assert dataset_root is not None
|
||||||
assert entry_depth is not None and entry_depth.path is not None
|
assert entry_depth is not None
|
||||||
path = os.path.join(dataset_root, entry_depth.path)
|
|
||||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||||
|
|
||||||
if self.mask_depths:
|
if self.mask_depths:
|
||||||
@@ -705,11 +730,11 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
else:
|
else:
|
||||||
depth_mask = (depth_map > 0.0).astype(np.float32)
|
depth_mask = (depth_map > 0.0).astype(np.float32)
|
||||||
|
|
||||||
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
|
return torch.tensor(depth_map), torch.tensor(depth_mask)
|
||||||
|
|
||||||
def _get_pytorch3d_camera(
|
def _get_pytorch3d_camera(
|
||||||
self,
|
self,
|
||||||
entry: types.FrameAnnotation,
|
entry: FrameAnnotationT,
|
||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
entry_viewpoint = entry.viewpoint
|
entry_viewpoint = entry.viewpoint
|
||||||
assert entry_viewpoint is not None
|
assert entry_viewpoint is not None
|
||||||
@@ -738,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _fix_point_cloud_path(self, path: str) -> str:
|
|
||||||
"""
|
|
||||||
Fix up a point cloud path from the dataset.
|
|
||||||
Some files in Co3Dv2 have an accidental absolute path stored.
|
|
||||||
"""
|
|
||||||
unwanted_prefix = (
|
|
||||||
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
|
|
||||||
)
|
|
||||||
if path.startswith(unwanted_prefix):
|
|
||||||
path = path[len(unwanted_prefix) :]
|
|
||||||
assert self.dataset_root is not None
|
|
||||||
return os.path.join(self.dataset_root, path)
|
|
||||||
|
|
||||||
def _local_path(self, path: str) -> str:
|
def _local_path(self, path: str) -> str:
|
||||||
if self.path_manager is None:
|
if self.path_manager is None:
|
||||||
return path
|
return path
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ _NEED_CONTROL: Tuple[str, ...] = (
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
class JsonIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Generates the training / validation and testing dataset objects for
|
Generates the training / validation and testing dataset objects for
|
||||||
a dataset laid out on disk like Co3D, with annotations in json files.
|
a dataset laid out on disk like Co3D, with annotations in json files.
|
||||||
@@ -95,6 +95,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `category` is never initialized.
|
||||||
category: str
|
category: str
|
||||||
task_str: str = "singlesequence"
|
task_str: str = "singlesequence"
|
||||||
dataset_root: str = _CO3D_DATASET_ROOT
|
dataset_root: str = _CO3D_DATASET_ROOT
|
||||||
@@ -104,8 +105,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
test_restrict_sequence_id: int = -1
|
test_restrict_sequence_id: int = -1
|
||||||
assert_single_seq: bool = False
|
assert_single_seq: bool = False
|
||||||
only_test_set: bool = False
|
only_test_set: bool = False
|
||||||
|
# pyre-fixme[13]: Attribute `dataset` is never initialized.
|
||||||
dataset: JsonIndexDataset
|
dataset: JsonIndexDataset
|
||||||
dataset_class_type: str = "JsonIndexDataset"
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Generates the training, validation, and testing dataset objects for
|
Generates the training, validation, and testing dataset objects for
|
||||||
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
||||||
@@ -171,7 +171,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `category` is never initialized.
|
||||||
category: str
|
category: str
|
||||||
|
# pyre-fixme[13]: Attribute `subset_name` is never initialized.
|
||||||
subset_name: str
|
subset_name: str
|
||||||
dataset_root: str = _CO3DV2_DATASET_ROOT
|
dataset_root: str = _CO3DV2_DATASET_ROOT
|
||||||
|
|
||||||
@@ -183,8 +185,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
n_known_frames_for_test: int = 0
|
n_known_frames_for_test: int = 0
|
||||||
|
|
||||||
dataset_class_type: str = "JsonIndexDataset"
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
# pyre-fixme[13]: Attribute `dataset` is never initialized.
|
||||||
dataset: JsonIndexDataset
|
dataset: JsonIndexDataset
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
@@ -218,7 +222,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
self.dataset_map = dataset_map
|
self.dataset_map = dataset_map
|
||||||
|
|
||||||
def _load_category(self, category: str) -> DatasetMap:
|
def _load_category(self, category: str) -> DatasetMap:
|
||||||
|
|
||||||
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
|
||||||
sequence_file = os.path.join(
|
sequence_file = os.path.join(
|
||||||
self.dataset_root, category, "sequence_annotations.jgz"
|
self.dataset_root, category, "sequence_annotations.jgz"
|
||||||
|
|||||||
@@ -75,7 +75,6 @@ def _minify(basedir, path_manager, factors=(), resolutions=()):
|
|||||||
def _load_data(
|
def _load_data(
|
||||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||||
):
|
):
|
||||||
|
|
||||||
poses_arr = np.load(
|
poses_arr = np.load(
|
||||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||||
)
|
)
|
||||||
@@ -164,7 +163,6 @@ def ptstocam(pts, c2w):
|
|||||||
|
|
||||||
|
|
||||||
def poses_avg(poses):
|
def poses_avg(poses):
|
||||||
|
|
||||||
hwf = poses[0, :3, -1:]
|
hwf = poses[0, :3, -1:]
|
||||||
|
|
||||||
center = poses[:, :3, 3].mean(0)
|
center = poses[:, :3, 3].mean(0)
|
||||||
@@ -192,7 +190,6 @@ def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
|||||||
|
|
||||||
|
|
||||||
def recenter_poses(poses):
|
def recenter_poses(poses):
|
||||||
|
|
||||||
poses_ = poses + 0
|
poses_ = poses + 0
|
||||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||||
c2w = poses_avg(poses)
|
c2w = poses_avg(poses)
|
||||||
@@ -256,7 +253,6 @@ def spherify_poses(poses, bds):
|
|||||||
new_poses = []
|
new_poses = []
|
||||||
|
|
||||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||||
|
|
||||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||||
up = np.array([0, 0, -1.0])
|
up = np.array([0, 0, -1.0])
|
||||||
|
|
||||||
@@ -311,7 +307,6 @@ def load_llff_data(
|
|||||||
path_zflat=False,
|
path_zflat=False,
|
||||||
path_manager=None,
|
path_manager=None,
|
||||||
):
|
):
|
||||||
|
|
||||||
poses, bds, imgs = _load_data(
|
poses, bds, imgs = _load_data(
|
||||||
basedir, factor=factor, path_manager=path_manager
|
basedir, factor=factor, path_manager=path_manager
|
||||||
) # factor=8 downsamples original imgs by 8x
|
) # factor=8 downsamples original imgs by 8x
|
||||||
|
|||||||
@@ -4,6 +4,8 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# pyre-unsafe
|
||||||
|
|
||||||
# This functionality requires SQLAlchemy 2.0 or later.
|
# This functionality requires SQLAlchemy 2.0 or later.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from .utils import DATASET_TYPE_KNOWN
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
A simple single-scene dataset based on PyTorch3D renders of a mesh.
|
A simple single-scene dataset based on PyTorch3D renders of a mesh.
|
||||||
Provides `num_views` renders of the mesh as train, with no val
|
Provides `num_views` renders of the mesh as train, with no val
|
||||||
@@ -76,6 +76,7 @@ class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13
|
|||||||
resolution: int = 128
|
resolution: int = 128
|
||||||
use_point_light: bool = True
|
use_point_light: bool = True
|
||||||
gpu_idx: Optional[int] = 0
|
gpu_idx: Optional[int] = 0
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
|||||||
@@ -83,7 +83,6 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
|||||||
return self.eval_batches
|
return self.eval_batches
|
||||||
|
|
||||||
|
|
||||||
# pyre-fixme[13]: Uninitialized attribute
|
|
||||||
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Base for provider of data for one scene from LLFF or blender datasets.
|
Base for provider of data for one scene from LLFF or blender datasets.
|
||||||
@@ -100,8 +99,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
testing frame.
|
testing frame.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `base_dir` is never initialized.
|
||||||
base_dir: str
|
base_dir: str
|
||||||
|
# pyre-fixme[13]: Attribute `object_name` is never initialized.
|
||||||
object_name: str
|
object_name: str
|
||||||
|
# pyre-fixme[13]: Attribute `path_manager_factory` is never initialized.
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
n_known_frames_for_test: Optional[int] = None
|
n_known_frames_for_test: Optional[int] = None
|
||||||
|
|||||||
@@ -4,11 +4,15 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# pyre-unsafe
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
|
||||||
|
import urllib
|
||||||
|
from dataclasses import dataclass, Field, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@@ -29,17 +33,18 @@ import sqlalchemy as sa
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
from pytorch3d.implicitron.dataset.frame_data import (
|
||||||
FrameData,
|
FrameData,
|
||||||
FrameDataBuilder,
|
FrameDataBuilder, # noqa
|
||||||
FrameDataBuilderBase,
|
FrameDataBuilderBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
registry,
|
registry,
|
||||||
ReplaceableBase,
|
ReplaceableBase,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||||
|
|
||||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||||
|
|
||||||
@@ -51,7 +56,7 @@ _SET_LISTS_TABLE: str = "set_lists"
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||||
The length is returned after all sequence and frame filters are applied (see param
|
The length is returned after all sequence and frame filters are applied (see param
|
||||||
@@ -88,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
engine verbatim. Don’t expose it to end users of your application!
|
engine verbatim. Don’t expose it to end users of your application!
|
||||||
pick_categories: Restrict the dataset to the given list of categories.
|
pick_categories: Restrict the dataset to the given list of categories.
|
||||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||||
|
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||||
sequences within each category (applies after all other sequence filters
|
sequences within each category (applies after all other sequence filters
|
||||||
@@ -102,9 +108,16 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
more frames than that; applied after other frame-level filters.
|
more frames than that; applied after other frame-level filters.
|
||||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||||
random frames per sequence.
|
random frames per sequence.
|
||||||
|
preload_metadata: If True, the metadata is preloaded into memory.
|
||||||
|
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||||
|
scoped_session: If True, allows different parts of the code to share
|
||||||
|
a global session to access the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||||
|
sequence_annotations_type: ClassVar[Type[SqlSequenceAnnotation]] = (
|
||||||
|
SqlSequenceAnnotation
|
||||||
|
)
|
||||||
|
|
||||||
sqlite_metadata_file: str = ""
|
sqlite_metadata_file: str = ""
|
||||||
dataset_root: Optional[str] = None
|
dataset_root: Optional[str] = None
|
||||||
@@ -117,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
pick_categories: Tuple[str, ...] = ()
|
pick_categories: Tuple[str, ...] = ()
|
||||||
|
|
||||||
pick_sequences: Tuple[str, ...] = ()
|
pick_sequences: Tuple[str, ...] = ()
|
||||||
|
pick_sequences_sql_clause: Optional[str] = None
|
||||||
exclude_sequences: Tuple[str, ...] = ()
|
exclude_sequences: Tuple[str, ...] = ()
|
||||||
limit_sequences_per_category_to: int = 0
|
limit_sequences_per_category_to: int = 0
|
||||||
limit_sequences_to: int = 0
|
limit_sequences_to: int = 0
|
||||||
@@ -124,12 +138,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||||
|
preload_metadata: bool = False
|
||||||
|
precompute_seq_to_idx: bool = False
|
||||||
# we set it manually in the constructor
|
# we set it manually in the constructor
|
||||||
# _index: pd.DataFrame = field(init=False)
|
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||||
|
_sql_engine: sa.engine.Engine = field(
|
||||||
|
init=False, metadata={"omegaconf_ignore": True}
|
||||||
|
)
|
||||||
|
eval_batches: Optional[List[Any]] = field(
|
||||||
|
init=False, metadata={"omegaconf_ignore": True}
|
||||||
|
)
|
||||||
|
|
||||||
frame_data_builder: FrameDataBuilderBase
|
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||||
|
|
||||||
|
scoped_session: bool = False
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if sa.__version__ < "2.0":
|
if sa.__version__ < "2.0":
|
||||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||||
@@ -138,19 +162,28 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
raise ValueError("sqlite_metadata_file must be set")
|
raise ValueError("sqlite_metadata_file must be set")
|
||||||
|
|
||||||
if self.dataset_root:
|
if self.dataset_root:
|
||||||
frame_builder_type = self.frame_data_builder_class_type
|
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
|
||||||
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
getattr(self, frame_args)["dataset_root"] = self.dataset_root
|
||||||
"dataset_root"
|
getattr(self, frame_args)["path_manager"] = self.path_manager
|
||||||
] = self.dataset_root
|
|
||||||
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
self.frame_data_builder.path_manager = self.path_manager
|
|
||||||
|
|
||||||
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
if self.path_manager is not None:
|
||||||
|
self.sqlite_metadata_file = self.path_manager.get_local_path(
|
||||||
|
self.sqlite_metadata_file
|
||||||
|
)
|
||||||
|
self.subset_lists_file = self.path_manager.get_local_path(
|
||||||
|
self.subset_lists_file
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: sqlite-specific args (read-only mode).
|
||||||
self._sql_engine = sa.create_engine(
|
self._sql_engine = sa.create_engine(
|
||||||
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.preload_metadata:
|
||||||
|
self._sql_engine = self._preload_database(self._sql_engine)
|
||||||
|
|
||||||
sequences = self._get_filtered_sequences_if_any()
|
sequences = self._get_filtered_sequences_if_any()
|
||||||
|
|
||||||
if self.subsets:
|
if self.subsets:
|
||||||
@@ -166,16 +199,29 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
if len(index) == 0:
|
if len(index) == 0:
|
||||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||||
|
|
||||||
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
self._index = index.set_index(["sequence_name", "frame_number"])
|
||||||
|
|
||||||
self.eval_batches = None # pyre-ignore
|
self.eval_batches = None
|
||||||
if self.eval_batches_file:
|
if self.eval_batches_file:
|
||||||
self.eval_batches = self._load_filter_eval_batches()
|
self.eval_batches = self._load_filter_eval_batches()
|
||||||
|
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
|
if self.scoped_session:
|
||||||
|
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||||
|
|
||||||
|
if self.precompute_seq_to_idx:
|
||||||
|
# This is deprecated and will be removed in the future.
|
||||||
|
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||||
|
logger.warning(
|
||||||
|
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||||
|
)
|
||||||
|
self._index["rowid"] = np.arange(len(self._index))
|
||||||
|
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||||
|
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||||
|
del self._index["rowid"]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
# pyre-ignore[16]
|
|
||||||
return len(self._index)
|
return len(self._index)
|
||||||
|
|
||||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||||
@@ -232,12 +278,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
self.frame_annotations_type.frame_number
|
self.frame_annotations_type.frame_number
|
||||||
== int(frame), # cast from np.int64
|
== int(frame), # cast from np.int64
|
||||||
)
|
)
|
||||||
seq_stmt = sa.select(SqlSequenceAnnotation).where(
|
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||||
SqlSequenceAnnotation.sequence_name == seq
|
self.sequence_annotations_type.sequence_name == seq
|
||||||
)
|
)
|
||||||
with Session(self._sql_engine) as session:
|
if self.scoped_session:
|
||||||
entry = session.scalars(stmt).one()
|
# pyre-ignore
|
||||||
seq_metadata = session.scalars(seq_stmt).one()
|
with scoped_session(self._session_factory)() as session:
|
||||||
|
entry = session.scalars(stmt).one()
|
||||||
|
seq_metadata = session.scalars(seq_stmt).one()
|
||||||
|
else:
|
||||||
|
with Session(self._sql_engine) as session:
|
||||||
|
entry = session.scalars(stmt).one()
|
||||||
|
seq_metadata = session.scalars(seq_stmt).one()
|
||||||
|
|
||||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||||
|
|
||||||
@@ -250,7 +302,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
# pyre-ignore[16]
|
|
||||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||||
|
|
||||||
def sequence_names(self) -> Iterable[str]:
|
def sequence_names(self) -> Iterable[str]:
|
||||||
@@ -260,9 +311,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
# override
|
# override
|
||||||
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
def category_to_sequence_names(self) -> Dict[str, List[str]]:
|
||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
SqlSequenceAnnotation.category, SqlSequenceAnnotation.sequence_name
|
self.sequence_annotations_type.category,
|
||||||
|
self.sequence_annotations_type.sequence_name,
|
||||||
).where( # we limit results to sequences that have frames after all filters
|
).where( # we limit results to sequences that have frames after all filters
|
||||||
SqlSequenceAnnotation.sequence_name.in_(self.sequence_names())
|
self.sequence_annotations_type.sequence_name.in_(self.sequence_names())
|
||||||
)
|
)
|
||||||
with self._sql_engine.connect() as connection:
|
with self._sql_engine.connect() as connection:
|
||||||
cat_to_seqs = pd.read_sql(stmt, connection)
|
cat_to_seqs = pd.read_sql(stmt, connection)
|
||||||
@@ -335,17 +387,31 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
rows = self._index.index.get_loc(seq_name)
|
rows = self._index.index.get_loc(seq_name)
|
||||||
if isinstance(rows, slice):
|
if isinstance(rows, slice):
|
||||||
assert rows.stop is not None, "Unexpected result from pandas"
|
assert rows.stop is not None, "Unexpected result from pandas"
|
||||||
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||||
else:
|
else:
|
||||||
rows = np.where(rows)[0]
|
rows_seq = list(np.where(rows)[0])
|
||||||
|
|
||||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||||
rows, seq_name, subset_filter
|
rows_seq, seq_name, subset_filter
|
||||||
)
|
)
|
||||||
index_slice["idx"] = idx
|
index_slice["idx"] = idx
|
||||||
|
|
||||||
yield from index_slice.itertuples(index=False)
|
yield from index_slice.itertuples(index=False)
|
||||||
|
|
||||||
|
# override
|
||||||
|
def sequence_indices_in_order(
|
||||||
|
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||||
|
) -> Iterator[int]:
|
||||||
|
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||||
|
only dataset indices.
|
||||||
|
"""
|
||||||
|
if self.precompute_seq_to_idx and subset_filter is None:
|
||||||
|
# pyre-ignore
|
||||||
|
yield from self._seq_to_indices[seq_name]
|
||||||
|
else:
|
||||||
|
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||||
|
yield idx
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||||
"""
|
"""
|
||||||
@@ -379,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
or self.limit_sequences_to > 0
|
or self.limit_sequences_to > 0
|
||||||
or self.limit_sequences_per_category_to > 0
|
or self.limit_sequences_per_category_to > 0
|
||||||
or len(self.pick_sequences) > 0
|
or len(self.pick_sequences) > 0
|
||||||
|
or self.pick_sequences_sql_clause is not None
|
||||||
or len(self.exclude_sequences) > 0
|
or len(self.exclude_sequences) > 0
|
||||||
or len(self.pick_categories) > 0
|
or len(self.pick_categories) > 0
|
||||||
or self.n_frames_per_sequence > 0
|
or self.n_frames_per_sequence > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _preload_database(
|
||||||
|
self, source_engine: sa.engine.base.Engine
|
||||||
|
) -> sa.engine.base.Engine:
|
||||||
|
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||||
|
metadata = sa.MetaData()
|
||||||
|
metadata.reflect(bind=source_engine)
|
||||||
|
metadata.create_all(bind=destination_engine)
|
||||||
|
|
||||||
|
with source_engine.connect() as source_conn:
|
||||||
|
with destination_engine.connect() as destination_conn:
|
||||||
|
for table_obj in metadata.tables.values():
|
||||||
|
# Select all rows from the source table
|
||||||
|
source_rows = source_conn.execute(table_obj.select())
|
||||||
|
|
||||||
|
# Insert rows into the destination table
|
||||||
|
for row in source_rows:
|
||||||
|
destination_conn.execute(table_obj.insert().values(row))
|
||||||
|
|
||||||
|
# Commit the changes for each table
|
||||||
|
destination_conn.commit()
|
||||||
|
|
||||||
|
return destination_engine
|
||||||
|
|
||||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||||
# WHERE category IN 'self.pick_categories'
|
# WHERE category IN 'self.pick_categories'
|
||||||
@@ -396,19 +486,22 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
*self._get_pick_filters(),
|
*self._get_pick_filters(),
|
||||||
*self._get_exclude_filters(),
|
*self._get_exclude_filters(),
|
||||||
]
|
]
|
||||||
|
if self.pick_sequences_sql_clause:
|
||||||
|
print("Applying the custom SQL clause.")
|
||||||
|
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||||
|
|
||||||
def add_where(stmt):
|
def add_where(stmt):
|
||||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
|
|
||||||
if self.limit_sequences_per_category_to <= 0:
|
if self.limit_sequences_per_category_to <= 0:
|
||||||
stmt = add_where(sa.select(SqlSequenceAnnotation.sequence_name))
|
stmt = add_where(sa.select(self.sequence_annotations_type.sequence_name))
|
||||||
else:
|
else:
|
||||||
subquery = sa.select(
|
subquery = sa.select(
|
||||||
SqlSequenceAnnotation.sequence_name,
|
self.sequence_annotations_type.sequence_name,
|
||||||
sa.func.row_number()
|
sa.func.row_number()
|
||||||
.over(
|
.over(
|
||||||
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
order_by=sa.text("ROWID"), # NOTE: ROWID is SQLite-specific
|
||||||
partition_by=SqlSequenceAnnotation.category,
|
partition_by=self.sequence_annotations_type.category,
|
||||||
)
|
)
|
||||||
.label("row_number"),
|
.label("row_number"),
|
||||||
)
|
)
|
||||||
@@ -444,31 +537,34 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
logger.info(f"Limiting dataset to categories: {self.pick_categories}")
|
||||||
return [SqlSequenceAnnotation.category.in_(self.pick_categories)]
|
return [self.sequence_annotations_type.category.in_(self.pick_categories)]
|
||||||
|
|
||||||
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
def _get_pick_filters(self) -> List[sa.ColumnElement]:
|
||||||
if not self.pick_sequences:
|
if not self.pick_sequences:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
logger.info(f"Limiting dataset to sequences: {self.pick_sequences}")
|
||||||
return [SqlSequenceAnnotation.sequence_name.in_(self.pick_sequences)]
|
return [self.sequence_annotations_type.sequence_name.in_(self.pick_sequences)]
|
||||||
|
|
||||||
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
def _get_exclude_filters(self) -> List[sa.ColumnOperators]:
|
||||||
if not self.exclude_sequences:
|
if not self.exclude_sequences:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
logger.info(f"Removing sequences from the dataset: {self.exclude_sequences}")
|
||||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
return [
|
||||||
|
self.sequence_annotations_type.sequence_name.notin_(self.exclude_sequences)
|
||||||
|
]
|
||||||
|
|
||||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||||
assert self.subsets is not None
|
subsets = self.subsets
|
||||||
|
assert subsets is not None
|
||||||
with open(subset_lists_path, "r") as f:
|
with open(subset_lists_path, "r") as f:
|
||||||
subset_to_seq_frame = json.load(f)
|
subset_to_seq_frame = json.load(f)
|
||||||
|
|
||||||
seq_frame_list = sum(
|
seq_frame_list = sum(
|
||||||
(
|
(
|
||||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||||
for subset in self.subsets
|
for subset in subsets
|
||||||
),
|
),
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
@@ -522,7 +618,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.frame_annotations_type.sequence_name,
|
self.frame_annotations_type.sequence_name,
|
||||||
self.frame_annotations_type.frame_number,
|
self.frame_annotations_type.frame_number,
|
||||||
).where(self.frame_annotations_type._mask_mass == 0)
|
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
|
||||||
with Session(self._sql_engine) as session:
|
with Session(self._sql_engine) as session:
|
||||||
to_remove = session.execute(stmt).all()
|
to_remove = session.execute(stmt).all()
|
||||||
|
|
||||||
@@ -586,7 +682,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.frame_annotations_type.sequence_name,
|
self.frame_annotations_type.sequence_name,
|
||||||
self.frame_annotations_type.frame_number,
|
self.frame_annotations_type.frame_number,
|
||||||
self.frame_annotations_type._image_path,
|
self.frame_annotations_type._image_path, # pyre-ignore[16]
|
||||||
sa.null().label("subset"),
|
sa.null().label("subset"),
|
||||||
)
|
)
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
@@ -600,7 +696,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
logger.info(" excluding samples with empty masks")
|
logger.info(" excluding samples with empty masks")
|
||||||
where_conditions.append(
|
where_conditions.append(
|
||||||
sa.or_(
|
sa.or_(
|
||||||
self.frame_annotations_type._mask_mass.is_(None),
|
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
||||||
self.frame_annotations_type._mask_mass != 0,
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -634,7 +730,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
assert self.eval_batches_file
|
assert self.eval_batches_file
|
||||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||||
|
|
||||||
if not os.path.isfile(self.eval_batches_file):
|
if (
|
||||||
|
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
|
||||||
|
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
|
||||||
# The batch indices file does not exist.
|
# The batch indices file does not exist.
|
||||||
# Most probably the user has not specified the root folder.
|
# Most probably the user has not specified the root folder.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -642,7 +740,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
+ "Please specify a correct dataset_root folder."
|
+ "Please specify a correct dataset_root folder."
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(self.eval_batches_file, "r") as f:
|
eval_batches_file = self._local_path(self.eval_batches_file)
|
||||||
|
with open(eval_batches_file, "r") as f:
|
||||||
eval_batches = json.load(f)
|
eval_batches = json.load(f)
|
||||||
|
|
||||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||||
@@ -726,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
self.frame_annotations_type.sequence_name == seq_name,
|
self.frame_annotations_type.sequence_name == seq_name,
|
||||||
self.frame_annotations_type.frame_number.in_(frames),
|
self.frame_annotations_type.frame_number.in_(frames),
|
||||||
)
|
)
|
||||||
|
frame_no_ts = None
|
||||||
|
|
||||||
with self._sql_engine.connect() as connection:
|
if self.scoped_session:
|
||||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||||
|
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||||
|
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||||
|
else:
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||||
|
|
||||||
if len(frame_no_ts) != len(index_slice):
|
if len(frame_no_ts) != len(index_slice):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -758,11 +863,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls) -> None:
|
||||||
|
# remove dataclass annotations that are not meant to be init params
|
||||||
|
# because they cause troubles for OmegaConf
|
||||||
|
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
|
||||||
|
if isinstance(attr_value, Field) and attr_value.metadata.get(
|
||||||
|
"omegaconf_ignore", False
|
||||||
|
):
|
||||||
|
delattr(cls, attr)
|
||||||
|
del cls.__annotations__[attr]
|
||||||
|
|
||||||
|
|
||||||
def _seq_name_to_seed(seq_name) -> int:
|
def _seq_name_to_seed(seq_name) -> int:
|
||||||
"""Generates numbers in [0, 2 ** 28)"""
|
"""Generates numbers in [0, 2 ** 28)"""
|
||||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||||
|
|
||||||
|
|
||||||
def _safe_as_tensor(data, dtype):
|
|
||||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# pyre-unsafe
|
||||||
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@@ -43,7 +45,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Generates the training, validation, and testing dataset objects for
|
Generates the training, validation, and testing dataset objects for
|
||||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||||
@@ -193,9 +195,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||||
dataset_class_type: str = "SqlIndexDataset"
|
dataset_class_type: str = "SqlIndexDataset"
|
||||||
dataset: SqlIndexDataset
|
dataset: SqlIndexDataset # pyre-ignore [13]
|
||||||
|
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory # pyre-ignore [13]
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -282,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
|
|
||||||
logger.debug("Extracting test dataset.")
|
logger.debug("Extracting test dataset.")
|
||||||
eval_batches_file = self._get_lists_file("eval_batches")
|
if self.eval_batches_path is None:
|
||||||
del common_dataset_kwargs["eval_batches_file"]
|
eval_batches_file = None
|
||||||
|
else:
|
||||||
|
eval_batches_file = self._get_lists_file("eval_batches")
|
||||||
|
|
||||||
|
if "eval_batches_file" in common_dataset_kwargs:
|
||||||
|
common_dataset_kwargs.pop("eval_batches_file", None)
|
||||||
|
|
||||||
test_dataset = dataset_type(
|
test_dataset = dataset_type(
|
||||||
**common_dataset_kwargs,
|
**common_dataset_kwargs,
|
||||||
subsets=self._get_subsets(self.test_subsets, True),
|
subsets=self._get_subsets(self.test_subsets, True),
|
||||||
|
|||||||
@@ -87,6 +87,15 @@ def is_train_frame(
|
|||||||
def get_bbox_from_mask(
|
def get_bbox_from_mask(
|
||||||
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
mask: np.ndarray, thr: float, decrease_quant: float = 0.05
|
||||||
) -> Tuple[int, int, int, int]:
|
) -> Tuple[int, int, int, int]:
|
||||||
|
# these corner cases need to be handled in order to avoid an infinite loop
|
||||||
|
if mask.size == 0:
|
||||||
|
warnings.warn("Empty mask is provided for bbox extraction.", stacklevel=1)
|
||||||
|
return 0, 0, 1, 1
|
||||||
|
|
||||||
|
if not mask.min() >= 0.0:
|
||||||
|
warnings.warn("Negative values in the mask for bbox extraction.", stacklevel=1)
|
||||||
|
mask = mask.clip(min=0.0)
|
||||||
|
|
||||||
# bbox in xywh
|
# bbox in xywh
|
||||||
masks_for_box = np.zeros_like(mask)
|
masks_for_box = np.zeros_like(mask)
|
||||||
while masks_for_box.sum() <= 1.0:
|
while masks_for_box.sum() <= 1.0:
|
||||||
@@ -134,7 +143,15 @@ T = TypeVar("T", bound=torch.Tensor)
|
|||||||
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
def bbox_xyxy_to_xywh(xyxy: T) -> T:
|
||||||
wh = xyxy[2:] - xyxy[:2]
|
wh = xyxy[2:] - xyxy[:2]
|
||||||
xywh = torch.cat([xyxy[:2], wh])
|
xywh = torch.cat([xyxy[:2], wh])
|
||||||
return xywh # pyre-ignore
|
return xywh # pyre-ignore[7]
|
||||||
|
|
||||||
|
|
||||||
|
def bbox_xywh_to_xyxy(xywh: T, clamp_size: float | int | None = None) -> T:
|
||||||
|
wh = xywh[2:]
|
||||||
|
if clamp_size is not None:
|
||||||
|
wh = wh.clamp(min=clamp_size)
|
||||||
|
xyxy = torch.cat([xywh[:2], xywh[:2] + wh])
|
||||||
|
return xyxy # pyre-ignore[7]
|
||||||
|
|
||||||
|
|
||||||
def get_clamp_bbox(
|
def get_clamp_bbox(
|
||||||
@@ -180,16 +197,6 @@ def rescale_bbox(
|
|||||||
return bbox * rel_size
|
return bbox * rel_size
|
||||||
|
|
||||||
|
|
||||||
def bbox_xywh_to_xyxy(
|
|
||||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
|
||||||
) -> torch.Tensor:
|
|
||||||
xyxy = xywh.clone()
|
|
||||||
if clamp_size is not None:
|
|
||||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
|
||||||
xyxy[2:] += xyxy[:2]
|
|
||||||
return xyxy
|
|
||||||
|
|
||||||
|
|
||||||
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
def get_1d_bounds(arr: np.ndarray) -> Tuple[int, int]:
|
||||||
nz = np.flatnonzero(arr)
|
nz = np.flatnonzero(arr)
|
||||||
return nz[0], nz[-1] + 1
|
return nz[0], nz[-1] + 1
|
||||||
@@ -201,18 +208,24 @@ def resize_image(
|
|||||||
image_width: Optional[int],
|
image_width: Optional[int],
|
||||||
mode: str = "bilinear",
|
mode: str = "bilinear",
|
||||||
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
) -> Tuple[torch.Tensor, float, torch.Tensor]:
|
||||||
|
|
||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
if image_height is None or image_width is None:
|
if (
|
||||||
|
image_height is None
|
||||||
|
or image_width is None
|
||||||
|
or image.shape[-2] == 0
|
||||||
|
or image.shape[-1] == 0
|
||||||
|
):
|
||||||
# skip the resizing
|
# skip the resizing
|
||||||
return image, 1.0, torch.ones_like(image[:1])
|
return image, 1.0, torch.ones_like(image[:1])
|
||||||
|
|
||||||
# takes numpy array or tensor, returns pytorch tensor
|
# takes numpy array or tensor, returns pytorch tensor
|
||||||
minscale = min(
|
minscale = min(
|
||||||
image_height / image.shape[-2],
|
image_height / image.shape[-2],
|
||||||
image_width / image.shape[-1],
|
image_width / image.shape[-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
imre = torch.nn.functional.interpolate(
|
imre = torch.nn.functional.interpolate(
|
||||||
image[None],
|
image[None],
|
||||||
scale_factor=minscale,
|
scale_factor=minscale,
|
||||||
@@ -220,6 +233,7 @@ def resize_image(
|
|||||||
align_corners=False if mode == "bilinear" else None,
|
align_corners=False if mode == "bilinear" else None,
|
||||||
recompute_scale_factor=True,
|
recompute_scale_factor=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||||
mask = torch.zeros(1, image_height, image_width)
|
mask = torch.zeros(1, image_height, image_width)
|
||||||
@@ -232,9 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
|||||||
return im.astype(np.float32) / 255.0
|
return im.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
def load_image(path: str) -> np.ndarray:
|
def load_image(
|
||||||
|
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
Load an image from a path and return it as a numpy array.
|
||||||
|
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||||
|
returned as the fourth channel.
|
||||||
|
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||||
|
"""
|
||||||
with Image.open(path) as pil_im:
|
with Image.open(path) as pil_im:
|
||||||
im = np.array(pil_im.convert("RGB"))
|
# Check if the image has an alpha channel
|
||||||
|
if try_read_alpha and pil_im.mode == "RGBA":
|
||||||
|
im = np.array(pil_im)
|
||||||
|
else:
|
||||||
|
im = np.array(pil_im.convert(pil_format))
|
||||||
|
|
||||||
return transpose_normalize_image(im)
|
return transpose_normalize_image(im)
|
||||||
|
|
||||||
@@ -329,6 +355,7 @@ def adjust_camera_to_bbox_crop_(
|
|||||||
|
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
camera.principal_point[0],
|
camera.principal_point[0],
|
||||||
image_size_wh,
|
image_size_wh,
|
||||||
)
|
)
|
||||||
@@ -341,6 +368,7 @@ def adjust_camera_to_bbox_crop_(
|
|||||||
)
|
)
|
||||||
|
|
||||||
camera.focal_length = focal_length[None]
|
camera.focal_length = focal_length[None]
|
||||||
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||||
camera.principal_point = principal_point_cropped[None]
|
camera.principal_point = principal_point_cropped[None]
|
||||||
|
|
||||||
|
|
||||||
@@ -348,9 +376,11 @@ def adjust_camera_to_image_scale_(
|
|||||||
camera: PerspectiveCameras,
|
camera: PerspectiveCameras,
|
||||||
original_size_wh: torch.Tensor,
|
original_size_wh: torch.Tensor,
|
||||||
new_size_wh: torch.LongTensor,
|
new_size_wh: torch.LongTensor,
|
||||||
|
# pyre-fixme[7]: Expected `PerspectiveCameras` but got implicit return value of `None`.
|
||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
focal_length_px, principal_point_px = _convert_ndc_to_pixels(
|
||||||
camera.focal_length[0],
|
camera.focal_length[0],
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, slice[Any, A...
|
||||||
camera.principal_point[0],
|
camera.principal_point[0],
|
||||||
original_size_wh,
|
original_size_wh,
|
||||||
)
|
)
|
||||||
@@ -367,7 +397,8 @@ def adjust_camera_to_image_scale_(
|
|||||||
image_size_wh_output,
|
image_size_wh_output,
|
||||||
)
|
)
|
||||||
camera.focal_length = focal_length_scaled[None]
|
camera.focal_length = focal_length_scaled[None]
|
||||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||||
|
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
||||||
|
|
||||||
|
|
||||||
# NOTE this cache is per-worker; they are implemented as processes.
|
# NOTE this cache is per-worker; they are implemented as processes.
|
||||||
|
|||||||
@@ -299,7 +299,6 @@ def eval_batch(
|
|||||||
)
|
)
|
||||||
|
|
||||||
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("_masked", "_fg")):
|
||||||
|
|
||||||
loss_mask_now = mask_crop * loss_fg_mask
|
loss_mask_now = mask_crop * loss_fg_mask
|
||||||
|
|
||||||
for rgb_metric_name, rgb_metric_fun in zip(
|
for rgb_metric_name, rgb_metric_fun in zip(
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
self.layers = torch.nn.ModuleList()
|
self.layers = torch.nn.ModuleList()
|
||||||
self.proj_layers = torch.nn.ModuleList()
|
self.proj_layers = torch.nn.ModuleList()
|
||||||
for stage in range(self.max_stage):
|
for stage in range(self.max_stage):
|
||||||
stage_name = f"layer{stage+1}"
|
stage_name = f"layer{stage + 1}"
|
||||||
feature_name = self._get_resnet_stage_feature_name(stage)
|
feature_name = self._get_resnet_stage_feature_name(stage)
|
||||||
if (stage + 1) in self.stages:
|
if (stage + 1) in self.stages:
|
||||||
if (
|
if (
|
||||||
@@ -139,12 +139,18 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
self.stages = set(self.stages) # convert to set for faster "in"
|
self.stages = set(self.stages) # convert to set for faster "in"
|
||||||
|
|
||||||
def _get_resnet_stage_feature_name(self, stage) -> str:
|
def _get_resnet_stage_feature_name(self, stage) -> str:
|
||||||
return f"res_layer_{stage+1}"
|
return f"res_layer_{stage + 1}"
|
||||||
|
|
||||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
|
# pyre-fixme[58]: `-` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
|
# pyre-fixme[58]: `/` is not supported for operand types `Tensor` and
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
return (img - self._resnet_mean) / self._resnet_std
|
return (img - self._resnet_mean) / self._resnet_std
|
||||||
|
|
||||||
def get_feat_dims(self) -> int:
|
def get_feat_dims(self) -> int:
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase) -> Tensor, Tensor, Module]` is
|
||||||
|
# not a function.
|
||||||
return sum(self._feat_dim.values())
|
return sum(self._feat_dim.values())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -183,7 +189,12 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
|
|||||||
else:
|
else:
|
||||||
imgs_normed = imgs_resized
|
imgs_normed = imgs_resized
|
||||||
# is not a function.
|
# is not a function.
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
feats = self.stem(imgs_normed)
|
feats = self.stem(imgs_normed)
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Iterable[_T1]` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
|
# pyre-fixme[6]: For 2nd argument expected `Iterable[_T2]` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
|
||||||
feats = layer(feats)
|
feats = layer(feats)
|
||||||
# just a sanity check below
|
# just a sanity check below
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
class GenericModel(ImplicitronModelBase):
|
||||||
"""
|
"""
|
||||||
GenericModel is a wrapper for the neural implicit
|
GenericModel is a wrapper for the neural implicit
|
||||||
rendering and reconstruction pipeline which consists
|
rendering and reconstruction pipeline which consists
|
||||||
@@ -226,34 +226,42 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
|
|
||||||
# ---- global encoder settings
|
# ---- global encoder settings
|
||||||
global_encoder_class_type: Optional[str] = None
|
global_encoder_class_type: Optional[str] = None
|
||||||
|
# pyre-fixme[13]: Attribute `global_encoder` is never initialized.
|
||||||
global_encoder: Optional[GlobalEncoderBase]
|
global_encoder: Optional[GlobalEncoderBase]
|
||||||
|
|
||||||
# ---- raysampler
|
# ---- raysampler
|
||||||
raysampler_class_type: str = "AdaptiveRaySampler"
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
|
# pyre-fixme[13]: Attribute `raysampler` is never initialized.
|
||||||
raysampler: RaySamplerBase
|
raysampler: RaySamplerBase
|
||||||
|
|
||||||
# ---- renderer configs
|
# ---- renderer configs
|
||||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
|
# pyre-fixme[13]: Attribute `renderer` is never initialized.
|
||||||
renderer: BaseRenderer
|
renderer: BaseRenderer
|
||||||
|
|
||||||
# ---- image feature extractor settings
|
# ---- image feature extractor settings
|
||||||
# (This is only created if view_pooler is enabled)
|
# (This is only created if view_pooler is enabled)
|
||||||
|
# pyre-fixme[13]: Attribute `image_feature_extractor` is never initialized.
|
||||||
image_feature_extractor: Optional[FeatureExtractorBase]
|
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||||
image_feature_extractor_class_type: Optional[str] = None
|
image_feature_extractor_class_type: Optional[str] = None
|
||||||
# ---- view pooler settings
|
# ---- view pooler settings
|
||||||
view_pooler_enabled: bool = False
|
view_pooler_enabled: bool = False
|
||||||
|
# pyre-fixme[13]: Attribute `view_pooler` is never initialized.
|
||||||
view_pooler: Optional[ViewPooler]
|
view_pooler: Optional[ViewPooler]
|
||||||
|
|
||||||
# ---- implicit function settings
|
# ---- implicit function settings
|
||||||
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||||
# This is just a model, never constructed.
|
# This is just a model, never constructed.
|
||||||
# The actual implicit functions live in self._implicit_functions
|
# The actual implicit functions live in self._implicit_functions
|
||||||
|
# pyre-fixme[13]: Attribute `implicit_function` is never initialized.
|
||||||
implicit_function: ImplicitFunctionBase
|
implicit_function: ImplicitFunctionBase
|
||||||
|
|
||||||
# ----- metrics
|
# ----- metrics
|
||||||
|
# pyre-fixme[13]: Attribute `view_metrics` is never initialized.
|
||||||
view_metrics: ViewMetricsBase
|
view_metrics: ViewMetricsBase
|
||||||
view_metrics_class_type: str = "ViewMetrics"
|
view_metrics_class_type: str = "ViewMetrics"
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `regularization_metrics` is never initialized.
|
||||||
regularization_metrics: RegularizationMetricsBase
|
regularization_metrics: RegularizationMetricsBase
|
||||||
regularization_metrics_class_type: str = "RegularizationMetrics"
|
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||||
|
|
||||||
@@ -470,6 +478,8 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
)
|
)
|
||||||
custom_args["global_code"] = global_code
|
custom_args["global_code"] = global_code
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||||
|
# function.
|
||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.bind_args(**custom_args)
|
func.bind_args(**custom_args)
|
||||||
|
|
||||||
@@ -492,6 +502,8 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
# Unbind the custom arguments to prevent pytorch from storing
|
# Unbind the custom arguments to prevent pytorch from storing
|
||||||
# large buffers of intermediate results due to points in the
|
# large buffers of intermediate results due to points in the
|
||||||
# bound arguments.
|
# bound arguments.
|
||||||
|
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Tensor, Module]` is not a
|
||||||
|
# function.
|
||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.unbind_args()
|
func.unbind_args()
|
||||||
|
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute `weight`.
|
||||||
return (self._autodecoder_codes.weight**2).mean()
|
return (self._autodecoder_codes.weight**2).mean()
|
||||||
|
|
||||||
def get_encoding_dim(self) -> int:
|
def get_encoding_dim(self) -> int:
|
||||||
@@ -95,6 +96,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||||
# `Tensor`.
|
# `Tensor`.
|
||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||||
[self._key_map[elem] for elem in x],
|
[self._key_map[elem] for elem in x],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
@@ -102,6 +104,7 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError("Not enough n_instances in the autodecoder") from None
|
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self._autodecoder_codes(x)
|
return self._autodecoder_codes(x)
|
||||||
|
|
||||||
def _load_key_map_hook(
|
def _load_key_map_hook(
|
||||||
|
|||||||
@@ -59,12 +59,13 @@ class GlobalEncoderBase(ReplaceableBase):
|
|||||||
|
|
||||||
# TODO: probabilistic embeddings?
|
# TODO: probabilistic embeddings?
|
||||||
@registry.register
|
@registry.register
|
||||||
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
|
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A global encoder implementation which provides an autodecoder encoding
|
A global encoder implementation which provides an autodecoder encoding
|
||||||
of the frame's sequence identifier.
|
of the frame's sequence identifier.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Attribute `autodecoder` is never initialized.
|
||||||
autodecoder: Autodecoder
|
autodecoder: Autodecoder
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -121,6 +122,7 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
|||||||
if frame_timestamp.shape[-1] != 1:
|
if frame_timestamp.shape[-1] != 1:
|
||||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||||
time = frame_timestamp / self.time_divisor
|
time = frame_timestamp / self.time_divisor
|
||||||
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
||||||
return self._harmonic_embedding(time)
|
return self._harmonic_embedding(time)
|
||||||
|
|
||||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
|||||||
@@ -232,9 +232,14 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
# if the skip tensor is None, we use `x` instead.
|
# if the skip tensor is None, we use `x` instead.
|
||||||
z = x
|
z = x
|
||||||
skipi = 0
|
skipi = 0
|
||||||
|
# pyre-fixme[6]: For 1st argument expected `Iterable[_T]` but got
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
for li, layer in enumerate(self.mlp):
|
for li, layer in enumerate(self.mlp):
|
||||||
|
# pyre-fixme[58]: `in` is not supported for right operand type
|
||||||
|
# `Union[Tensor, Module]`.
|
||||||
if li in self._input_skips:
|
if li in self._input_skips:
|
||||||
if self._skip_affine_trans:
|
if self._skip_affine_trans:
|
||||||
|
# pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, ...
|
||||||
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
|
||||||
else:
|
else:
|
||||||
y = torch.cat((y, z), dim=-1)
|
y = torch.cat((y, z), dim=-1)
|
||||||
@@ -244,7 +249,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
# pyre-fixme[13]: Attribute `network` is never initialized.
|
|
||||||
class MLPDecoder(DecoderFunctionBase):
|
class MLPDecoder(DecoderFunctionBase):
|
||||||
"""
|
"""
|
||||||
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
Decoding function which uses `MLPWithIputSkips` to convert the embedding to output.
|
||||||
@@ -272,6 +276,7 @@ class MLPDecoder(DecoderFunctionBase):
|
|||||||
|
|
||||||
input_dim: int = 3
|
input_dim: int = 3
|
||||||
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
param_groups: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
# pyre-fixme[13]: Attribute `network` is never initialized.
|
||||||
network: MLPWithInputSkips
|
network: MLPWithInputSkips
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user