mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-22 04:45:58 +08:00
Compare commits
106 Commits
v0.6.2
...
classner-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7c1f026ea | ||
|
|
cb49550486 | ||
|
|
36edf2b302 | ||
|
|
78bb6d17fa | ||
|
|
54c75b4114 | ||
|
|
3783437d2f | ||
|
|
b2dc520210 | ||
|
|
8597d4c5c1 | ||
|
|
38fd8380f7 | ||
|
|
67840f8320 | ||
|
|
9b2e570536 | ||
|
|
0f966217e5 | ||
|
|
379c8b2780 | ||
|
|
8e0c82b89a | ||
|
|
8ba9a694ee | ||
|
|
36ba079bef | ||
|
|
b95ec190af | ||
|
|
55f67b0d18 | ||
|
|
4261e59f51 | ||
|
|
af55ba01f8 | ||
|
|
d3b7f5f421 | ||
|
|
4ecc9ea89d | ||
|
|
8d10ba52b2 | ||
|
|
aa8b03f31d | ||
|
|
57a40b3688 | ||
|
|
522e5f0644 | ||
|
|
e8390d3500 | ||
|
|
4300030d7a | ||
|
|
00acf0b0c7 | ||
|
|
a94f3f4c4b | ||
|
|
efb721320a | ||
|
|
40fb189c29 | ||
|
|
4e87c2b7f1 | ||
|
|
771cf8a328 | ||
|
|
0dce883241 | ||
|
|
ae35824f21 | ||
|
|
f4dd151037 | ||
|
|
7ce8ed55e1 | ||
|
|
7e0146ece4 | ||
|
|
0e4c53c612 | ||
|
|
879495d38f | ||
|
|
5c1ca757bb | ||
|
|
3e4fb0b9d9 | ||
|
|
731ea53c80 | ||
|
|
2e42ef793f | ||
|
|
81d63c6382 | ||
|
|
28c1afaa9d | ||
|
|
cba26506b6 | ||
|
|
65f667fd2e | ||
|
|
7978ffd1e4 | ||
|
|
ea4f3260e4 | ||
|
|
023a2369ae | ||
|
|
c0f88e04a0 | ||
|
|
6275283202 | ||
|
|
1d43251391 | ||
|
|
1fb268dea6 | ||
|
|
8bc0a04e86 | ||
|
|
5cd70067e2 | ||
|
|
5b74a2cc27 | ||
|
|
49ed7b07b1 | ||
|
|
c6519f29f0 | ||
|
|
a42a89a5ba | ||
|
|
c31bf85a23 | ||
|
|
fbd3c679ac | ||
|
|
34f648ede0 | ||
|
|
f625fe1f8b | ||
|
|
7c25d34d22 | ||
|
|
c5a83f46ef | ||
|
|
1702c85bec | ||
|
|
90d00f1b2b | ||
|
|
d27ef14ec7 | ||
|
|
2d1c6d5d93 | ||
|
|
9fe15da3cd | ||
|
|
0f12c51646 | ||
|
|
79c61a2d86 | ||
|
|
69c6d06ed8 | ||
|
|
73dc109dba | ||
|
|
9ec9d057cc | ||
|
|
cd7b885169 | ||
|
|
f632c423ef | ||
|
|
f36b11fe49 | ||
|
|
ea5df60d72 | ||
|
|
4372001981 | ||
|
|
61e2b87019 | ||
|
|
0143d63ba8 | ||
|
|
899a3192b6 | ||
|
|
3b2300641a | ||
|
|
b5f3d3ce12 | ||
|
|
2c1901522a | ||
|
|
90ab219d88 | ||
|
|
9e57b994ca | ||
|
|
e767c4b548 | ||
|
|
e85fa03c5a | ||
|
|
47d06c8924 | ||
|
|
bef959c755 | ||
|
|
c21ba144e7 | ||
|
|
d737a05e55 | ||
|
|
2374d19da5 | ||
|
|
1f3953795c | ||
|
|
a6dada399d | ||
|
|
5c59841863 | ||
|
|
2c64635daa | ||
|
|
ec9580a1d4 | ||
|
|
44cb00e468 | ||
|
|
44ca5f95d9 | ||
|
|
a51a300827 |
@@ -81,7 +81,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
export LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64
|
export LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64
|
||||||
python3 setup.py build_ext --inplace
|
python3 setup.py build_ext --inplace
|
||||||
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests
|
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests -t .
|
||||||
- run: python3 setup.py bdist_wheel
|
- run: python3 setup.py bdist_wheel
|
||||||
|
|
||||||
binary_linux_wheel:
|
binary_linux_wheel:
|
||||||
@@ -182,23 +182,23 @@ workflows:
|
|||||||
# context: DOCKERHUB_TOKEN
|
# context: DOCKERHUB_TOKEN
|
||||||
{{workflows()}}
|
{{workflows()}}
|
||||||
- binary_linux_conda_cuda:
|
- binary_linux_conda_cuda:
|
||||||
name: testrun_conda_cuda_py37_cu102_pyt170
|
name: testrun_conda_cuda_py37_cu102_pyt190
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
python_version: "3.7"
|
python_version: "3.7"
|
||||||
pytorch_version: '1.7.0'
|
pytorch_version: '1.9.0'
|
||||||
cu_version: "cu102"
|
cu_version: "cu102"
|
||||||
- binary_macos_wheel:
|
- binary_macos_wheel:
|
||||||
cu_version: cpu
|
cu_version: cpu
|
||||||
name: macos_wheel_py37_cpu
|
name: macos_wheel_py37_cpu
|
||||||
python_version: '3.7'
|
python_version: '3.7'
|
||||||
pytorch_version: '1.9.0'
|
pytorch_version: '1.12.0'
|
||||||
- binary_macos_wheel:
|
- binary_macos_wheel:
|
||||||
cu_version: cpu
|
cu_version: cpu
|
||||||
name: macos_wheel_py38_cpu
|
name: macos_wheel_py38_cpu
|
||||||
python_version: '3.8'
|
python_version: '3.8'
|
||||||
pytorch_version: '1.9.0'
|
pytorch_version: '1.12.0'
|
||||||
- binary_macos_wheel:
|
- binary_macos_wheel:
|
||||||
cu_version: cpu
|
cu_version: cpu
|
||||||
name: macos_wheel_py39_cpu
|
name: macos_wheel_py39_cpu
|
||||||
python_version: '3.9'
|
python_version: '3.9'
|
||||||
pytorch_version: '1.9.0'
|
pytorch_version: '1.12.0'
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ jobs:
|
|||||||
command: |
|
command: |
|
||||||
export LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64
|
export LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64
|
||||||
python3 setup.py build_ext --inplace
|
python3 setup.py build_ext --inplace
|
||||||
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests
|
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests -t .
|
||||||
- run: python3 setup.py bdist_wheel
|
- run: python3 setup.py bdist_wheel
|
||||||
|
|
||||||
binary_linux_wheel:
|
binary_linux_wheel:
|
||||||
@@ -180,42 +180,6 @@ workflows:
|
|||||||
jobs:
|
jobs:
|
||||||
# - main:
|
# - main:
|
||||||
# context: DOCKERHUB_TOKEN
|
# context: DOCKERHUB_TOKEN
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu101
|
|
||||||
name: linux_conda_py37_cu101_pyt170
|
|
||||||
python_version: '3.7'
|
|
||||||
pytorch_version: 1.7.0
|
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu102
|
|
||||||
name: linux_conda_py37_cu102_pyt170
|
|
||||||
python_version: '3.7'
|
|
||||||
pytorch_version: 1.7.0
|
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu110
|
|
||||||
name: linux_conda_py37_cu110_pyt170
|
|
||||||
python_version: '3.7'
|
|
||||||
pytorch_version: 1.7.0
|
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu101
|
|
||||||
name: linux_conda_py37_cu101_pyt171
|
|
||||||
python_version: '3.7'
|
|
||||||
pytorch_version: 1.7.1
|
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu102
|
|
||||||
name: linux_conda_py37_cu102_pyt171
|
|
||||||
python_version: '3.7'
|
|
||||||
pytorch_version: 1.7.1
|
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu110
|
|
||||||
name: linux_conda_py37_cu110_pyt171
|
|
||||||
python_version: '3.7'
|
|
||||||
pytorch_version: 1.7.1
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu101
|
cu_version: cu101
|
||||||
@@ -359,42 +323,26 @@ workflows:
|
|||||||
name: linux_conda_py37_cu115_pyt1110
|
name: linux_conda_py37_cu115_pyt1110
|
||||||
python_version: '3.7'
|
python_version: '3.7'
|
||||||
pytorch_version: 1.11.0
|
pytorch_version: 1.11.0
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu101
|
|
||||||
name: linux_conda_py38_cu101_pyt170
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 1.7.0
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu102
|
cu_version: cu102
|
||||||
name: linux_conda_py38_cu102_pyt170
|
name: linux_conda_py37_cu102_pyt1120
|
||||||
python_version: '3.8'
|
python_version: '3.7'
|
||||||
pytorch_version: 1.7.0
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda113
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu110
|
cu_version: cu113
|
||||||
name: linux_conda_py38_cu110_pyt170
|
name: linux_conda_py37_cu113_pyt1120
|
||||||
python_version: '3.8'
|
python_version: '3.7'
|
||||||
pytorch_version: 1.7.0
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda116
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu101
|
cu_version: cu116
|
||||||
name: linux_conda_py38_cu101_pyt171
|
name: linux_conda_py37_cu116_pyt1120
|
||||||
python_version: '3.8'
|
python_version: '3.7'
|
||||||
pytorch_version: 1.7.1
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu102
|
|
||||||
name: linux_conda_py38_cu102_pyt171
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 1.7.1
|
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu110
|
|
||||||
name: linux_conda_py38_cu110_pyt171
|
|
||||||
python_version: '3.8'
|
|
||||||
pytorch_version: 1.7.1
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu101
|
cu_version: cu101
|
||||||
@@ -538,24 +486,26 @@ workflows:
|
|||||||
name: linux_conda_py38_cu115_pyt1110
|
name: linux_conda_py38_cu115_pyt1110
|
||||||
python_version: '3.8'
|
python_version: '3.8'
|
||||||
pytorch_version: 1.11.0
|
pytorch_version: 1.11.0
|
||||||
- binary_linux_conda:
|
|
||||||
context: DOCKERHUB_TOKEN
|
|
||||||
cu_version: cu101
|
|
||||||
name: linux_conda_py39_cu101_pyt171
|
|
||||||
python_version: '3.9'
|
|
||||||
pytorch_version: 1.7.1
|
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu102
|
cu_version: cu102
|
||||||
name: linux_conda_py39_cu102_pyt171
|
name: linux_conda_py38_cu102_pyt1120
|
||||||
python_version: '3.9'
|
python_version: '3.8'
|
||||||
pytorch_version: 1.7.1
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda113
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu110
|
cu_version: cu113
|
||||||
name: linux_conda_py39_cu110_pyt171
|
name: linux_conda_py38_cu113_pyt1120
|
||||||
python_version: '3.9'
|
python_version: '3.8'
|
||||||
pytorch_version: 1.7.1
|
pytorch_version: 1.12.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda116
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu116
|
||||||
|
name: linux_conda_py38_cu116_pyt1120
|
||||||
|
python_version: '3.8'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu101
|
cu_version: cu101
|
||||||
@@ -699,6 +649,26 @@ workflows:
|
|||||||
name: linux_conda_py39_cu115_pyt1110
|
name: linux_conda_py39_cu115_pyt1110
|
||||||
python_version: '3.9'
|
python_version: '3.9'
|
||||||
pytorch_version: 1.11.0
|
pytorch_version: 1.11.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu102
|
||||||
|
name: linux_conda_py39_cu102_pyt1120
|
||||||
|
python_version: '3.9'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda113
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu113
|
||||||
|
name: linux_conda_py39_cu113_pyt1120
|
||||||
|
python_version: '3.9'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda116
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu116
|
||||||
|
name: linux_conda_py39_cu116_pyt1120
|
||||||
|
python_version: '3.9'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda:
|
- binary_linux_conda:
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
cu_version: cu102
|
cu_version: cu102
|
||||||
@@ -725,24 +695,44 @@ workflows:
|
|||||||
name: linux_conda_py310_cu115_pyt1110
|
name: linux_conda_py310_cu115_pyt1110
|
||||||
python_version: '3.10'
|
python_version: '3.10'
|
||||||
pytorch_version: 1.11.0
|
pytorch_version: 1.11.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu102
|
||||||
|
name: linux_conda_py310_cu102_pyt1120
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda113
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu113
|
||||||
|
name: linux_conda_py310_cu113_pyt1120
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
|
- binary_linux_conda:
|
||||||
|
conda_docker_image: pytorch/conda-builder:cuda116
|
||||||
|
context: DOCKERHUB_TOKEN
|
||||||
|
cu_version: cu116
|
||||||
|
name: linux_conda_py310_cu116_pyt1120
|
||||||
|
python_version: '3.10'
|
||||||
|
pytorch_version: 1.12.0
|
||||||
- binary_linux_conda_cuda:
|
- binary_linux_conda_cuda:
|
||||||
name: testrun_conda_cuda_py37_cu102_pyt170
|
name: testrun_conda_cuda_py37_cu102_pyt190
|
||||||
context: DOCKERHUB_TOKEN
|
context: DOCKERHUB_TOKEN
|
||||||
python_version: "3.7"
|
python_version: "3.7"
|
||||||
pytorch_version: '1.7.0'
|
pytorch_version: '1.9.0'
|
||||||
cu_version: "cu102"
|
cu_version: "cu102"
|
||||||
- binary_macos_wheel:
|
- binary_macos_wheel:
|
||||||
cu_version: cpu
|
cu_version: cpu
|
||||||
name: macos_wheel_py37_cpu
|
name: macos_wheel_py37_cpu
|
||||||
python_version: '3.7'
|
python_version: '3.7'
|
||||||
pytorch_version: '1.9.0'
|
pytorch_version: '1.12.0'
|
||||||
- binary_macos_wheel:
|
- binary_macos_wheel:
|
||||||
cu_version: cpu
|
cu_version: cpu
|
||||||
name: macos_wheel_py38_cpu
|
name: macos_wheel_py38_cpu
|
||||||
python_version: '3.8'
|
python_version: '3.8'
|
||||||
pytorch_version: '1.9.0'
|
pytorch_version: '1.12.0'
|
||||||
- binary_macos_wheel:
|
- binary_macos_wheel:
|
||||||
cu_version: cpu
|
cu_version: cpu
|
||||||
name: macos_wheel_py39_cpu
|
name: macos_wheel_py39_cpu
|
||||||
python_version: '3.9'
|
python_version: '3.9'
|
||||||
pytorch_version: '1.9.0'
|
pytorch_version: '1.12.0'
|
||||||
|
|||||||
@@ -20,8 +20,6 @@ from packaging import version
|
|||||||
# version of pytorch.
|
# version of pytorch.
|
||||||
# Pytorch 1.4 also supports cuda 10.0 but we no longer build for cuda 10.0 at all.
|
# Pytorch 1.4 also supports cuda 10.0 but we no longer build for cuda 10.0 at all.
|
||||||
CONDA_CUDA_VERSIONS = {
|
CONDA_CUDA_VERSIONS = {
|
||||||
"1.7.0": ["cu101", "cu102", "cu110"],
|
|
||||||
"1.7.1": ["cu101", "cu102", "cu110"],
|
|
||||||
"1.8.0": ["cu101", "cu102", "cu111"],
|
"1.8.0": ["cu101", "cu102", "cu111"],
|
||||||
"1.8.1": ["cu101", "cu102", "cu111"],
|
"1.8.1": ["cu101", "cu102", "cu111"],
|
||||||
"1.9.0": ["cu102", "cu111"],
|
"1.9.0": ["cu102", "cu111"],
|
||||||
@@ -30,15 +28,20 @@ CONDA_CUDA_VERSIONS = {
|
|||||||
"1.10.1": ["cu102", "cu111", "cu113"],
|
"1.10.1": ["cu102", "cu111", "cu113"],
|
||||||
"1.10.2": ["cu102", "cu111", "cu113"],
|
"1.10.2": ["cu102", "cu111", "cu113"],
|
||||||
"1.11.0": ["cu102", "cu111", "cu113", "cu115"],
|
"1.11.0": ["cu102", "cu111", "cu113", "cu115"],
|
||||||
|
"1.12.0": ["cu102", "cu113", "cu116"],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def conda_docker_image_for_cuda(cuda_version):
|
def conda_docker_image_for_cuda(cuda_version):
|
||||||
|
if cuda_version in ("cu101", "cu102", "cu111"):
|
||||||
|
return None
|
||||||
if cuda_version == "cu113":
|
if cuda_version == "cu113":
|
||||||
return "pytorch/conda-builder:cuda113"
|
return "pytorch/conda-builder:cuda113"
|
||||||
if cuda_version == "cu115":
|
if cuda_version == "cu115":
|
||||||
return "pytorch/conda-builder:cuda115"
|
return "pytorch/conda-builder:cuda115"
|
||||||
return None
|
if cuda_version == "cu116":
|
||||||
|
return "pytorch/conda-builder:cuda116"
|
||||||
|
raise ValueError("Unknown cuda version")
|
||||||
|
|
||||||
|
|
||||||
def pytorch_versions_for_python(python_version):
|
def pytorch_versions_for_python(python_version):
|
||||||
|
|||||||
14
INSTALL.md
14
INSTALL.md
@@ -9,7 +9,7 @@ The core library is written in PyTorch. Several components have underlying imple
|
|||||||
|
|
||||||
- Linux or macOS or Windows
|
- Linux or macOS or Windows
|
||||||
- Python 3.6, 3.7, 3.8 or 3.9
|
- Python 3.6, 3.7, 3.8 or 3.9
|
||||||
- PyTorch 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2 or 1.11.0.
|
- PyTorch 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0 or 1.12.0.
|
||||||
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
|
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
|
||||||
- gcc & g++ ≥ 4.9
|
- gcc & g++ ≥ 4.9
|
||||||
- [fvcore](https://github.com/facebookresearch/fvcore)
|
- [fvcore](https://github.com/facebookresearch/fvcore)
|
||||||
@@ -78,7 +78,7 @@ Or, to install a nightly (non-official, alpha) build:
|
|||||||
conda install pytorch3d -c pytorch3d-nightly
|
conda install pytorch3d -c pytorch3d-nightly
|
||||||
```
|
```
|
||||||
### 2. Install from PyPI, on Mac only.
|
### 2. Install from PyPI, on Mac only.
|
||||||
This works with pytorch 1.9.0 only. The build is CPU only.
|
This works with pytorch 1.12.0 only. The build is CPU only.
|
||||||
```
|
```
|
||||||
pip install pytorch3d
|
pip install pytorch3d
|
||||||
```
|
```
|
||||||
@@ -87,9 +87,9 @@ pip install pytorch3d
|
|||||||
We have prebuilt wheels with CUDA for Linux for PyTorch 1.11.0, for each of the supported CUDA versions,
|
We have prebuilt wheels with CUDA for Linux for PyTorch 1.11.0, for each of the supported CUDA versions,
|
||||||
for Python 3.7, 3.8 and 3.9. This is for ease of use on Google Colab.
|
for Python 3.7, 3.8 and 3.9. This is for ease of use on Google Colab.
|
||||||
These are installed in a special way.
|
These are installed in a special way.
|
||||||
For example, to install for Python 3.8, PyTorch 1.11.0 and CUDA 10.2
|
For example, to install for Python 3.8, PyTorch 1.11.0 and CUDA 11.3
|
||||||
```
|
```
|
||||||
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt1110/download.html
|
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html
|
||||||
```
|
```
|
||||||
|
|
||||||
In general, from inside IPython, or in Google Colab or a jupyter notebook, you can install with
|
In general, from inside IPython, or in Google Colab or a jupyter notebook, you can install with
|
||||||
@@ -147,10 +147,10 @@ After any necessary patching, you can go to "x64 Native Tools Command Prompt for
|
|||||||
cd pytorch3d
|
cd pytorch3d
|
||||||
python3 setup.py install
|
python3 setup.py install
|
||||||
```
|
```
|
||||||
After installing, verify whether all unit tests have passed
|
|
||||||
|
After installing, you can run **unit tests**
|
||||||
```
|
```
|
||||||
cd tests
|
python3 -m unittest discover -v -s tests -t .
|
||||||
python3 -m unittest discover -p *.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
# FAQ
|
# FAQ
|
||||||
|
|||||||
@@ -46,3 +46,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
SOFTWARE.
|
SOFTWARE.
|
||||||
|
|
||||||
|
|
||||||
|
NeRF https://github.com/bmild/nerf/
|
||||||
|
|
||||||
|
Copyright (c) 2020 bmild
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|||||||
@@ -7,22 +7,10 @@
|
|||||||
|
|
||||||
# Run this script at project root by "./dev/linter.sh" before you commit
|
# Run this script at project root by "./dev/linter.sh" before you commit
|
||||||
|
|
||||||
{
|
|
||||||
V=$(black --version|cut '-d ' -f3)
|
|
||||||
code='import distutils.version; assert "19.3" < distutils.version.LooseVersion("'$V'")'
|
|
||||||
PYTHON=false
|
|
||||||
command -v python > /dev/null && PYTHON=python
|
|
||||||
command -v python3 > /dev/null && PYTHON=python3
|
|
||||||
${PYTHON} -c "${code}" 2> /dev/null
|
|
||||||
} || {
|
|
||||||
echo "Linter requires black 19.3b0 or higher!"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||||
DIR=$(dirname "${DIR}")
|
DIR=$(dirname "${DIR}")
|
||||||
|
|
||||||
if [[ -f "${DIR}/tests/TARGETS" ]]
|
if [[ -f "${DIR}/TARGETS" ]]
|
||||||
then
|
then
|
||||||
pyfmt "${DIR}"
|
pyfmt "${DIR}"
|
||||||
else
|
else
|
||||||
@@ -42,7 +30,7 @@ clangformat=$(command -v clang-format-8 || echo clang-format)
|
|||||||
find "${DIR}" -regex ".*\.\(cpp\|c\|cc\|cu\|cuh\|cxx\|h\|hh\|hpp\|hxx\|tcc\|mm\|m\)" -print0 | xargs -0 "${clangformat}" -i
|
find "${DIR}" -regex ".*\.\(cpp\|c\|cc\|cu\|cuh\|cxx\|h\|hh\|hpp\|hxx\|tcc\|mm\|m\)" -print0 | xargs -0 "${clangformat}" -i
|
||||||
|
|
||||||
# Run arc and pyre internally only.
|
# Run arc and pyre internally only.
|
||||||
if [[ -f "${DIR}/tests/TARGETS" ]]
|
if [[ -f "${DIR}/TARGETS" ]]
|
||||||
then
|
then
|
||||||
(cd "${DIR}"; command -v arc > /dev/null && arc lint) || true
|
(cd "${DIR}"; command -v arc > /dev/null && arc lint) || true
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def tests_from_file(path: Path, base: str) -> List[str]:
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
files = get_test_files()
|
files = get_test_files()
|
||||||
test_root = Path(__file__).parent.parent / "tests"
|
test_root = Path(__file__).parent.parent
|
||||||
all_tests = []
|
all_tests = []
|
||||||
for f in files:
|
for f in files:
|
||||||
file_base = str(f.relative_to(test_root))[:-3].replace("/", ".")
|
file_base = str(f.relative_to(test_root))[:-3].replace("/", ".")
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ API Documentation
|
|||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
|
|
||||||
common
|
|
||||||
structures
|
structures
|
||||||
io
|
io
|
||||||
loss
|
loss
|
||||||
@@ -12,3 +11,5 @@ API Documentation
|
|||||||
transforms
|
transforms
|
||||||
utils
|
utils
|
||||||
datasets
|
datasets
|
||||||
|
common
|
||||||
|
vis
|
||||||
|
|||||||
6
docs/modules/vis.rst
Normal file
6
docs/modules/vis.rst
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
pytorch3d.vis
|
||||||
|
===========================
|
||||||
|
|
||||||
|
.. automodule:: pytorch3d.vis
|
||||||
|
:members:
|
||||||
|
:undoc-members:
|
||||||
@@ -8,3 +8,4 @@
|
|||||||
sudo docker run --rm -v "$PWD/../../:/inside" pytorch/conda-cuda bash inside/packaging/linux_wheels/inside.sh
|
sudo docker run --rm -v "$PWD/../../:/inside" pytorch/conda-cuda bash inside/packaging/linux_wheels/inside.sh
|
||||||
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu113 pytorch/conda-builder:cuda113 bash inside/packaging/linux_wheels/inside.sh
|
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu113 pytorch/conda-builder:cuda113 bash inside/packaging/linux_wheels/inside.sh
|
||||||
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu115 pytorch/conda-builder:cuda115 bash inside/packaging/linux_wheels/inside.sh
|
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu115 pytorch/conda-builder:cuda115 bash inside/packaging/linux_wheels/inside.sh
|
||||||
|
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu116 pytorch/conda-builder:cuda116 bash inside/packaging/linux_wheels/inside.sh
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ do
|
|||||||
|
|
||||||
for cu_version in ${CONDA_CUDA_VERSIONS[$pytorch_version]}
|
for cu_version in ${CONDA_CUDA_VERSIONS[$pytorch_version]}
|
||||||
do
|
do
|
||||||
if [[ "cu113 cu115" == *$cu_version* ]]
|
if [[ "cu113 cu115 cu116" == *$cu_version* ]]
|
||||||
# ^^^ CUDA versions listed here have to be built
|
# ^^^ CUDA versions listed here have to be built
|
||||||
# in their own containers.
|
# in their own containers.
|
||||||
then
|
then
|
||||||
@@ -74,6 +74,11 @@ do
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
case "$cu_version" in
|
case "$cu_version" in
|
||||||
|
cu116)
|
||||||
|
export CUDA_HOME=/usr/local/cuda-11.6/
|
||||||
|
export CUDA_TAG=11.6
|
||||||
|
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
|
||||||
|
;;
|
||||||
cu115)
|
cu115)
|
||||||
export CUDA_HOME=/usr/local/cuda-11.5/
|
export CUDA_HOME=/usr/local/cuda-11.5/
|
||||||
export CUDA_TAG=11.5
|
export CUDA_TAG=11.5
|
||||||
@@ -124,6 +129,7 @@ do
|
|||||||
|
|
||||||
conda create -y -n "$tag" "python=$python_version"
|
conda create -y -n "$tag" "python=$python_version"
|
||||||
conda activate "$tag"
|
conda activate "$tag"
|
||||||
|
# shellcheck disable=SC2086
|
||||||
conda install -y -c pytorch $extra_channel "pytorch=$pytorch_version" "cudatoolkit=$CUDA_TAG" torchvision
|
conda install -y -c pytorch $extra_channel "pytorch=$pytorch_version" "cudatoolkit=$CUDA_TAG" torchvision
|
||||||
pip install fvcore iopath
|
pip install fvcore iopath
|
||||||
echo "python version" "$python_version" "pytorch version" "$pytorch_version" "cuda version" "$cu_version" "tag" "$tag"
|
echo "python version" "$python_version" "pytorch version" "$pytorch_version" "cuda version" "$cu_version" "tag" "$tag"
|
||||||
|
|||||||
@@ -55,6 +55,17 @@ setup_cuda() {
|
|||||||
|
|
||||||
# Now work out the CUDA settings
|
# Now work out the CUDA settings
|
||||||
case "$CU_VERSION" in
|
case "$CU_VERSION" in
|
||||||
|
cu116)
|
||||||
|
if [[ "$OSTYPE" == "msys" ]]; then
|
||||||
|
export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.6"
|
||||||
|
else
|
||||||
|
export CUDA_HOME=/usr/local/cuda-11.6/
|
||||||
|
fi
|
||||||
|
export FORCE_CUDA=1
|
||||||
|
# Hard-coding gencode flags is temporary situation until
|
||||||
|
# https://github.com/pytorch/pytorch/pull/23408 lands
|
||||||
|
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
|
||||||
|
;;
|
||||||
cu115)
|
cu115)
|
||||||
if [[ "$OSTYPE" == "msys" ]]; then
|
if [[ "$OSTYPE" == "msys" ]]; then
|
||||||
export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5"
|
export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5"
|
||||||
@@ -304,6 +315,9 @@ setup_conda_cudatoolkit_constraint() {
|
|||||||
export CONDA_CUDATOOLKIT_CONSTRAINT=""
|
export CONDA_CUDATOOLKIT_CONSTRAINT=""
|
||||||
else
|
else
|
||||||
case "$CU_VERSION" in
|
case "$CU_VERSION" in
|
||||||
|
cu116)
|
||||||
|
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.6,<11.7 # [not osx]"
|
||||||
|
;;
|
||||||
cu115)
|
cu115)
|
||||||
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]"
|
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]"
|
||||||
;;
|
;;
|
||||||
|
|||||||
@@ -45,9 +45,12 @@ test:
|
|||||||
- docs
|
- docs
|
||||||
requires:
|
requires:
|
||||||
- imageio
|
- imageio
|
||||||
|
- hydra-core
|
||||||
|
- accelerate
|
||||||
|
- lpips
|
||||||
commands:
|
commands:
|
||||||
#pytest .
|
#pytest .
|
||||||
python -m unittest discover -v -s tests
|
python -m unittest discover -v -s tests -t .
|
||||||
|
|
||||||
|
|
||||||
about:
|
about:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling t
|
|||||||
# License
|
# License
|
||||||
|
|
||||||
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
|
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
|
||||||
It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
|
It includes code from the [NeRF](https://github.com/bmild/nerf), [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
|
||||||
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
|
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
|
||||||
|
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ Only configuration can be changed (see [Configuration system](#configuration-sys
|
|||||||
For this setup, install the dependencies and PyTorch3D from conda following [the guide](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#1-install-with-cuda-support-from-anaconda-cloud-on-linux-only). Then, install implicitron-specific dependencies:
|
For this setup, install the dependencies and PyTorch3D from conda following [the guide](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#1-install-with-cuda-support-from-anaconda-cloud-on-linux-only). Then, install implicitron-specific dependencies:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install "hydra-core>=1.1" visdom lpips matplotlib
|
pip install "hydra-core>=1.1" visdom lpips matplotlib accelerate
|
||||||
```
|
```
|
||||||
|
|
||||||
Runner executable is available as `pytorch3d_implicitron_runner` shell command.
|
Runner executable is available as `pytorch3d_implicitron_runner` shell command.
|
||||||
@@ -49,7 +49,7 @@ Please follow the instructions to [install PyTorch3D from a local clone](https:/
|
|||||||
Then, install Implicitron-specific dependencies:
|
Then, install Implicitron-specific dependencies:
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
pip install "hydra-core>=1.1" visdom lpips matplotlib
|
pip install "hydra-core>=1.1" visdom lpips matplotlib accelerate
|
||||||
```
|
```
|
||||||
|
|
||||||
You are still encouraged to implement custom plugins as above where possible as it makes reusing the code easier.
|
You are still encouraged to implement custom plugins as above where possible as it makes reusing the code easier.
|
||||||
@@ -66,7 +66,8 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
|
|||||||
To run training, pass a yaml config file, followed by a list of overridden arguments.
|
To run training, pass a yaml config file, followed by a list of overridden arguments.
|
||||||
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
|
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
|
||||||
```shell
|
```shell
|
||||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
||||||
```
|
```
|
||||||
|
|
||||||
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
|
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
|
||||||
@@ -84,7 +85,8 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
|
|||||||
|
|
||||||
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
|
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
|
||||||
```shell
|
```shell
|
||||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
||||||
```
|
```
|
||||||
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
|
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
|
||||||
|
|
||||||
@@ -202,7 +204,7 @@ to replace the implementation and potentially override the parameters.
|
|||||||
# Code and config structure
|
# Code and config structure
|
||||||
|
|
||||||
As per above, the config structure is parsed automatically from the module hierarchy.
|
As per above, the config structure is parsed automatically from the module hierarchy.
|
||||||
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node.
|
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node.
|
||||||
|
|
||||||
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
|
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
|
||||||
```
|
```
|
||||||
@@ -224,7 +226,8 @@ generic_model_args: GenericModel
|
|||||||
└-- hypernet_args: SRNRaymarchHyperNet
|
└-- hypernet_args: SRNRaymarchHyperNet
|
||||||
└-- pixel_generator_args: SRNPixelGenerator
|
└-- pixel_generator_args: SRNPixelGenerator
|
||||||
╘== IdrFeatureField
|
╘== IdrFeatureField
|
||||||
└-- image_feature_extractor_args: ResNetFeatureExtractor
|
└-- image_feature_extractor_*_args: FeatureExtractorBase
|
||||||
|
╘== ResNetFeatureExtractor
|
||||||
└-- view_sampler_args: ViewSampler
|
└-- view_sampler_args: ViewSampler
|
||||||
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
||||||
╘== IdentityFeatureAggregator
|
╘== IdentityFeatureAggregator
|
||||||
@@ -232,8 +235,9 @@ generic_model_args: GenericModel
|
|||||||
╘== AngleWeightedReductionFeatureAggregator
|
╘== AngleWeightedReductionFeatureAggregator
|
||||||
╘== ReductionFeatureAggregator
|
╘== ReductionFeatureAggregator
|
||||||
solver_args: init_optimizer
|
solver_args: init_optimizer
|
||||||
dataset_args: dataset_zoo
|
data_source_args: ImplicitronDataSource
|
||||||
dataloader_args: dataloader_zoo
|
└-- dataset_map_provider_*_args
|
||||||
|
└-- data_loader_map_provider_*_args
|
||||||
```
|
```
|
||||||
|
|
||||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
||||||
|
|||||||
@@ -5,29 +5,22 @@ exp_dir: ./data/exps/base/
|
|||||||
architecture: generic
|
architecture: generic
|
||||||
visualize_interval: 0
|
visualize_interval: 0
|
||||||
visdom_port: 8097
|
visdom_port: 8097
|
||||||
dataloader_args:
|
data_source_args:
|
||||||
batch_size: 10
|
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||||
dataset_len: 1000
|
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
||||||
dataset_len_val: 1
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
|
dataset_length_train: 1000
|
||||||
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
images_per_seq_options:
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
- 2
|
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||||
- 3
|
|
||||||
- 4
|
|
||||||
- 5
|
|
||||||
- 6
|
|
||||||
- 7
|
|
||||||
- 8
|
|
||||||
- 9
|
|
||||||
- 10
|
|
||||||
dataset_args:
|
|
||||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}"
|
|
||||||
load_point_clouds: false
|
|
||||||
mask_depths: false
|
|
||||||
mask_images: false
|
|
||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
test_on_train: true
|
test_on_train: true
|
||||||
test_restrict_sequence_id: 0
|
test_restrict_sequence_id: 0
|
||||||
|
dataset_JsonIndexDataset_args:
|
||||||
|
load_point_clouds: false
|
||||||
|
mask_depths: false
|
||||||
|
mask_images: false
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_mask_bce: 1.0
|
loss_mask_bce: 1.0
|
||||||
@@ -49,10 +42,8 @@ generic_model_args:
|
|||||||
append_xyz:
|
append_xyz:
|
||||||
- 5
|
- 5
|
||||||
latent_dim: 0
|
latent_dim: 0
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
min_depth: 0.0
|
|
||||||
max_depth: 0.0
|
|
||||||
scene_extent: 8.0
|
scene_extent: 8.0
|
||||||
n_pts_per_ray_training: 64
|
n_pts_per_ray_training: 64
|
||||||
n_pts_per_ray_evaluation: 64
|
n_pts_per_ray_evaluation: 64
|
||||||
@@ -63,9 +54,10 @@ generic_model_args:
|
|||||||
n_pts_per_ray_fine_evaluation: 64
|
n_pts_per_ray_fine_evaluation: 64
|
||||||
append_coarse_samples_to_fine: true
|
append_coarse_samples_to_fine: true
|
||||||
density_noise_std_train: 1.0
|
density_noise_std_train: 1.0
|
||||||
|
view_pooler_args:
|
||||||
view_sampler_args:
|
view_sampler_args:
|
||||||
masked_sampling: false
|
masked_sampling: false
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
stages:
|
stages:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
add_masks: true
|
add_masks: true
|
||||||
first_max_pool: true
|
first_max_pool: true
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
add_masks: true
|
add_masks: true
|
||||||
first_max_pool: false
|
first_max_pool: false
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
image_feature_extractor_args:
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
stages:
|
stages:
|
||||||
- 1
|
- 1
|
||||||
- 2
|
- 2
|
||||||
@@ -11,6 +12,7 @@ generic_model_args:
|
|||||||
name: resnet34
|
name: resnet34
|
||||||
normalize_image: true
|
normalize_image: true
|
||||||
pretrained: true
|
pretrained: true
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||||
reduction_functions:
|
reduction_functions:
|
||||||
- AVG
|
- AVG
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_base.yaml
|
- repro_base.yaml
|
||||||
- _self_
|
- _self_
|
||||||
dataloader_args:
|
data_source_args:
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
dataset_len: 1000
|
dataset_length_train: 1000
|
||||||
dataset_len_val: 1
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
|
train_conditioning_type: SAME
|
||||||
|
val_conditioning_type: SAME
|
||||||
|
test_conditioning_type: SAME
|
||||||
images_per_seq_options:
|
images_per_seq_options:
|
||||||
- 2
|
- 2
|
||||||
- 3
|
- 3
|
||||||
@@ -16,12 +20,9 @@ dataloader_args:
|
|||||||
- 8
|
- 8
|
||||||
- 9
|
- 9
|
||||||
- 10
|
- 10
|
||||||
dataset_args:
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
assert_single_seq: false
|
assert_single_seq: false
|
||||||
dataset_name: co3d_multisequence
|
task_str: multisequence
|
||||||
load_point_clouds: false
|
|
||||||
mask_depths: false
|
|
||||||
mask_images: false
|
|
||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
test_on_train: true
|
test_on_train: true
|
||||||
test_restrict_sequence_id: 0
|
test_restrict_sequence_id: 0
|
||||||
@@ -29,3 +30,6 @@ solver_args:
|
|||||||
max_epochs: 3000
|
max_epochs: 3000
|
||||||
milestones:
|
milestones:
|
||||||
- 1000
|
- 1000
|
||||||
|
camera_difficulty_bin_breaks:
|
||||||
|
- 0.666667
|
||||||
|
- 0.833334
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ generic_model_args:
|
|||||||
num_passes: 1
|
num_passes: 1
|
||||||
output_rasterized_mc: true
|
output_rasterized_mc: true
|
||||||
sampling_mode_training: mask_sample
|
sampling_mode_training: mask_sample
|
||||||
view_pool: false
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
sequence_autodecoder_args:
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
n_instances: 20000
|
n_instances: 20000
|
||||||
init_scale: 1.0
|
init_scale: 1.0
|
||||||
encoding_dim: 256
|
encoding_dim: 256
|
||||||
@@ -55,7 +56,7 @@ generic_model_args:
|
|||||||
n_harmonic_functions_dir: 4
|
n_harmonic_functions_dir: 4
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
weight_norm: true
|
weight_norm: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
n_pts_per_ray_training: 0
|
n_pts_per_ray_training: 0
|
||||||
n_pts_per_ray_evaluation: 0
|
n_pts_per_ray_evaluation: 0
|
||||||
|
|||||||
@@ -3,7 +3,9 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
sequence_autodecoder_args:
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
n_instances: 20000
|
n_instances: 20000
|
||||||
encoding_dim: 256
|
encoding_dim: 256
|
||||||
|
|||||||
@@ -5,6 +5,6 @@ defaults:
|
|||||||
clip_grad: 1.0
|
clip_grad: 1.0
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
raysampler_AdaptiveRaySampler_args:
|
||||||
raysampler_args:
|
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
n_pts_per_ray_training: 32
|
n_pts_per_ray_training: 32
|
||||||
n_pts_per_ray_evaluation: 32
|
n_pts_per_ray_evaluation: 32
|
||||||
@@ -13,4 +12,6 @@ generic_model_args:
|
|||||||
n_pts_per_ray_fine_training: 16
|
n_pts_per_ray_fine_training: 16
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
n_pts_per_ray_fine_evaluation: 16
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
|
view_pooler_enabled: true
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||||
|
|||||||
@@ -1,16 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_nerformer.yaml
|
||||||
- repro_feat_extractor_transformer.yaml
|
|
||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
view_pooler_args:
|
||||||
view_pool: true
|
|
||||||
raysampler_args:
|
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
|
||||||
n_pts_per_ray_training: 32
|
|
||||||
n_pts_per_ray_evaluation: 32
|
|
||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
|
||||||
n_pts_per_ray_fine_training: 16
|
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
|
||||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
n_train_target_views: -1
|
n_train_target_views: -1
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
loss_weights:
|
loss_weights:
|
||||||
@@ -13,14 +13,16 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.001
|
loss_autodecoder_norm: 0.001
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
sequence_autodecoder_args:
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
encoding_dim: 256
|
encoding_dim: 256
|
||||||
n_instances: 20000
|
n_instances: 20000
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pool: true
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
n_train_target_views: -1
|
n_train_target_views: -1
|
||||||
loss_weights:
|
loss_weights:
|
||||||
@@ -14,17 +13,18 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.0
|
loss_autodecoder_norm: 0.0
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
stratified_point_sampling_evaluation: false
|
stratified_point_sampling_evaluation: false
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
implicit_function_class_type: SRNImplicitFunction
|
implicit_function_class_type: SRNImplicitFunction
|
||||||
|
view_pooler_enabled: true
|
||||||
solver_args:
|
solver_args:
|
||||||
breed: adam
|
breed: adam
|
||||||
lr: 5.0e-05
|
lr: 5.0e-05
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_base
|
- repro_base
|
||||||
- _self_
|
- _self_
|
||||||
dataloader_args:
|
data_source_args:
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
dataset_len: 1000
|
dataset_length_train: 1000
|
||||||
dataset_len_val: 1
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
images_per_seq_options:
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
- 2
|
|
||||||
dataset_args:
|
|
||||||
dataset_name: co3d_singlesequence
|
|
||||||
assert_single_seq: true
|
assert_single_seq: true
|
||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
test_restrict_sequence_id: 0
|
test_restrict_sequence_id: 0
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ generic_model_args:
|
|||||||
loss_eikonal: 0.1
|
loss_eikonal: 0.1
|
||||||
chunk_size_grid: 65536
|
chunk_size_grid: 65536
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
implicit_function_IdrFeatureField_args:
|
implicit_function_IdrFeatureField_args:
|
||||||
n_harmonic_functions_xyz: 6
|
n_harmonic_functions_xyz: 6
|
||||||
bias: 0.6
|
bias: 0.6
|
||||||
@@ -49,7 +49,7 @@ generic_model_args:
|
|||||||
n_harmonic_functions_dir: 4
|
n_harmonic_functions_dir: 4
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
weight_norm: true
|
weight_norm: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
n_pts_per_ray_training: 0
|
n_pts_per_ray_training: 0
|
||||||
n_pts_per_ray_evaluation: 0
|
n_pts_per_ray_evaluation: 0
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_base
|
- repro_singleseq_base
|
||||||
- _self_
|
- _self_
|
||||||
exp_dir: ./data/nerf_single_apple/
|
|
||||||
|
|||||||
@@ -4,6 +4,6 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
generic_model_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
raysampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
n_pts_per_ray_training: 32
|
n_pts_per_ray_training: 32
|
||||||
n_pts_per_ray_evaluation: 32
|
n_pts_per_ray_evaluation: 32
|
||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
n_pts_per_ray_fine_training: 16
|
n_pts_per_ray_fine_training: 16
|
||||||
n_pts_per_ray_fine_evaluation: 16
|
n_pts_per_ray_fine_evaluation: 16
|
||||||
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ defaults:
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pool: false
|
view_pooler_enabled: false
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 200.0
|
loss_rgb_mse: 200.0
|
||||||
loss_prev_stage_rgb_mse: 0.0
|
loss_prev_stage_rgb_mse: 0.0
|
||||||
@@ -12,11 +12,11 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.0
|
loss_autodecoder_norm: 0.0
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ defaults:
|
|||||||
generic_model_args:
|
generic_model_args:
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pool: true
|
view_pooler_enabled: true
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_rgb_mse: 200.0
|
loss_rgb_mse: 200.0
|
||||||
loss_prev_stage_rgb_mse: 0.0
|
loss_prev_stage_rgb_mse: 0.0
|
||||||
@@ -13,11 +13,11 @@ generic_model_args:
|
|||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
loss_autodecoder_norm: 0.0
|
loss_autodecoder_norm: 0.0
|
||||||
depth_neg_penalty: 10000.0
|
depth_neg_penalty: 10000.0
|
||||||
raysampler_args:
|
raysampler_class_type: NearFarRaySampler
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 2048
|
n_rays_per_image_sampled_from_mask: 2048
|
||||||
min_depth: 0.05
|
min_depth: 0.05
|
||||||
max_depth: 0.05
|
max_depth: 0.05
|
||||||
scene_extent: 0.0
|
|
||||||
n_pts_per_ray_training: 1
|
n_pts_per_ray_training: 1
|
||||||
n_pts_per_ray_evaluation: 1
|
n_pts_per_ray_evaluation: 1
|
||||||
stratified_point_sampling_training: false
|
stratified_point_sampling_training: false
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_base
|
- repro_singleseq_base
|
||||||
- _self_
|
- _self_
|
||||||
dataloader_args:
|
data_source_args:
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
dataset_len: 1000
|
dataset_length_train: 1000
|
||||||
dataset_len_val: 1
|
dataset_length_val: 1
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
|
train_conditioning_type: SAME
|
||||||
|
val_conditioning_type: SAME
|
||||||
|
test_conditioning_type: SAME
|
||||||
images_per_seq_options:
|
images_per_seq_options:
|
||||||
- 2
|
- 2
|
||||||
- 3
|
- 3
|
||||||
|
|||||||
@@ -45,7 +45,6 @@ The outputs of the experiment are saved and logged in multiple ways:
|
|||||||
config file.
|
config file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -53,7 +52,6 @@ import os
|
|||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
@@ -61,26 +59,29 @@ import lpips
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
|
from accelerate import Accelerator
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
|
||||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||||
FrameData,
|
|
||||||
ImplicitronDataset,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||||
|
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||||
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
|
||||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
enable_get_default_args,
|
expand_args_fields,
|
||||||
get_default_args_field,
|
|
||||||
remove_unused_components,
|
remove_unused_components,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.stats import Stats
|
from pytorch3d.implicitron.tools.stats import Stats
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .impl.experiment_config import ExperimentConfig
|
||||||
|
from .impl.optimization import init_optimizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -96,9 +97,13 @@ try:
|
|||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
||||||
|
|
||||||
|
|
||||||
def init_model(
|
def init_model(
|
||||||
|
*,
|
||||||
cfg: DictConfig,
|
cfg: DictConfig,
|
||||||
|
accelerator: Optional[Accelerator] = None,
|
||||||
force_load: bool = False,
|
force_load: bool = False,
|
||||||
clear_stats: bool = False,
|
clear_stats: bool = False,
|
||||||
load_model_only: bool = False,
|
load_model_only: bool = False,
|
||||||
@@ -162,12 +167,20 @@ def init_model(
|
|||||||
logger.info("found previous model %s" % model_path)
|
logger.info("found previous model %s" % model_path)
|
||||||
if force_load or cfg.resume:
|
if force_load or cfg.resume:
|
||||||
logger.info(" -> resuming")
|
logger.info(" -> resuming")
|
||||||
|
|
||||||
|
map_location = None
|
||||||
|
if accelerator is not None and not accelerator.is_local_main_process:
|
||||||
|
map_location = {
|
||||||
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||||
|
}
|
||||||
if load_model_only:
|
if load_model_only:
|
||||||
model_state_dict = torch.load(model_io.get_model_path(model_path))
|
model_state_dict = torch.load(
|
||||||
|
model_io.get_model_path(model_path), map_location=map_location
|
||||||
|
)
|
||||||
stats_load, optimizer_state = None, None
|
stats_load, optimizer_state = None, None
|
||||||
else:
|
else:
|
||||||
model_state_dict, stats_load, optimizer_state = model_io.load_model(
|
model_state_dict, stats_load, optimizer_state = model_io.load_model(
|
||||||
model_path
|
model_path, map_location=map_location
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine if stats should be reset
|
# Determine if stats should be reset
|
||||||
@@ -211,116 +224,21 @@ def init_model(
|
|||||||
return model, stats, optimizer_state
|
return model, stats, optimizer_state
|
||||||
|
|
||||||
|
|
||||||
def init_optimizer(
|
|
||||||
model: GenericModel,
|
|
||||||
optimizer_state: Optional[Dict[str, Any]],
|
|
||||||
last_epoch: int,
|
|
||||||
breed: str = "adam",
|
|
||||||
weight_decay: float = 0.0,
|
|
||||||
lr_policy: str = "multistep",
|
|
||||||
lr: float = 0.0005,
|
|
||||||
gamma: float = 0.1,
|
|
||||||
momentum: float = 0.9,
|
|
||||||
betas: Tuple[float] = (0.9, 0.999),
|
|
||||||
milestones: tuple = (),
|
|
||||||
max_epochs: int = 1000,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the optimizer (optionally from checkpoint state)
|
|
||||||
and the learning rate scheduler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model with optionally loaded weights
|
|
||||||
optimizer_state: The state dict for the optimizer. If None
|
|
||||||
it has not been loaded from checkpoint
|
|
||||||
last_epoch: If the model was loaded from checkpoint this will be the
|
|
||||||
number of the last epoch that was saved
|
|
||||||
breed: The type of optimizer to use e.g. adam
|
|
||||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
|
||||||
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
|
||||||
is supported.
|
|
||||||
lr: The value for the initial learning rate
|
|
||||||
gamma: Multiplicative factor of learning rate decay
|
|
||||||
momentum: Momentum factor for SGD optimizer
|
|
||||||
betas: Coefficients used for computing running averages of gradient and its square
|
|
||||||
in the Adam optimizer
|
|
||||||
milestones: List of increasing epoch indices at which the learning rate is
|
|
||||||
modified
|
|
||||||
max_epochs: The maximum number of epochs to run the optimizer for
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
optimizer: Optimizer module, optionally loaded from checkpoint
|
|
||||||
scheduler: Learning rate scheduler module
|
|
||||||
|
|
||||||
Raise:
|
|
||||||
ValueError if `breed` or `lr_policy` are not supported.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Get the parameters to optimize
|
|
||||||
if hasattr(model, "_get_param_groups"): # use the model function
|
|
||||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
|
||||||
else:
|
|
||||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
|
||||||
p_groups = [{"params": allprm, "lr": lr}]
|
|
||||||
|
|
||||||
# Intialize the optimizer
|
|
||||||
if breed == "sgd":
|
|
||||||
optimizer = torch.optim.SGD(
|
|
||||||
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
|
||||||
)
|
|
||||||
elif breed == "adagrad":
|
|
||||||
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
|
||||||
elif breed == "adam":
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("no such solver type %s" % breed)
|
|
||||||
logger.info(" -> solver type = %s" % breed)
|
|
||||||
|
|
||||||
# Load state from checkpoint
|
|
||||||
if optimizer_state is not None:
|
|
||||||
logger.info(" -> setting loaded optimizer state")
|
|
||||||
optimizer.load_state_dict(optimizer_state)
|
|
||||||
|
|
||||||
# Initialize the learning rate scheduler
|
|
||||||
if lr_policy == "multistep":
|
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
||||||
optimizer,
|
|
||||||
milestones=milestones,
|
|
||||||
gamma=gamma,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("no such lr policy %s" % lr_policy)
|
|
||||||
|
|
||||||
# When loading from checkpoint, this will make sure that the
|
|
||||||
# lr is correctly set even after returning
|
|
||||||
for _ in range(last_epoch):
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
# Add the max epochs here
|
|
||||||
scheduler.max_epochs = max_epochs
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
return optimizer, scheduler
|
|
||||||
|
|
||||||
|
|
||||||
enable_get_default_args(init_optimizer)
|
|
||||||
|
|
||||||
|
|
||||||
def trainvalidate(
|
def trainvalidate(
|
||||||
model,
|
model,
|
||||||
stats,
|
stats,
|
||||||
epoch,
|
epoch,
|
||||||
loader,
|
loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
validation,
|
validation: bool,
|
||||||
|
*,
|
||||||
|
accelerator: Optional[Accelerator],
|
||||||
|
device: torch.device,
|
||||||
bp_var: str = "objective",
|
bp_var: str = "objective",
|
||||||
metric_print_interval: int = 5,
|
metric_print_interval: int = 5,
|
||||||
visualize_interval: int = 100,
|
visualize_interval: int = 100,
|
||||||
visdom_env_root: str = "trainvalidate",
|
visdom_env_root: str = "trainvalidate",
|
||||||
clip_grad: float = 0.0,
|
clip_grad: float = 0.0,
|
||||||
device: str = "cuda:0",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -368,11 +286,11 @@ def trainvalidate(
|
|||||||
|
|
||||||
# Iterate through the batches
|
# Iterate through the batches
|
||||||
n_batches = len(loader)
|
n_batches = len(loader)
|
||||||
for it, batch in enumerate(loader):
|
for it, net_input in enumerate(loader):
|
||||||
last_iter = it == n_batches - 1
|
last_iter = it == n_batches - 1
|
||||||
|
|
||||||
# move to gpu where possible (in place)
|
# move to gpu where possible (in place)
|
||||||
net_input = batch.to(device)
|
net_input = net_input.to(device)
|
||||||
|
|
||||||
# run the forward pass
|
# run the forward pass
|
||||||
if not validation:
|
if not validation:
|
||||||
@@ -398,7 +316,11 @@ def trainvalidate(
|
|||||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
stats.print(stat_set=trainmode, max_it=n_batches)
|
||||||
|
|
||||||
# visualize results
|
# visualize results
|
||||||
if visualize_interval > 0 and it % visualize_interval == 0:
|
if (
|
||||||
|
(accelerator is None or accelerator.is_local_main_process)
|
||||||
|
and visualize_interval > 0
|
||||||
|
and it % visualize_interval == 0
|
||||||
|
):
|
||||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||||
|
|
||||||
model.visualize(
|
model.visualize(
|
||||||
@@ -413,7 +335,10 @@ def trainvalidate(
|
|||||||
loss = preds[bp_var]
|
loss = preds[bp_var]
|
||||||
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
||||||
# backprop
|
# backprop
|
||||||
|
if accelerator is None:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
else:
|
||||||
|
accelerator.backward(loss)
|
||||||
if clip_grad > 0.0:
|
if clip_grad > 0.0:
|
||||||
# Optionally clip the gradient norms.
|
# Optionally clip the gradient norms.
|
||||||
total_norm = torch.nn.utils.clip_grad_norm(
|
total_norm = torch.nn.utils.clip_grad_norm(
|
||||||
@@ -422,18 +347,29 @@ def trainvalidate(
|
|||||||
if total_norm > clip_grad:
|
if total_norm > clip_grad:
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Clipping gradient: {total_norm}"
|
f"Clipping gradient: {total_norm}"
|
||||||
+ f" with coef {clip_grad / total_norm}."
|
+ f" with coef {clip_grad / float(total_norm)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
def run_training(cfg: DictConfig, device: str = "cpu"):
|
def run_training(cfg: DictConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Entry point to run the training and validation loops
|
Entry point to run the training and validation loops
|
||||||
based on the specified config file.
|
based on the specified config file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Initialize the accelerator
|
||||||
|
if no_accelerate:
|
||||||
|
accelerator = None
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
else:
|
||||||
|
accelerator = Accelerator(device_placement=False)
|
||||||
|
logger.info(accelerator.state)
|
||||||
|
device = accelerator.device
|
||||||
|
|
||||||
|
logger.info(f"Running experiment on device: {device}")
|
||||||
|
|
||||||
# set the debug mode
|
# set the debug mode
|
||||||
if cfg.detect_anomaly:
|
if cfg.detect_anomaly:
|
||||||
logger.info("Anomaly detection!")
|
logger.info("Anomaly detection!")
|
||||||
@@ -452,12 +388,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
|||||||
warnings.warn("Cant dump config due to insufficient permissions!")
|
warnings.warn("Cant dump config due to insufficient permissions!")
|
||||||
|
|
||||||
# setup datasets
|
# setup datasets
|
||||||
datasets = dataset_zoo(**cfg.dataset_args)
|
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||||
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
|
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||||
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
|
task = datasource.get_task()
|
||||||
|
|
||||||
# init the model
|
# init the model
|
||||||
model, stats, optimizer_state = init_model(cfg)
|
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
|
||||||
start_epoch = stats.epoch + 1
|
start_epoch = stats.epoch + 1
|
||||||
|
|
||||||
# move model to gpu
|
# move model to gpu
|
||||||
@@ -465,7 +401,16 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
|||||||
|
|
||||||
# only run evaluation on the test dataloader
|
# only run evaluation on the test dataloader
|
||||||
if cfg.eval_only:
|
if cfg.eval_only:
|
||||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
_eval_and_dump(
|
||||||
|
cfg,
|
||||||
|
task,
|
||||||
|
datasource.all_train_cameras,
|
||||||
|
datasets,
|
||||||
|
dataloaders,
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
# init the optimizer
|
# init the optimizer
|
||||||
@@ -480,6 +425,19 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
|||||||
assert scheduler.last_epoch == stats.epoch + 1
|
assert scheduler.last_epoch == stats.epoch + 1
|
||||||
assert scheduler.last_epoch == start_epoch
|
assert scheduler.last_epoch == start_epoch
|
||||||
|
|
||||||
|
# Wrap all modules in the distributed library
|
||||||
|
# Note: we don't pass the scheduler to prepare as it
|
||||||
|
# doesn't need to be stepped at each optimizer step
|
||||||
|
train_loader = dataloaders.train
|
||||||
|
val_loader = dataloaders.val
|
||||||
|
if accelerator is not None:
|
||||||
|
(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
train_loader,
|
||||||
|
val_loader,
|
||||||
|
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||||
|
|
||||||
past_scheduler_lrs = []
|
past_scheduler_lrs = []
|
||||||
# loop through epochs
|
# loop through epochs
|
||||||
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
|
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
|
||||||
@@ -499,46 +457,62 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
|||||||
model,
|
model,
|
||||||
stats,
|
stats,
|
||||||
epoch,
|
epoch,
|
||||||
dataloaders["train"],
|
train_loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
False,
|
False,
|
||||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||||
device=device,
|
device=device,
|
||||||
|
accelerator=accelerator,
|
||||||
**cfg,
|
**cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# val loop (optional)
|
# val loop (optional)
|
||||||
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
|
if val_loader is not None and epoch % cfg.validation_interval == 0:
|
||||||
trainvalidate(
|
trainvalidate(
|
||||||
model,
|
model,
|
||||||
stats,
|
stats,
|
||||||
epoch,
|
epoch,
|
||||||
dataloaders["val"],
|
val_loader,
|
||||||
optimizer,
|
optimizer,
|
||||||
True,
|
True,
|
||||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||||
device=device,
|
device=device,
|
||||||
|
accelerator=accelerator,
|
||||||
**cfg,
|
**cfg,
|
||||||
)
|
)
|
||||||
|
|
||||||
# eval loop (optional)
|
# eval loop (optional)
|
||||||
if (
|
if (
|
||||||
"test" in dataloaders
|
dataloaders.test is not None
|
||||||
and cfg.test_interval > 0
|
and cfg.test_interval > 0
|
||||||
and epoch % cfg.test_interval == 0
|
and epoch % cfg.test_interval == 0
|
||||||
):
|
):
|
||||||
run_eval(cfg, model, stats, dataloaders["test"], device=device)
|
_run_eval(
|
||||||
|
model,
|
||||||
|
datasource.all_train_cameras,
|
||||||
|
dataloaders.test,
|
||||||
|
task,
|
||||||
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
assert stats.epoch == epoch, "inconsistent stats!"
|
assert stats.epoch == epoch, "inconsistent stats!"
|
||||||
|
|
||||||
# delete previous models if required
|
# delete previous models if required
|
||||||
# save model
|
# save model only on the main process
|
||||||
if cfg.store_checkpoints:
|
if cfg.store_checkpoints and (
|
||||||
|
accelerator is None or accelerator.is_local_main_process
|
||||||
|
):
|
||||||
if cfg.store_checkpoints_purge > 0:
|
if cfg.store_checkpoints_purge > 0:
|
||||||
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
|
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
|
||||||
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
|
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
|
||||||
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
|
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
|
||||||
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
|
unwrapped_model = (
|
||||||
|
model if accelerator is None else accelerator.unwrap_model(model)
|
||||||
|
)
|
||||||
|
model_io.safe_save_model(
|
||||||
|
unwrapped_model, stats, outfile, optimizer=optimizer
|
||||||
|
)
|
||||||
|
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
@@ -547,26 +521,45 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
|||||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||||
|
|
||||||
if cfg.test_when_finished:
|
if cfg.test_when_finished:
|
||||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
_eval_and_dump(
|
||||||
|
cfg,
|
||||||
|
task,
|
||||||
|
datasource.all_train_cameras,
|
||||||
|
datasets,
|
||||||
|
dataloaders,
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
|
def _eval_and_dump(
|
||||||
|
cfg,
|
||||||
|
task: Task,
|
||||||
|
all_train_cameras: Optional[CamerasBase],
|
||||||
|
datasets: DatasetMap,
|
||||||
|
dataloaders: DataLoaderMap,
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
device,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Run the evaluation loop with the test data loader and
|
Run the evaluation loop with the test data loader and
|
||||||
save the predictions to the `exp_dir`.
|
save the predictions to the `exp_dir`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if "test" not in dataloaders:
|
dataloader = dataloaders.test
|
||||||
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
|
|
||||||
|
|
||||||
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
|
if dataloader is None:
|
||||||
all_source_cameras = (
|
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||||
_get_all_source_cameras(datasets["train"])
|
|
||||||
if eval_task == "singlesequence"
|
results = _run_eval(
|
||||||
else None
|
model,
|
||||||
)
|
all_train_cameras,
|
||||||
results = run_eval(
|
dataloader,
|
||||||
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
|
task,
|
||||||
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# add the evaluation epoch to the results
|
# add the evaluation epoch to the results
|
||||||
@@ -594,7 +587,14 @@ def _get_eval_frame_data(frame_data):
|
|||||||
return frame_data_for_eval
|
return frame_data_for_eval
|
||||||
|
|
||||||
|
|
||||||
def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
def _run_eval(
|
||||||
|
model,
|
||||||
|
all_train_cameras,
|
||||||
|
loader,
|
||||||
|
task: Task,
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||||
|
device,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Run the evaluation loop on the test dataloader
|
Run the evaluation loop on the test dataloader
|
||||||
"""
|
"""
|
||||||
@@ -615,104 +615,91 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
|||||||
preds = model(
|
preds = model(
|
||||||
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||||
)
|
)
|
||||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
|
||||||
|
# TODO: Cannot use accelerate gather for two reasons:.
|
||||||
|
# (1) TypeError: Can't apply _gpu_gather_one on object of type
|
||||||
|
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
|
||||||
|
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
|
||||||
|
# (2) Same error above but for frame_data which contains Cameras.
|
||||||
|
|
||||||
|
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||||
|
|
||||||
per_batch_eval_results.append(
|
per_batch_eval_results.append(
|
||||||
evaluate.eval_batch(
|
evaluate.eval_batch(
|
||||||
frame_data,
|
frame_data,
|
||||||
nvs_prediction,
|
implicitron_render,
|
||||||
bg_color="black",
|
bg_color="black",
|
||||||
lpips_model=lpips_model,
|
lpips_model=lpips_model,
|
||||||
source_cameras=all_source_cameras,
|
source_cameras=all_train_cameras,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||||
per_batch_eval_results, task
|
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||||
)
|
)
|
||||||
|
|
||||||
return category_result["results"]
|
return category_result["results"]
|
||||||
|
|
||||||
|
|
||||||
def _get_all_source_cameras(
|
def _seed_all_random_engines(seed: int) -> None:
|
||||||
dataset: ImplicitronDataset,
|
|
||||||
num_workers: int = 8,
|
|
||||||
) -> CamerasBase:
|
|
||||||
"""
|
|
||||||
Load and return all the source cameras in the training dataset
|
|
||||||
"""
|
|
||||||
|
|
||||||
all_frame_data = next(
|
|
||||||
iter(
|
|
||||||
torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
shuffle=False,
|
|
||||||
batch_size=len(dataset),
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=FrameData.collate,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
|
|
||||||
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
|
|
||||||
return source_cameras
|
|
||||||
|
|
||||||
|
|
||||||
def _seed_all_random_engines(seed: int):
|
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
def _setup_envvars_for_cluster() -> bool:
|
||||||
class ExperimentConfig:
|
"""
|
||||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
Prepares to run on cluster if relevant.
|
||||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
Returns whether FAIR cluster in use.
|
||||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
"""
|
||||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
# TODO: How much of this is needed in general?
|
||||||
architecture: str = "generic"
|
|
||||||
detect_anomaly: bool = False
|
|
||||||
eval_only: bool = False
|
|
||||||
exp_dir: str = "./data/default_experiment/"
|
|
||||||
exp_idx: int = 0
|
|
||||||
gpu_idx: int = 0
|
|
||||||
metric_print_interval: int = 5
|
|
||||||
resume: bool = True
|
|
||||||
resume_epoch: int = -1
|
|
||||||
seed: int = 0
|
|
||||||
store_checkpoints: bool = True
|
|
||||||
store_checkpoints_purge: int = 1
|
|
||||||
test_interval: int = -1
|
|
||||||
test_when_finished: bool = False
|
|
||||||
validation_interval: int = 1
|
|
||||||
visdom_env: str = ""
|
|
||||||
visdom_port: int = 8097
|
|
||||||
visdom_server: str = "http://127.0.0.1"
|
|
||||||
visualize_interval: int = 1000
|
|
||||||
clip_grad: float = 0.0
|
|
||||||
|
|
||||||
hydra: dict = field(
|
try:
|
||||||
default_factory=lambda: {
|
import submitit
|
||||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
except ImportError:
|
||||||
"output_subdir": None, # disable storing the .hydra logs
|
return False
|
||||||
}
|
|
||||||
|
try:
|
||||||
|
# Only needed when launching on cluster with slurm and submitit
|
||||||
|
job_env = submitit.JobEnvironment()
|
||||||
|
except RuntimeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
|
||||||
|
os.environ["RANK"] = str(job_env.global_rank)
|
||||||
|
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
|
||||||
|
os.environ["MASTER_ADDR"] = "localhost"
|
||||||
|
os.environ["MASTER_PORT"] = "42918"
|
||||||
|
logger.info(
|
||||||
|
"Num tasks %s, global_rank %s"
|
||||||
|
% (str(job_env.num_tasks), str(job_env.global_rank))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
expand_args_fields(ExperimentConfig)
|
||||||
cs = hydra.core.config_store.ConfigStore.instance()
|
cs = hydra.core.config_store.ConfigStore.instance()
|
||||||
cs.store(name="default_config", node=ExperimentConfig)
|
cs.store(name="default_config", node=ExperimentConfig)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="./configs/", config_name="default_config")
|
@hydra.main(config_path="./configs/", config_name="default_config")
|
||||||
def experiment(cfg: DictConfig) -> None:
|
def experiment(cfg: DictConfig) -> None:
|
||||||
|
# CUDA_VISIBLE_DEVICES must have been set.
|
||||||
|
|
||||||
|
if "CUDA_DEVICE_ORDER" not in os.environ:
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
|
||||||
# Set the device
|
if not _setup_envvars_for_cluster():
|
||||||
device = "cpu"
|
logger.info("Running locally")
|
||||||
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
|
|
||||||
device = f"cuda:{cfg.gpu_idx}"
|
# TODO: The following may be needed for hydra/submitit it to work
|
||||||
logger.info(f"Running experiment on device: {device}")
|
expand_args_fields(GenericModel)
|
||||||
run_training(cfg, device)
|
expand_args_fields(AdaptiveRaySampler)
|
||||||
|
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
|
||||||
|
expand_args_fields(ImplicitronDataSource)
|
||||||
|
|
||||||
|
run_training(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
|
||||||
|
|
||||||
|
from .optimization import init_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentConfig(Configurable):
|
||||||
|
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||||
|
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||||
|
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||||
|
architecture: str = "generic"
|
||||||
|
detect_anomaly: bool = False
|
||||||
|
eval_only: bool = False
|
||||||
|
exp_dir: str = "./data/default_experiment/"
|
||||||
|
exp_idx: int = 0
|
||||||
|
gpu_idx: int = 0
|
||||||
|
metric_print_interval: int = 5
|
||||||
|
resume: bool = True
|
||||||
|
resume_epoch: int = -1
|
||||||
|
seed: int = 0
|
||||||
|
store_checkpoints: bool = True
|
||||||
|
store_checkpoints_purge: int = 1
|
||||||
|
test_interval: int = -1
|
||||||
|
test_when_finished: bool = False
|
||||||
|
validation_interval: int = 1
|
||||||
|
visdom_env: str = ""
|
||||||
|
visdom_port: int = 8097
|
||||||
|
visdom_server: str = "http://127.0.0.1"
|
||||||
|
visualize_interval: int = 1000
|
||||||
|
clip_grad: float = 0.0
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||||
|
|
||||||
|
hydra: Dict[str, Any] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||||
|
"output_subdir": None, # disable storing the .hydra logs
|
||||||
|
}
|
||||||
|
)
|
||||||
109
projects/implicitron_trainer/impl/optimization.py
Normal file
109
projects/implicitron_trainer/impl/optimization.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def init_optimizer(
|
||||||
|
model: GenericModel,
|
||||||
|
optimizer_state: Optional[Dict[str, Any]],
|
||||||
|
last_epoch: int,
|
||||||
|
breed: str = "adam",
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
lr_policy: str = "multistep",
|
||||||
|
lr: float = 0.0005,
|
||||||
|
gamma: float = 0.1,
|
||||||
|
momentum: float = 0.9,
|
||||||
|
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||||
|
milestones: Tuple[int, ...] = (),
|
||||||
|
max_epochs: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the optimizer (optionally from checkpoint state)
|
||||||
|
and the learning rate scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model with optionally loaded weights
|
||||||
|
optimizer_state: The state dict for the optimizer. If None
|
||||||
|
it has not been loaded from checkpoint
|
||||||
|
last_epoch: If the model was loaded from checkpoint this will be the
|
||||||
|
number of the last epoch that was saved
|
||||||
|
breed: The type of optimizer to use e.g. adam
|
||||||
|
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||||
|
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||||
|
is supported.
|
||||||
|
lr: The value for the initial learning rate
|
||||||
|
gamma: Multiplicative factor of learning rate decay
|
||||||
|
momentum: Momentum factor for SGD optimizer
|
||||||
|
betas: Coefficients used for computing running averages of gradient and its square
|
||||||
|
in the Adam optimizer
|
||||||
|
milestones: List of increasing epoch indices at which the learning rate is
|
||||||
|
modified
|
||||||
|
max_epochs: The maximum number of epochs to run the optimizer for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||||
|
scheduler: Learning rate scheduler module
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
ValueError if `breed` or `lr_policy` are not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the parameters to optimize
|
||||||
|
if hasattr(model, "_get_param_groups"): # use the model function
|
||||||
|
# pyre-ignore[29]
|
||||||
|
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||||
|
else:
|
||||||
|
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||||
|
p_groups = [{"params": allprm, "lr": lr}]
|
||||||
|
|
||||||
|
# Intialize the optimizer
|
||||||
|
if breed == "sgd":
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
elif breed == "adagrad":
|
||||||
|
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||||
|
elif breed == "adam":
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("no such solver type %s" % breed)
|
||||||
|
logger.info(" -> solver type = %s" % breed)
|
||||||
|
|
||||||
|
# Load state from checkpoint
|
||||||
|
if optimizer_state is not None:
|
||||||
|
logger.info(" -> setting loaded optimizer state")
|
||||||
|
optimizer.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
|
# Initialize the learning rate scheduler
|
||||||
|
if lr_policy == "multistep":
|
||||||
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
optimizer,
|
||||||
|
milestones=milestones,
|
||||||
|
gamma=gamma,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("no such lr policy %s" % lr_policy)
|
||||||
|
|
||||||
|
# When loading from checkpoint, this will make sure that the
|
||||||
|
# lr is correctly set even after returning
|
||||||
|
for _ in range(last_epoch):
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
enable_get_default_args(init_optimizer)
|
||||||
5
projects/implicitron_trainer/tests/__init__.py
Normal file
5
projects/implicitron_trainer/tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
425
projects/implicitron_trainer/tests/experiment.yaml
Normal file
425
projects/implicitron_trainer/tests/experiment.yaml
Normal file
@@ -0,0 +1,425 @@
|
|||||||
|
generic_model_args:
|
||||||
|
mask_images: true
|
||||||
|
mask_depths: true
|
||||||
|
render_image_width: 400
|
||||||
|
render_image_height: 400
|
||||||
|
mask_threshold: 0.5
|
||||||
|
output_rasterized_mc: false
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
num_passes: 1
|
||||||
|
chunk_size_grid: 4096
|
||||||
|
render_features_dimensions: 3
|
||||||
|
tqdm_trigger_threshold: 16
|
||||||
|
n_train_target_views: 1
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
global_encoder_class_type: null
|
||||||
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
|
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
||||||
|
image_feature_extractor_class_type: null
|
||||||
|
view_pooler_enabled: false
|
||||||
|
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
||||||
|
view_metrics_class_type: ViewMetrics
|
||||||
|
regularization_metrics_class_type: RegularizationMetrics
|
||||||
|
loss_weights:
|
||||||
|
loss_rgb_mse: 1.0
|
||||||
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
|
loss_mask_bce: 0.0
|
||||||
|
loss_prev_stage_mask_bce: 0.0
|
||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_rgb_huber
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_mask_beta_prior
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_density_tv
|
||||||
|
- loss_depth_neg_penalty
|
||||||
|
- loss_autodecoder_norm
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_psnr_fg
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
|
global_encoder_HarmonicTimeEncoder_args:
|
||||||
|
n_harmonic_functions: 10
|
||||||
|
append_input: true
|
||||||
|
time_divisor: 1.0
|
||||||
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
|
encoding_dim: 0
|
||||||
|
n_instances: 0
|
||||||
|
init_scale: 1.0
|
||||||
|
ignore_input: false
|
||||||
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
image_width: 400
|
||||||
|
image_height: 400
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
scene_extent: 8.0
|
||||||
|
scene_center:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
|
image_width: 400
|
||||||
|
image_height: 400
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
min_depth: 0.1
|
||||||
|
max_depth: 8.0
|
||||||
|
renderer_LSTMRenderer_args:
|
||||||
|
num_raymarch_steps: 10
|
||||||
|
init_depth: 17.0
|
||||||
|
init_depth_noise_std: 0.0005
|
||||||
|
hidden_size: 16
|
||||||
|
n_feature_channels: 256
|
||||||
|
bg_color: null
|
||||||
|
verbose: false
|
||||||
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
|
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
||||||
|
n_pts_per_ray_fine_training: 64
|
||||||
|
n_pts_per_ray_fine_evaluation: 64
|
||||||
|
stratified_sampling_coarse_training: true
|
||||||
|
stratified_sampling_coarse_evaluation: false
|
||||||
|
append_coarse_samples_to_fine: true
|
||||||
|
density_noise_std_train: 0.0
|
||||||
|
return_weights: false
|
||||||
|
raymarcher_CumsumRaymarcher_args:
|
||||||
|
surface_thickness: 1
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
background_opacity: 0.0
|
||||||
|
density_relu: true
|
||||||
|
blend_output: false
|
||||||
|
raymarcher_EmissionAbsorptionRaymarcher_args:
|
||||||
|
surface_thickness: 1
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
background_opacity: 10000000000.0
|
||||||
|
density_relu: true
|
||||||
|
blend_output: false
|
||||||
|
renderer_SignedDistanceFunctionRenderer_args:
|
||||||
|
render_features_dimensions: 3
|
||||||
|
ray_tracer_args:
|
||||||
|
object_bounding_sphere: 1.0
|
||||||
|
sdf_threshold: 5.0e-05
|
||||||
|
line_search_step: 0.5
|
||||||
|
line_step_iters: 1
|
||||||
|
sphere_tracing_iters: 10
|
||||||
|
n_steps: 100
|
||||||
|
n_secant_steps: 8
|
||||||
|
ray_normal_coloring_network_args:
|
||||||
|
feature_vector_size: 3
|
||||||
|
mode: idr
|
||||||
|
d_in: 9
|
||||||
|
d_out: 3
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_dir: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
soft_mask_alpha: 50.0
|
||||||
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
|
name: resnet34
|
||||||
|
pretrained: true
|
||||||
|
stages:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
normalize_image: true
|
||||||
|
image_rescale: 0.16
|
||||||
|
first_max_pool: true
|
||||||
|
proj_dim: 32
|
||||||
|
l2_norm: true
|
||||||
|
add_masks: true
|
||||||
|
add_images: true
|
||||||
|
global_average_pool: false
|
||||||
|
feature_rescale: 1.0
|
||||||
|
view_pooler_args:
|
||||||
|
feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator
|
||||||
|
view_sampler_args:
|
||||||
|
masked_sampling: false
|
||||||
|
sampling_mode: bilinear
|
||||||
|
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
weight_by_ray_angle_gamma: 1.0
|
||||||
|
min_ray_angle_weight: 0.1
|
||||||
|
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
reduction_functions:
|
||||||
|
- AVG
|
||||||
|
- STD
|
||||||
|
weight_by_ray_angle_gamma: 1.0
|
||||||
|
min_ray_angle_weight: 0.1
|
||||||
|
feature_aggregator_IdentityFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
feature_aggregator_ReductionFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
reduction_functions:
|
||||||
|
- AVG
|
||||||
|
- STD
|
||||||
|
implicit_function_IdrFeatureField_args:
|
||||||
|
feature_vector_size: 3
|
||||||
|
d_in: 3
|
||||||
|
d_out: 1
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
geometric_init: true
|
||||||
|
bias: 1.0
|
||||||
|
skip_in: []
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_xyz: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
encoding_dim: 0
|
||||||
|
implicit_function_NeRFormerImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
latent_dim: 0
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
color_dim: 3
|
||||||
|
transformer_dim_down_factor: 2.0
|
||||||
|
n_hidden_neurons_xyz: 80
|
||||||
|
n_layers_xyz: 2
|
||||||
|
append_xyz:
|
||||||
|
- 1
|
||||||
|
implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
latent_dim: 0
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
color_dim: 3
|
||||||
|
transformer_dim_down_factor: 1.0
|
||||||
|
n_hidden_neurons_xyz: 256
|
||||||
|
n_layers_xyz: 8
|
||||||
|
append_xyz:
|
||||||
|
- 5
|
||||||
|
implicit_function_SRNHyperNetImplicitFunction_args:
|
||||||
|
hypernet_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
n_hidden_units_hypernet: 256
|
||||||
|
n_layers_hypernet: 1
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
latent_dim_hypernet: 0
|
||||||
|
latent_dim: 0
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
implicit_function_SRNImplicitFunction_args:
|
||||||
|
raymarch_function_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
latent_dim: 0
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
raymarch_function: null
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
view_metrics_ViewMetrics_args: {}
|
||||||
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
|
solver_args:
|
||||||
|
breed: adam
|
||||||
|
weight_decay: 0.0
|
||||||
|
lr_policy: multistep
|
||||||
|
lr: 0.0005
|
||||||
|
gamma: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.999
|
||||||
|
milestones: []
|
||||||
|
max_epochs: 1000
|
||||||
|
data_source_args:
|
||||||
|
dataset_map_provider_class_type: ???
|
||||||
|
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||||
|
dataset_map_provider_BlenderDatasetMapProvider_args:
|
||||||
|
base_dir: ???
|
||||||
|
object_name: ???
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
n_known_frames_for_test: null
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
|
category: ???
|
||||||
|
task_str: singlesequence
|
||||||
|
dataset_root: ''
|
||||||
|
n_frames_per_sequence: -1
|
||||||
|
test_on_train: false
|
||||||
|
restrict_sequence_name: []
|
||||||
|
test_restrict_sequence_id: -1
|
||||||
|
assert_single_seq: false
|
||||||
|
only_test_set: false
|
||||||
|
dataset_class_type: JsonIndexDataset
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
dataset_JsonIndexDataset_args:
|
||||||
|
limit_to: 0
|
||||||
|
limit_sequences_to: 0
|
||||||
|
exclude_sequence: []
|
||||||
|
limit_category_to: []
|
||||||
|
load_images: true
|
||||||
|
load_depths: true
|
||||||
|
load_depth_masks: true
|
||||||
|
load_masks: true
|
||||||
|
load_point_clouds: false
|
||||||
|
max_points: 0
|
||||||
|
mask_images: false
|
||||||
|
mask_depths: false
|
||||||
|
image_height: 800
|
||||||
|
image_width: 800
|
||||||
|
box_crop: true
|
||||||
|
box_crop_mask_thr: 0.4
|
||||||
|
box_crop_context: 0.3
|
||||||
|
remove_empty_masks: true
|
||||||
|
seed: 0
|
||||||
|
sort_frames: false
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||||
|
category: ???
|
||||||
|
subset_name: ???
|
||||||
|
dataset_root: ''
|
||||||
|
test_on_train: false
|
||||||
|
only_test_set: false
|
||||||
|
load_eval_batches: true
|
||||||
|
dataset_class_type: JsonIndexDataset
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
dataset_JsonIndexDataset_args:
|
||||||
|
path_manager: null
|
||||||
|
frame_annotations_file: ''
|
||||||
|
sequence_annotations_file: ''
|
||||||
|
subset_lists_file: ''
|
||||||
|
subsets: null
|
||||||
|
limit_to: 0
|
||||||
|
limit_sequences_to: 0
|
||||||
|
pick_sequence: []
|
||||||
|
exclude_sequence: []
|
||||||
|
limit_category_to: []
|
||||||
|
dataset_root: ''
|
||||||
|
load_images: true
|
||||||
|
load_depths: true
|
||||||
|
load_depth_masks: true
|
||||||
|
load_masks: true
|
||||||
|
load_point_clouds: false
|
||||||
|
max_points: 0
|
||||||
|
mask_images: false
|
||||||
|
mask_depths: false
|
||||||
|
image_height: 800
|
||||||
|
image_width: 800
|
||||||
|
box_crop: true
|
||||||
|
box_crop_mask_thr: 0.4
|
||||||
|
box_crop_context: 0.3
|
||||||
|
remove_empty_masks: true
|
||||||
|
n_frames_per_sequence: -1
|
||||||
|
seed: 0
|
||||||
|
sort_frames: false
|
||||||
|
eval_batches: null
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
|
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||||
|
base_dir: ???
|
||||||
|
object_name: ???
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
n_known_frames_for_test: null
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 0
|
||||||
|
dataset_length_train: 0
|
||||||
|
dataset_length_val: 0
|
||||||
|
dataset_length_test: 0
|
||||||
|
train_conditioning_type: SAME
|
||||||
|
val_conditioning_type: SAME
|
||||||
|
test_conditioning_type: KNOWN
|
||||||
|
images_per_seq_options: []
|
||||||
|
sample_consecutive_frames: false
|
||||||
|
consecutive_frames_max_gap: 0
|
||||||
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
|
architecture: generic
|
||||||
|
detect_anomaly: false
|
||||||
|
eval_only: false
|
||||||
|
exp_dir: ./data/default_experiment/
|
||||||
|
exp_idx: 0
|
||||||
|
gpu_idx: 0
|
||||||
|
metric_print_interval: 5
|
||||||
|
resume: true
|
||||||
|
resume_epoch: -1
|
||||||
|
seed: 0
|
||||||
|
store_checkpoints: true
|
||||||
|
store_checkpoints_purge: 1
|
||||||
|
test_interval: -1
|
||||||
|
test_when_finished: false
|
||||||
|
validation_interval: 1
|
||||||
|
visdom_env: ''
|
||||||
|
visdom_port: 8097
|
||||||
|
visdom_server: http://127.0.0.1
|
||||||
|
visualize_interval: 1000
|
||||||
|
clip_grad: 0.0
|
||||||
|
camera_difficulty_bin_breaks:
|
||||||
|
- 0.97
|
||||||
|
- 0.98
|
||||||
|
hydra:
|
||||||
|
run:
|
||||||
|
dir: .
|
||||||
|
output_subdir: null
|
||||||
91
projects/implicitron_trainer/tests/test_experiment.py
Normal file
91
projects/implicitron_trainer/tests/test_experiment.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from hydra import compose, initialize_config_dir
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from .. import experiment
|
||||||
|
|
||||||
|
|
||||||
|
def interactive_testing_requested() -> bool:
|
||||||
|
"""
|
||||||
|
Certain tests are only useful when run interactively, and so are not regularly run.
|
||||||
|
These are activated by this funciton returning True, which the user requests by
|
||||||
|
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
|
||||||
|
"""
|
||||||
|
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
|
||||||
|
|
||||||
|
|
||||||
|
internal = os.environ.get("FB_TEST", False)
|
||||||
|
|
||||||
|
|
||||||
|
DATA_DIR = Path(__file__).resolve().parent
|
||||||
|
IMPLICITRON_CONFIGS_DIR = Path(__file__).resolve().parent.parent / "configs"
|
||||||
|
DEBUG: bool = False
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# - add enough files to skateboard_first_5 that this works on RE.
|
||||||
|
# - share common code with PyTorch3D tests?
|
||||||
|
# - deal with the temporary output files this test creates
|
||||||
|
|
||||||
|
|
||||||
|
class TestExperiment(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.maxDiff = None
|
||||||
|
|
||||||
|
def test_from_defaults(self):
|
||||||
|
# Test making minimal changes to the dataclass defaults.
|
||||||
|
if not interactive_testing_requested() or not internal:
|
||||||
|
return
|
||||||
|
cfg = OmegaConf.structured(experiment.ExperimentConfig)
|
||||||
|
cfg.data_source_args.dataset_map_provider_class_type = (
|
||||||
|
"JsonIndexDatasetMapProvider"
|
||||||
|
)
|
||||||
|
dataset_args = (
|
||||||
|
cfg.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
)
|
||||||
|
dataloader_args = (
|
||||||
|
cfg.data_source_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
|
)
|
||||||
|
dataset_args.category = "skateboard"
|
||||||
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
|
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
||||||
|
dataset_args.dataset_JsonIndexDataset_args.image_height = 80
|
||||||
|
dataset_args.dataset_JsonIndexDataset_args.image_width = 80
|
||||||
|
dataloader_args.dataset_length_train = 1
|
||||||
|
dataloader_args.dataset_length_val = 1
|
||||||
|
cfg.solver_args.max_epochs = 2
|
||||||
|
|
||||||
|
experiment.run_training(cfg)
|
||||||
|
|
||||||
|
def test_yaml_contents(self):
|
||||||
|
cfg = OmegaConf.structured(experiment.ExperimentConfig)
|
||||||
|
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||||
|
if DEBUG:
|
||||||
|
(DATA_DIR / "experiment.yaml").write_text(yaml)
|
||||||
|
self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text())
|
||||||
|
|
||||||
|
def test_load_configs(self):
|
||||||
|
config_files = []
|
||||||
|
|
||||||
|
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
|
||||||
|
config_files.extend(
|
||||||
|
[
|
||||||
|
f
|
||||||
|
for f in IMPLICITRON_CONFIGS_DIR.glob(pattern)
|
||||||
|
if not f.name.endswith("_base.yaml")
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for file in config_files:
|
||||||
|
with self.subTest(file.name):
|
||||||
|
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||||
|
compose(file.name)
|
||||||
@@ -21,15 +21,11 @@ from typing import Optional, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as Fu
|
import torch.nn.functional as Fu
|
||||||
from experiment import init_model
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||||
FrameData,
|
|
||||||
ImplicitronDataset,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||||
from pytorch3d.implicitron.models.base import EvaluationMode
|
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||||
generate_eval_video_cameras,
|
generate_eval_video_cameras,
|
||||||
@@ -41,9 +37,11 @@ from pytorch3d.implicitron.tools.vis_utils import (
|
|||||||
)
|
)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from .experiment import init_model
|
||||||
|
|
||||||
|
|
||||||
def render_sequence(
|
def render_sequence(
|
||||||
dataset: ImplicitronDataset,
|
dataset: DatasetBase,
|
||||||
sequence_name: str,
|
sequence_name: str,
|
||||||
model: torch.nn.Module,
|
model: torch.nn.Module,
|
||||||
video_path,
|
video_path,
|
||||||
@@ -66,6 +64,12 @@ def render_sequence(
|
|||||||
):
|
):
|
||||||
if seed is None:
|
if seed is None:
|
||||||
seed = hash(sequence_name)
|
seed = hash(sequence_name)
|
||||||
|
|
||||||
|
if visdom_show_preds:
|
||||||
|
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||||
|
else:
|
||||||
|
viz = None
|
||||||
|
|
||||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
print(f"Loading all data of sequence '{sequence_name}'.")
|
||||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||||
@@ -84,7 +88,7 @@ def render_sequence(
|
|||||||
up=up,
|
up=up,
|
||||||
focal_length=None,
|
focal_length=None,
|
||||||
principal_point=torch.zeros(n_eval_cameras, 2),
|
principal_point=torch.zeros(n_eval_cameras, 2),
|
||||||
traj_offset_canonical=[0.0, 0.0, traj_offset],
|
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
# sample the source views reproducibly
|
# sample the source views reproducibly
|
||||||
@@ -120,7 +124,6 @@ def render_sequence(
|
|||||||
if visdom_show_preds and (
|
if visdom_show_preds and (
|
||||||
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
||||||
):
|
):
|
||||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
|
||||||
show_predictions(
|
show_predictions(
|
||||||
preds_total,
|
preds_total,
|
||||||
sequence_name=batch.sequence_name[0],
|
sequence_name=batch.sequence_name[0],
|
||||||
@@ -248,7 +251,7 @@ def show_predictions(
|
|||||||
def generate_prediction_videos(
|
def generate_prediction_videos(
|
||||||
preds,
|
preds,
|
||||||
sequence_name,
|
sequence_name,
|
||||||
viz,
|
viz=None,
|
||||||
viz_env="visualizer",
|
viz_env="visualizer",
|
||||||
predicted_keys=(
|
predicted_keys=(
|
||||||
"images_render",
|
"images_render",
|
||||||
@@ -276,13 +279,14 @@ def generate_prediction_videos(
|
|||||||
for rendered_pred in tqdm(preds):
|
for rendered_pred in tqdm(preds):
|
||||||
for k in predicted_keys:
|
for k in predicted_keys:
|
||||||
vws[k].write_frame(
|
vws[k].write_frame(
|
||||||
rendered_pred[k][0].detach().cpu().numpy(),
|
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||||
resize=resize,
|
resize=resize,
|
||||||
)
|
)
|
||||||
|
|
||||||
for k in predicted_keys:
|
for k in predicted_keys:
|
||||||
vws[k].get_video(quiet=True)
|
vws[k].get_video(quiet=True)
|
||||||
print(f"Generated {vws[k].out_path}.")
|
print(f"Generated {vws[k].out_path}.")
|
||||||
|
if viz is not None:
|
||||||
viz.video(
|
viz.video(
|
||||||
videofile=vws[k].out_path,
|
videofile=vws[k].out_path,
|
||||||
env=viz_env,
|
env=viz_env,
|
||||||
@@ -297,7 +301,7 @@ def export_scenes(
|
|||||||
output_directory: Optional[str] = None,
|
output_directory: Optional[str] = None,
|
||||||
render_size: Tuple[int, int] = (512, 512),
|
render_size: Tuple[int, int] = (512, 512),
|
||||||
video_size: Optional[Tuple[int, int]] = None,
|
video_size: Optional[Tuple[int, int]] = None,
|
||||||
split: str = "train", # train | test
|
split: str = "train", # train | val | test
|
||||||
n_source_views: int = 9,
|
n_source_views: int = 9,
|
||||||
n_eval_cameras: int = 40,
|
n_eval_cameras: int = 40,
|
||||||
visdom_server="http://127.0.0.1",
|
visdom_server="http://127.0.0.1",
|
||||||
@@ -325,24 +329,31 @@ def export_scenes(
|
|||||||
config.gpu_idx = gpu_idx
|
config.gpu_idx = gpu_idx
|
||||||
config.exp_dir = exp_dir
|
config.exp_dir = exp_dir
|
||||||
# important so that the CO3D dataset gets loaded in full
|
# important so that the CO3D dataset gets loaded in full
|
||||||
config.dataset_args.test_on_train = False
|
dataset_args = (
|
||||||
|
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
)
|
||||||
|
dataset_args.test_on_train = False
|
||||||
# Set the rendering image size
|
# Set the rendering image size
|
||||||
config.generic_model_args.render_image_width = render_size[0]
|
config.generic_model_args.render_image_width = render_size[0]
|
||||||
config.generic_model_args.render_image_height = render_size[1]
|
config.generic_model_args.render_image_height = render_size[1]
|
||||||
if restrict_sequence_name is not None:
|
if restrict_sequence_name is not None:
|
||||||
config.dataset_args.restrict_sequence_name = restrict_sequence_name
|
dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||||
|
|
||||||
# Set up the CUDA env for the visualization
|
# Set up the CUDA env for the visualization
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
||||||
|
|
||||||
# Load the previously trained model
|
# Load the previously trained model
|
||||||
model, _, _ = init_model(config, force_load=True, load_model_only=True)
|
model, _, _ = init_model(cfg=config, force_load=True, load_model_only=True)
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
dataset = dataset_zoo(**config.dataset_args)[split]
|
datasource = ImplicitronDataSource(**config.data_source_args)
|
||||||
|
dataset_map = datasource.dataset_map_provider.get_dataset_map()
|
||||||
|
dataset = dataset_map[split]
|
||||||
|
if dataset is None:
|
||||||
|
raise ValueError(f"{split} dataset not provided")
|
||||||
|
|
||||||
# iterate over the sequences in the dataset
|
# iterate over the sequences in the dataset
|
||||||
for sequence_name in dataset.sequence_names():
|
for sequence_name in dataset.sequence_names():
|
||||||
|
|||||||
@@ -97,7 +97,7 @@ def generate_eval_video_cameras(
|
|||||||
cam_centers_on_plane.t() @ cam_centers_on_plane
|
cam_centers_on_plane.t() @ cam_centers_on_plane
|
||||||
) / cam_centers_on_plane.shape[0]
|
) / cam_centers_on_plane.shape[0]
|
||||||
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
||||||
traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
|
traj_radius = (cam_centers_on_plane**2).sum(dim=1).sqrt().mean()
|
||||||
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
|
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
|
||||||
traj = traj_radius * torch.stack(
|
traj = traj_radius * torch.stack(
|
||||||
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
|
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
|||||||
# PyTorch version >= 1.8.0
|
# PyTorch version >= 1.8.0
|
||||||
return torch.linalg.solve(A, B)
|
return torch.linalg.solve(A, B)
|
||||||
|
|
||||||
|
# pyre-fixme[16]: `Tuple` has no attribute `solution`.
|
||||||
return torch.solve(B, A).solution
|
return torch.solve(B, A).solution
|
||||||
|
|
||||||
|
|
||||||
@@ -67,9 +68,14 @@ def meshgrid_ij(
|
|||||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||||
"""
|
"""
|
||||||
if (
|
if (
|
||||||
|
# pyre-fixme[16]: Callable `meshgrid` has no attribute `__kwdefaults__`.
|
||||||
torch.meshgrid.__kwdefaults__ is not None
|
torch.meshgrid.__kwdefaults__ is not None
|
||||||
and "indexing" in torch.meshgrid.__kwdefaults__
|
and "indexing" in torch.meshgrid.__kwdefaults__
|
||||||
):
|
):
|
||||||
# PyTorch >= 1.10.0
|
# PyTorch >= 1.10.0
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but
|
||||||
|
# got `Union[Sequence[Tensor], Tensor]`.
|
||||||
return torch.meshgrid(*A, indexing="ij")
|
return torch.meshgrid(*A, indexing="ij")
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
|
||||||
|
# `Union[Sequence[Tensor], Tensor]`.
|
||||||
return torch.meshgrid(*A)
|
return torch.meshgrid(*A)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ def make_device(device: Device) -> torch.device:
|
|||||||
A matching torch.device object
|
A matching torch.device object
|
||||||
"""
|
"""
|
||||||
device = torch.device(device) if isinstance(device, str) else device
|
device = torch.device(device) if isinstance(device, str) else device
|
||||||
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
|
if device.type == "cuda" and device.index is None:
|
||||||
# If cuda but with no index, then the current cuda device is indicated.
|
# If cuda but with no index, then the current cuda device is indicated.
|
||||||
# In that case, we fix to that device
|
# In that case, we fix to that device
|
||||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||||
@@ -71,6 +71,5 @@ elif sys.version_info >= (3, 7, 0):
|
|||||||
def get_args(cls): # pragma: no cover
|
def get_args(cls): # pragma: no cover
|
||||||
return getattr(cls, "__args__", None)
|
return getattr(cls, "__args__", None)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ImportError("This module requires Python 3.7+")
|
raise ImportError("This module requires Python 3.7+")
|
||||||
|
|||||||
@@ -75,12 +75,14 @@ class _SymEig3x3(nn.Module):
|
|||||||
if inputs.shape[-2:] != (3, 3):
|
if inputs.shape[-2:] != (3, 3):
|
||||||
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
|
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
|
||||||
|
|
||||||
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16]
|
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1)
|
||||||
inputs_trace = inputs_diag.sum(-1)
|
inputs_trace = inputs_diag.sum(-1)
|
||||||
q = inputs_trace / 3.0
|
q = inputs_trace / 3.0
|
||||||
|
|
||||||
# Calculate squared sum of elements outside the main diagonal / 2
|
# Calculate squared sum of elements outside the main diagonal / 2
|
||||||
p1 = ((inputs ** 2).sum(dim=(-1, -2)) - (inputs_diag ** 2).sum(-1)) / 2
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
|
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
|
||||||
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
||||||
|
|
||||||
p = torch.sqrt(p2 / 6.0)
|
p = torch.sqrt(p2 / 6.0)
|
||||||
@@ -195,8 +197,9 @@ class _SymEig3x3(nn.Module):
|
|||||||
cross_products[..., :1, :]
|
cross_products[..., :1, :]
|
||||||
)
|
)
|
||||||
|
|
||||||
norms_sq = (cross_products ** 2).sum(dim=-1)
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16]
|
norms_sq = (cross_products**2).sum(dim=-1)
|
||||||
|
max_norms_index = norms_sq.argmax(dim=-1)
|
||||||
|
|
||||||
# Pick only the cross-product with highest squared norm for each input
|
# Pick only the cross-product with highest squared norm for each input
|
||||||
max_cross_products = self._gather_by_index(
|
max_cross_products = self._gather_by_index(
|
||||||
@@ -227,9 +230,7 @@ class _SymEig3x3(nn.Module):
|
|||||||
index_shape = list(source.shape)
|
index_shape = list(source.shape)
|
||||||
index_shape[dim] = 1
|
index_shape[dim] = 1
|
||||||
|
|
||||||
return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16]
|
return source.gather(dim, index.expand(index_shape)).squeeze(dim)
|
||||||
dim
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@@ -243,7 +244,7 @@ class _SymEig3x3(nn.Module):
|
|||||||
Tuple of U and V unit-length vector tensors of shape (..., 3)
|
Tuple of U and V unit-length vector tensors of shape (..., 3)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16]
|
min_idx = w.abs().argmin(dim=-1)
|
||||||
rotation_2d = self._rotations_3d[min_idx].to(w)
|
rotation_2d = self._rotations_3d[min_idx].to(w)
|
||||||
|
|
||||||
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
||||||
|
|||||||
@@ -377,7 +377,7 @@ class R2N2(ShapeNetBase): # pragma: no cover
|
|||||||
view_idxs: Optional[List[int]] = None,
|
view_idxs: Optional[List[int]] = None,
|
||||||
shader_type=HardPhongShader,
|
shader_type=HardPhongShader,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Render models with BlenderCamera by default to achieve the same orientations as the
|
Render models with BlenderCamera by default to achieve the same orientations as the
|
||||||
|
|||||||
@@ -140,7 +140,6 @@ def compute_extrinsic_matrix(
|
|||||||
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
|
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
|
||||||
# roll that rotation into the extrinsic matrix here
|
# roll that rotation into the extrinsic matrix here
|
||||||
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
||||||
# pyre-fixme[16]: `Tensor` has no attribute `mm`.
|
|
||||||
RT = RT.mm(rot.to(RT))
|
RT = RT.mm(rot.to(RT))
|
||||||
|
|
||||||
return RT
|
return RT
|
||||||
@@ -180,6 +179,7 @@ def read_binvox_coords(
|
|||||||
size, translation, scale = _read_binvox_header(f)
|
size, translation, scale = _read_binvox_header(f)
|
||||||
storage = torch.ByteStorage.from_buffer(f.read())
|
storage = torch.ByteStorage.from_buffer(f.read())
|
||||||
data = torch.tensor([], dtype=torch.uint8)
|
data = torch.tensor([], dtype=torch.uint8)
|
||||||
|
# pyre-fixme[28]: Unexpected keyword argument `source`.
|
||||||
data.set_(source=storage)
|
data.set_(source=storage)
|
||||||
vals, counts = data[::2], data[1::2]
|
vals, counts = data[::2], data[1::2]
|
||||||
idxs = _compute_idxs(vals, counts)
|
idxs = _compute_idxs(vals, counts)
|
||||||
@@ -276,7 +276,7 @@ def _read_binvox_header(f): # pragma: no cover
|
|||||||
try:
|
try:
|
||||||
dims = [int(d) for d in dims[1:]]
|
dims = [int(d) for d in dims[1:]]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("Invalid header (line 2)")
|
raise ValueError("Invalid header (line 2)") from None
|
||||||
if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]:
|
if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]:
|
||||||
raise ValueError("Invalid header (line 2)")
|
raise ValueError("Invalid header (line 2)")
|
||||||
size = dims[0]
|
size = dims[0]
|
||||||
@@ -291,7 +291,7 @@ def _read_binvox_header(f): # pragma: no cover
|
|||||||
try:
|
try:
|
||||||
translation = tuple(float(t) for t in translation[1:])
|
translation = tuple(float(t) for t in translation[1:])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("Invalid header (line 3)")
|
raise ValueError("Invalid header (line 3)") from None
|
||||||
|
|
||||||
# Fourth line of the header should be "scale [float]"
|
# Fourth line of the header should be "scale [float]"
|
||||||
line = f.readline().strip()
|
line = f.readline().strip()
|
||||||
|
|||||||
@@ -113,7 +113,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
|||||||
idxs: Optional[List[int]] = None,
|
idxs: Optional[List[int]] = None,
|
||||||
shader_type=HardPhongShader,
|
shader_type=HardPhongShader,
|
||||||
device: Device = "cpu",
|
device: Device = "cpu",
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
If a list of model_ids are supplied, render all the objects by the given model_ids.
|
If a list of model_ids are supplied, render all the objects by the given model_ids.
|
||||||
@@ -227,6 +227,8 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
|||||||
sampled_idxs = self._sample_idxs_from_category(
|
sampled_idxs = self._sample_idxs_from_category(
|
||||||
sample_num=sample_num, category=category
|
sample_num=sample_num, category=category
|
||||||
)
|
)
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor],
|
||||||
|
# typing.Tuple[Tensor, ...]]` but got `Tuple[Tensor, List[int]]`.
|
||||||
idxs_tensor = torch.cat((idxs_tensor, sampled_idxs))
|
idxs_tensor = torch.cat((idxs_tensor, sampled_idxs))
|
||||||
idxs = idxs_tensor.tolist()
|
idxs = idxs_tensor.tolist()
|
||||||
# Check if the indices are valid if idxs are supplied.
|
# Check if the indices are valid if idxs are supplied.
|
||||||
@@ -283,4 +285,5 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
|||||||
"category " + category if category is not None else "all categories",
|
"category " + category if category is not None else "all categories",
|
||||||
)
|
)
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
|
# pyre-fixme[7]: Expected `List[int]` but got `Tensor`.
|
||||||
return sampled_idxs
|
return sampled_idxs
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
|
||||||
|
from .load_blender import load_blender_data
|
||||||
|
from .single_sequence_dataset import (
|
||||||
|
_interpret_blender_cameras,
|
||||||
|
SingleSceneDatasetMapProviderBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||||
|
"""
|
||||||
|
Provides data for one scene from Blender synthetic dataset.
|
||||||
|
Uses the code in load_blender.py
|
||||||
|
|
||||||
|
Members:
|
||||||
|
base_dir: directory holding the data for the scene.
|
||||||
|
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
|
||||||
|
It will typically be equal to the name of the directory self.base_dir.
|
||||||
|
path_manager_factory: Creates path manager which may be used for
|
||||||
|
interpreting paths.
|
||||||
|
n_known_frames_for_test: If set, training frames are included in the val
|
||||||
|
and test datasets, and this many random training frames are added to
|
||||||
|
each test batch. If not set, test batches each contain just a single
|
||||||
|
testing frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_data(self) -> None:
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
images, poses, _, hwf, i_split = load_blender_data(
|
||||||
|
self.base_dir,
|
||||||
|
testskip=1,
|
||||||
|
path_manager=path_manager,
|
||||||
|
)
|
||||||
|
H, W, focal = hwf
|
||||||
|
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.poses = _interpret_blender_cameras(poses, focal)
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.images = images_masks[:, :3]
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.fg_probabilities = images_masks[:, 3:4]
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.i_split = i_split
|
||||||
438
pytorch3d/implicitron/dataset/data_loader_map_provider.py
Normal file
438
pytorch3d/implicitron/dataset/data_loader_map_provider.py
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
|
from torch.utils.data import (
|
||||||
|
BatchSampler,
|
||||||
|
ChainDataset,
|
||||||
|
DataLoader,
|
||||||
|
RandomSampler,
|
||||||
|
Sampler,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .dataset_base import DatasetBase, FrameData
|
||||||
|
from .dataset_map_provider import DatasetMap
|
||||||
|
from .scene_batch_sampler import SceneBatchSampler
|
||||||
|
from .utils import is_known_frame_scalar
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataLoaderMap:
|
||||||
|
"""
|
||||||
|
A collection of data loaders for Implicitron.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
|
||||||
|
train: a data loader for training
|
||||||
|
val: a data loader for validating during training
|
||||||
|
test: a data loader for final evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
train: Optional[DataLoader[FrameData]]
|
||||||
|
val: Optional[DataLoader[FrameData]]
|
||||||
|
test: Optional[DataLoader[FrameData]]
|
||||||
|
|
||||||
|
def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]:
|
||||||
|
"""
|
||||||
|
Get one of the data loaders by key (name of data split)
|
||||||
|
"""
|
||||||
|
if split not in ["train", "val", "test"]:
|
||||||
|
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||||
|
return getattr(self, split)
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoaderMapProviderBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Provider of a collection of data loaders for a given collection of datasets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||||
|
"""
|
||||||
|
Returns a collection of data loaders for a given collection of datasets.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||||
|
"""
|
||||||
|
Batch sampler for making random batches of a single frame
|
||||||
|
from one list and a number of known frames from another list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
first_indices: List[int],
|
||||||
|
rest_indices: List[int],
|
||||||
|
batch_size: int,
|
||||||
|
replacement: bool,
|
||||||
|
num_batches: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
first_indices: indexes of dataset items to use as the first element
|
||||||
|
of each batch.
|
||||||
|
rest_indices: indexes of dataset items to use as the subsequent
|
||||||
|
elements of each batch. Not used if batch_size==1.
|
||||||
|
batch_size: The common size of any batch.
|
||||||
|
replacement: Whether the sampling of first items is with replacement.
|
||||||
|
num_batches: The number of batches in an epoch. If 0 or None,
|
||||||
|
one epoch is the length of `first_indices`.
|
||||||
|
"""
|
||||||
|
self.first_indices = first_indices
|
||||||
|
self.rest_indices = rest_indices
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.replacement = replacement
|
||||||
|
self.num_batches = None if num_batches == 0 else num_batches
|
||||||
|
|
||||||
|
if batch_size - 1 > len(rest_indices):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# copied from RandomSampler
|
||||||
|
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||||
|
self.generator = torch.Generator()
|
||||||
|
self.generator.manual_seed(seed)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if self.num_batches is not None:
|
||||||
|
return self.num_batches
|
||||||
|
return len(self.first_indices)
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[List[int]]:
|
||||||
|
num_batches = self.num_batches
|
||||||
|
if self.replacement:
|
||||||
|
i_first = torch.randint(
|
||||||
|
len(self.first_indices),
|
||||||
|
size=(len(self),),
|
||||||
|
generator=self.generator,
|
||||||
|
)
|
||||||
|
elif num_batches is not None:
|
||||||
|
n_copies = 1 + (num_batches - 1) // len(self.first_indices)
|
||||||
|
raw_indices = [
|
||||||
|
torch.randperm(len(self.first_indices), generator=self.generator)
|
||||||
|
for _ in range(n_copies)
|
||||||
|
]
|
||||||
|
i_first = torch.concat(raw_indices)[:num_batches]
|
||||||
|
else:
|
||||||
|
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
|
||||||
|
first_indices = [self.first_indices[i] for i in i_first]
|
||||||
|
|
||||||
|
if self.batch_size == 1:
|
||||||
|
for first_index in first_indices:
|
||||||
|
yield [first_index]
|
||||||
|
return
|
||||||
|
|
||||||
|
for first_index in first_indices:
|
||||||
|
# Consider using this class in a program which sets the seed. This use
|
||||||
|
# of randperm means that rerunning with a higher batch_size
|
||||||
|
# results in batches whose first elements as the first run.
|
||||||
|
i_rest = torch.randperm(
|
||||||
|
len(self.rest_indices),
|
||||||
|
generator=self.generator,
|
||||||
|
)[: self.batch_size - 1]
|
||||||
|
yield [first_index] + [self.rest_indices[i] for i in i_rest]
|
||||||
|
|
||||||
|
|
||||||
|
class BatchConditioningType(Enum):
|
||||||
|
"""
|
||||||
|
Ways to add conditioning frames for the val and test batches.
|
||||||
|
|
||||||
|
SAME: Use the corresponding dataset for all elements of val batches
|
||||||
|
without regard to frame type.
|
||||||
|
TRAIN: Use the corresponding dataset for the first element of each
|
||||||
|
batch, and the training dataset for the extra conditioning
|
||||||
|
elements. No regard to frame type.
|
||||||
|
KNOWN: Use frames from the corresponding dataset but separate them
|
||||||
|
according to their frame_type. Each batch will contain one UNSEEN
|
||||||
|
frame followed by many KNOWN frames.
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAME = "same"
|
||||||
|
TRAIN = "train"
|
||||||
|
KNOWN = "known"
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||||
|
"""
|
||||||
|
Default implementation of DataLoaderMapProviderBase.
|
||||||
|
|
||||||
|
If a dataset returns batches from get_eval_batches(), then
|
||||||
|
they will be what the corresponding dataloader returns,
|
||||||
|
independently of any of the fields on this class.
|
||||||
|
|
||||||
|
If conditioning is not required, then the batch size should
|
||||||
|
be set as 1, and most of the fields do not matter.
|
||||||
|
|
||||||
|
If conditioning is required, each batch will contain one main
|
||||||
|
frame first to predict and the, rest of the elements are for
|
||||||
|
conditioning.
|
||||||
|
|
||||||
|
If images_per_seq_options is left empty, the conditioning
|
||||||
|
frames are picked according to the conditioning type given.
|
||||||
|
This does not have regard to the order of frames in a
|
||||||
|
scene, or which frames belong to what scene.
|
||||||
|
|
||||||
|
If images_per_seq_options is given, then the conditioning types
|
||||||
|
must be SAME and the remaining fields are used.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
batch_size: The size of the batch of the data loader.
|
||||||
|
num_workers: Number of data-loading threads in each data loader.
|
||||||
|
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the training set.
|
||||||
|
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the validation set.
|
||||||
|
dataset_length_test: The number of batches in a testing epoch. Or 0 to mean
|
||||||
|
an epoch is the length of the test set.
|
||||||
|
train_conditioning_type: Whether the train data loader should use
|
||||||
|
only known frames for conditioning.
|
||||||
|
Only used if batch_size>1 and train dataset is
|
||||||
|
present and does not return eval_batches.
|
||||||
|
val_conditioning_type: Whether the val data loader should use
|
||||||
|
training frames or known frames for conditioning.
|
||||||
|
Only used if batch_size>1 and val dataset is
|
||||||
|
present and does not return eval_batches.
|
||||||
|
test_conditioning_type: Whether the test data loader should use
|
||||||
|
training frames or known frames for conditioning.
|
||||||
|
Only used if batch_size>1 and test dataset is
|
||||||
|
present and does not return eval_batches.
|
||||||
|
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
|
||||||
|
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
|
||||||
|
value. Empty (the default) means that we are not careful about which frames
|
||||||
|
come from which scene.
|
||||||
|
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||||
|
in the sequence. It first sorts the frames by timestimps when available,
|
||||||
|
otherwise by frame numbers, finds the connected segments within the sequence
|
||||||
|
of sufficient length, then samples a random pivot element among them and
|
||||||
|
ideally uses it as a middle of the temporal window, shifting the borders
|
||||||
|
where necessary. This strategy mitigates the bias against shorter segments
|
||||||
|
and their boundaries.
|
||||||
|
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||||
|
difference in frame_number of neighbouring frames when forming connected
|
||||||
|
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||||
|
the whole sequence is considered a segment regardless of frame numbers.
|
||||||
|
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||||
|
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||||
|
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||||
|
the whole sequence is considered a segment regardless of frame timestamps.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size: int = 1
|
||||||
|
num_workers: int = 0
|
||||||
|
dataset_length_train: int = 0
|
||||||
|
dataset_length_val: int = 0
|
||||||
|
dataset_length_test: int = 0
|
||||||
|
train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||||
|
val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||||
|
test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN
|
||||||
|
images_per_seq_options: Tuple[int, ...] = ()
|
||||||
|
sample_consecutive_frames: bool = False
|
||||||
|
consecutive_frames_max_gap: int = 0
|
||||||
|
consecutive_frames_max_gap_seconds: float = 0.1
|
||||||
|
|
||||||
|
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||||
|
"""
|
||||||
|
Returns a collection of data loaders for a given collection of datasets.
|
||||||
|
"""
|
||||||
|
return DataLoaderMap(
|
||||||
|
train=self._make_data_loader(
|
||||||
|
datasets.train,
|
||||||
|
self.dataset_length_train,
|
||||||
|
datasets.train,
|
||||||
|
self.train_conditioning_type,
|
||||||
|
),
|
||||||
|
val=self._make_data_loader(
|
||||||
|
datasets.val,
|
||||||
|
self.dataset_length_val,
|
||||||
|
datasets.train,
|
||||||
|
self.val_conditioning_type,
|
||||||
|
),
|
||||||
|
test=self._make_data_loader(
|
||||||
|
datasets.test,
|
||||||
|
self.dataset_length_test,
|
||||||
|
datasets.train,
|
||||||
|
self.test_conditioning_type,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_data_loader(
|
||||||
|
self,
|
||||||
|
dataset: Optional[DatasetBase],
|
||||||
|
num_batches: int,
|
||||||
|
train_dataset: Optional[DatasetBase],
|
||||||
|
conditioning_type: BatchConditioningType,
|
||||||
|
) -> Optional[DataLoader[FrameData]]:
|
||||||
|
"""
|
||||||
|
Returns the dataloader for a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
train_dataset: the training dataset, used if conditioning_type==TRAIN
|
||||||
|
conditioning_type: source for padding of batches
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data_loader_kwargs = {
|
||||||
|
"num_workers": self.num_workers,
|
||||||
|
"collate_fn": dataset.frame_data_type.collate,
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_batches = dataset.get_eval_batches()
|
||||||
|
if eval_batches is not None:
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=eval_batches,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
scenes_matter = len(self.images_per_seq_options) > 0
|
||||||
|
if scenes_matter and conditioning_type != BatchConditioningType.SAME:
|
||||||
|
raise ValueError(
|
||||||
|
f"{conditioning_type} cannot be used with images_per_seq "
|
||||||
|
+ str(self.images_per_seq_options)
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.batch_size == 1 or (
|
||||||
|
not scenes_matter and conditioning_type == BatchConditioningType.SAME
|
||||||
|
):
|
||||||
|
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
|
||||||
|
|
||||||
|
if scenes_matter:
|
||||||
|
assert conditioning_type == BatchConditioningType.SAME
|
||||||
|
batch_sampler = SceneBatchSampler(
|
||||||
|
dataset,
|
||||||
|
self.batch_size,
|
||||||
|
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||||
|
images_per_seq_options=self.images_per_seq_options,
|
||||||
|
sample_consecutive_frames=self.sample_consecutive_frames,
|
||||||
|
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||||
|
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if conditioning_type == BatchConditioningType.TRAIN:
|
||||||
|
return self._train_loader(
|
||||||
|
dataset, train_dataset, num_batches, data_loader_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
assert conditioning_type == BatchConditioningType.KNOWN
|
||||||
|
return self._known_loader(dataset, num_batches, data_loader_kwargs)
|
||||||
|
|
||||||
|
def _simple_loader(
|
||||||
|
self,
|
||||||
|
dataset: DatasetBase,
|
||||||
|
num_batches: int,
|
||||||
|
data_loader_kwargs: dict,
|
||||||
|
) -> DataLoader[FrameData]:
|
||||||
|
"""
|
||||||
|
Return a simple loader for frames in the dataset.
|
||||||
|
|
||||||
|
This is equivalent to
|
||||||
|
Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs)
|
||||||
|
except that num_batches is fixed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
data_loader_kwargs: common args for dataloader
|
||||||
|
"""
|
||||||
|
if num_batches > 0:
|
||||||
|
num_samples = self.batch_size * num_batches
|
||||||
|
replacement = True
|
||||||
|
else:
|
||||||
|
num_samples = None
|
||||||
|
replacement = False
|
||||||
|
sampler = RandomSampler(
|
||||||
|
dataset, replacement=replacement, num_samples=num_samples
|
||||||
|
)
|
||||||
|
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _train_loader(
|
||||||
|
self,
|
||||||
|
dataset: DatasetBase,
|
||||||
|
train_dataset: Optional[DatasetBase],
|
||||||
|
num_batches: int,
|
||||||
|
data_loader_kwargs: dict,
|
||||||
|
) -> DataLoader[FrameData]:
|
||||||
|
"""
|
||||||
|
Return the loader for TRAIN conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
train_dataset: the training dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
data_loader_kwargs: common args for dataloader
|
||||||
|
"""
|
||||||
|
if train_dataset is None:
|
||||||
|
raise ValueError("No training data for conditioning.")
|
||||||
|
length = len(dataset)
|
||||||
|
first_indices = list(range(length))
|
||||||
|
rest_indices = list(range(length, length + len(train_dataset)))
|
||||||
|
sampler = DoublePoolBatchSampler(
|
||||||
|
first_indices=first_indices,
|
||||||
|
rest_indices=rest_indices,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
replacement=True,
|
||||||
|
num_batches=num_batches,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
ChainDataset([dataset, train_dataset]),
|
||||||
|
batch_sampler=sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _known_loader(
|
||||||
|
self,
|
||||||
|
dataset: DatasetBase,
|
||||||
|
num_batches: int,
|
||||||
|
data_loader_kwargs: dict,
|
||||||
|
) -> DataLoader[FrameData]:
|
||||||
|
"""
|
||||||
|
Return the loader for KNOWN conditioning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
data_loader_kwargs: common args for dataloader
|
||||||
|
"""
|
||||||
|
first_indices, rest_indices = [], []
|
||||||
|
for idx in range(len(dataset)):
|
||||||
|
frame_type = dataset[idx].frame_type
|
||||||
|
assert isinstance(frame_type, str)
|
||||||
|
if is_known_frame_scalar(frame_type):
|
||||||
|
rest_indices.append(idx)
|
||||||
|
else:
|
||||||
|
first_indices.append(idx)
|
||||||
|
sampler = DoublePoolBatchSampler(
|
||||||
|
first_indices=first_indices,
|
||||||
|
rest_indices=rest_indices,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
replacement=True,
|
||||||
|
num_batches=num_batches,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
79
pytorch3d/implicitron/dataset/data_source.py
Normal file
79
pytorch3d/implicitron/dataset/data_source.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
registry,
|
||||||
|
ReplaceableBase,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||||
|
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||||
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||||
|
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||||
|
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||||
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Base class for a data source in Implicitron. It encapsulates Dataset
|
||||||
|
and DataLoader configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
"""
|
||||||
|
If the data is all for a single scene, a list
|
||||||
|
of the known training cameras for that scene, which is
|
||||||
|
used for evaluating the viewpoint difficulty of the
|
||||||
|
unseen cameras.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||||
|
"""
|
||||||
|
Represents the data used in Implicitron. This is the only implementation
|
||||||
|
of DataSourceBase provided.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
dataset_map_provider_class_type: identifies type for dataset_map_provider.
|
||||||
|
e.g. JsonIndexDatasetMapProvider for Co3D.
|
||||||
|
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataset_map_provider: DatasetMapProviderBase
|
||||||
|
dataset_map_provider_class_type: str
|
||||||
|
data_loader_map_provider: DataLoaderMapProviderBase
|
||||||
|
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
run_auto_creation(self)
|
||||||
|
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
||||||
|
|
||||||
|
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||||
|
datasets = self.dataset_map_provider.get_dataset_map()
|
||||||
|
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
||||||
|
return datasets, dataloaders
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
return self.dataset_map_provider.get_task()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
||||||
|
all_train_cameras = self.dataset_map_provider.get_all_train_cameras()
|
||||||
|
self._all_train_cameras_cache = (all_train_cameras,)
|
||||||
|
|
||||||
|
return self._all_train_cameras_cache[0]
|
||||||
@@ -1,100 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
from typing import Dict, Sequence
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
|
||||||
|
|
||||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
|
|
||||||
from .scene_batch_sampler import SceneBatchSampler
|
|
||||||
|
|
||||||
|
|
||||||
def dataloader_zoo(
|
|
||||||
datasets: Dict[str, ImplicitronDatasetBase],
|
|
||||||
dataset_name: str = "co3d_singlesequence",
|
|
||||||
batch_size: int = 1,
|
|
||||||
num_workers: int = 0,
|
|
||||||
dataset_len: int = 1000,
|
|
||||||
dataset_len_val: int = 1,
|
|
||||||
images_per_seq_options: Sequence[int] = (2,),
|
|
||||||
sample_consecutive_frames: bool = False,
|
|
||||||
consecutive_frames_max_gap: int = 0,
|
|
||||||
consecutive_frames_max_gap_seconds: float = 0.1,
|
|
||||||
) -> Dict[str, torch.utils.data.DataLoader]:
|
|
||||||
"""
|
|
||||||
Returns a set of dataloaders for a given set of datasets.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
datasets: A dictionary containing the
|
|
||||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
|
||||||
dataset_name: The name of the returned dataset.
|
|
||||||
batch_size: The size of the batch of the dataloader.
|
|
||||||
num_workers: Number data-loading threads.
|
|
||||||
dataset_len: The number of batches in a training epoch.
|
|
||||||
dataset_len_val: The number of batches in a validation epoch.
|
|
||||||
images_per_seq_options: Possible numbers of images sampled per sequence.
|
|
||||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
|
||||||
in the sequence. It first sorts the frames by timestimps when available,
|
|
||||||
otherwise by frame numbers, finds the connected segments within the sequence
|
|
||||||
of sufficient length, then samples a random pivot element among them and
|
|
||||||
ideally uses it as a middle of the temporal window, shifting the borders
|
|
||||||
where necessary. This strategy mitigates the bias against shorter segments
|
|
||||||
and their boundaries.
|
|
||||||
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
|
||||||
difference in frame_number of neighbouring frames when forming connected
|
|
||||||
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
|
||||||
the whole sequence is considered a segment regardless of frame numbers.
|
|
||||||
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
|
||||||
maximum difference in frame_timestamp of neighbouring frames when forming
|
|
||||||
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
|
||||||
the whole sequence is considered a segment regardless of frame timestamps.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dataloaders: A dictionary containing the
|
|
||||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
|
||||||
"""
|
|
||||||
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
|
|
||||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
|
||||||
|
|
||||||
dataloaders = {}
|
|
||||||
|
|
||||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
|
||||||
for dataset_set, dataset in datasets.items():
|
|
||||||
num_samples = {
|
|
||||||
"train": dataset_len,
|
|
||||||
"val": dataset_len_val,
|
|
||||||
"test": None,
|
|
||||||
}[dataset_set]
|
|
||||||
|
|
||||||
if dataset_set == "test":
|
|
||||||
batch_sampler = dataset.get_eval_batches()
|
|
||||||
else:
|
|
||||||
assert num_samples is not None
|
|
||||||
num_samples = len(dataset) if num_samples <= 0 else num_samples
|
|
||||||
batch_sampler = SceneBatchSampler(
|
|
||||||
dataset,
|
|
||||||
batch_size,
|
|
||||||
num_batches=num_samples,
|
|
||||||
images_per_seq_options=images_per_seq_options,
|
|
||||||
sample_consecutive_frames=sample_consecutive_frames,
|
|
||||||
consecutive_frames_max_gap=consecutive_frames_max_gap,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataloaders[dataset_set] = torch.utils.data.DataLoader(
|
|
||||||
dataset,
|
|
||||||
num_workers=num_workers,
|
|
||||||
batch_sampler=batch_sampler,
|
|
||||||
collate_fn=FrameData.collate,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
|
||||||
|
|
||||||
return dataloaders
|
|
||||||
|
|
||||||
|
|
||||||
enable_get_default_args(dataloader_zoo)
|
|
||||||
306
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
306
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field, fields
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
Iterable,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FrameData(Mapping[str, Any]):
|
||||||
|
"""
|
||||||
|
A type of the elements returned by indexing the dataset object.
|
||||||
|
It can represent both individual frames and batches of thereof;
|
||||||
|
in this documentation, the sizes of tensors refer to single frames;
|
||||||
|
add the first batch dimension for the collation result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
frame_number: The number of the frame within its sequence.
|
||||||
|
0-based continuous integers.
|
||||||
|
sequence_name: The unique name of the frame's sequence.
|
||||||
|
sequence_category: The object category of the sequence.
|
||||||
|
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||||
|
image_size_hw: The size of the image in pixels; (height, width) tensor
|
||||||
|
of shape (2,).
|
||||||
|
image_path: The qualified path to the loaded image (with dataset_root).
|
||||||
|
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||||
|
of the frame; elements are floats in [0, 1].
|
||||||
|
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||||
|
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||||
|
are a result of zero-padding of the image after cropping around
|
||||||
|
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||||
|
depth_path: The qualified path to the frame's depth map.
|
||||||
|
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||||
|
of the frame; values correspond to distances from the camera;
|
||||||
|
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||||
|
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||||
|
depth map that are valid for evaluation, they have been checked for
|
||||||
|
consistency across views; elements are floats in {0.0, 1.0}.
|
||||||
|
mask_path: A qualified path to the foreground probability mask.
|
||||||
|
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||||
|
pixels belonging to the captured object; elements are floats
|
||||||
|
in [0, 1].
|
||||||
|
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
||||||
|
format (x0, y0, width, height). The convention assumes that
|
||||||
|
`x0+width` and `y0+height` includes the boundary of the box.
|
||||||
|
I.e., to slice out the corresponding crop from an image tensor `I`
|
||||||
|
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
||||||
|
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
||||||
|
in the original image coordinates in the format (x0, y0, width, height).
|
||||||
|
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
||||||
|
from `bbox_xywh` due to padding (which can happen e.g. due to
|
||||||
|
setting `JsonIndexDataset.box_crop_context > 0`)
|
||||||
|
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||||
|
corrected for cropping if it happened.
|
||||||
|
camera_quality_score: The score proportional to the confidence of the
|
||||||
|
frame's camera estimation (the higher the more accurate).
|
||||||
|
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||||
|
frame's sequence point cloud (the higher the more accurate).
|
||||||
|
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||||
|
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||||
|
point cloud corresponding to the frame's sequence. When the object
|
||||||
|
represents a batch of frames, point clouds may be deduplicated;
|
||||||
|
see `sequence_point_cloud_idx`.
|
||||||
|
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||||
|
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||||
|
corresponding point cloud to `image_rgb[i]`, use
|
||||||
|
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||||
|
frame_type: The type of the loaded frame specified in
|
||||||
|
`subset_lists_file`, if provided.
|
||||||
|
meta: A dict for storing additional frame information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
frame_number: Optional[torch.LongTensor]
|
||||||
|
sequence_name: Union[str, List[str]]
|
||||||
|
sequence_category: Union[str, List[str]]
|
||||||
|
frame_timestamp: Optional[torch.Tensor] = None
|
||||||
|
image_size_hw: Optional[torch.Tensor] = None
|
||||||
|
image_path: Union[str, List[str], None] = None
|
||||||
|
image_rgb: Optional[torch.Tensor] = None
|
||||||
|
# masks out padding added due to cropping the square bit
|
||||||
|
mask_crop: Optional[torch.Tensor] = None
|
||||||
|
depth_path: Union[str, List[str], None] = None
|
||||||
|
depth_map: Optional[torch.Tensor] = None
|
||||||
|
depth_mask: Optional[torch.Tensor] = None
|
||||||
|
mask_path: Union[str, List[str], None] = None
|
||||||
|
fg_probability: Optional[torch.Tensor] = None
|
||||||
|
bbox_xywh: Optional[torch.Tensor] = None
|
||||||
|
crop_bbox_xywh: Optional[torch.Tensor] = None
|
||||||
|
camera: Optional[PerspectiveCameras] = None
|
||||||
|
camera_quality_score: Optional[torch.Tensor] = None
|
||||||
|
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||||
|
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||||
|
sequence_point_cloud: Optional[Pointclouds] = None
|
||||||
|
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||||
|
frame_type: Union[str, List[str], None] = None # known | unseen
|
||||||
|
meta: dict = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
new_params = {}
|
||||||
|
for f in fields(self):
|
||||||
|
value = getattr(self, f.name)
|
||||||
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||||
|
new_params[f.name] = value.to(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
new_params[f.name] = value
|
||||||
|
return type(self)(**new_params)
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
return self.to(device=torch.device("cpu"))
|
||||||
|
|
||||||
|
def cuda(self):
|
||||||
|
return self.to(device=torch.device("cuda"))
|
||||||
|
|
||||||
|
# the following functions make sure **frame_data can be passed to functions
|
||||||
|
def __iter__(self):
|
||||||
|
for f in fields(self):
|
||||||
|
yield f.name
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(fields(self))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def collate(cls, batch):
|
||||||
|
"""
|
||||||
|
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||||
|
representation suitable for processing with deep networks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
elem = batch[0]
|
||||||
|
|
||||||
|
if isinstance(elem, cls):
|
||||||
|
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||||
|
id_to_idx = defaultdict(list)
|
||||||
|
for i, pc_id in enumerate(pointcloud_ids):
|
||||||
|
id_to_idx[pc_id].append(i)
|
||||||
|
|
||||||
|
sequence_point_cloud = []
|
||||||
|
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||||
|
for i, ind in enumerate(id_to_idx.values()):
|
||||||
|
sequence_point_cloud_idx[ind] = i
|
||||||
|
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||||
|
assert (sequence_point_cloud_idx >= 0).all()
|
||||||
|
|
||||||
|
override_fields = {
|
||||||
|
"sequence_point_cloud": sequence_point_cloud,
|
||||||
|
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||||
|
}
|
||||||
|
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||||
|
|
||||||
|
collated = {}
|
||||||
|
for f in fields(elem):
|
||||||
|
list_values = override_fields.get(
|
||||||
|
f.name, [getattr(d, f.name) for d in batch]
|
||||||
|
)
|
||||||
|
collated[f.name] = (
|
||||||
|
cls.collate(list_values)
|
||||||
|
if all(list_value is not None for list_value in list_values)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
return cls(**collated)
|
||||||
|
|
||||||
|
elif isinstance(elem, Pointclouds):
|
||||||
|
return join_pointclouds_as_batch(batch)
|
||||||
|
|
||||||
|
elif isinstance(elem, CamerasBase):
|
||||||
|
# TODO: don't store K; enforce working in NDC space
|
||||||
|
return join_cameras_as_batch(batch)
|
||||||
|
else:
|
||||||
|
return torch.utils.data._utils.collate.default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class _GenericWorkaround:
|
||||||
|
"""
|
||||||
|
OmegaConf.structured has a weirdness when you try to apply
|
||||||
|
it to a dataclass whose first base class is a Generic which is not
|
||||||
|
Dict. The issue is with a function called get_dict_key_value_types
|
||||||
|
in omegaconf/_utils.py.
|
||||||
|
For example this fails:
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class D(torch.utils.data.Dataset[int]):
|
||||||
|
a: int = 3
|
||||||
|
|
||||||
|
OmegaConf.structured(D)
|
||||||
|
|
||||||
|
We avoid the problem by adding this class as an extra base class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||||
|
"""
|
||||||
|
Base class to describe a dataset to be used with Implicitron.
|
||||||
|
|
||||||
|
The dataset is made up of frames, and the frames are grouped into sequences.
|
||||||
|
Each sequence has a name (a string).
|
||||||
|
(A sequence could be a video, or a set of images of one scene.)
|
||||||
|
|
||||||
|
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||||
|
which will describe one frame in one sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# _seq_to_idx is a member which implementations can define.
|
||||||
|
# It maps sequence name to the sequence's global frame indices.
|
||||||
|
# It is used for the default implementations of some functions in this class.
|
||||||
|
# Implementations which override them are free to ignore it.
|
||||||
|
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_frame_numbers_and_timestamps(
|
||||||
|
self, idxs: Sequence[int]
|
||||||
|
) -> List[Tuple[int, float]]:
|
||||||
|
"""
|
||||||
|
If the sequences in the dataset are videos rather than
|
||||||
|
unordered views, then the dataset should override this method to
|
||||||
|
return the index and timestamp in their videos of the frames whose
|
||||||
|
indices are given in `idxs`. In addition,
|
||||||
|
the values in _seq_to_idx should be in ascending order.
|
||||||
|
If timestamps are absent, they should be replaced with a constant.
|
||||||
|
|
||||||
|
This is used for letting SceneBatchSampler identify consecutive
|
||||||
|
frames.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx: frame index in self
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple of
|
||||||
|
- frame index in video
|
||||||
|
- timestamp of frame in video
|
||||||
|
"""
|
||||||
|
raise ValueError("This dataset does not contain videos.")
|
||||||
|
|
||||||
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def sequence_names(self) -> Iterable[str]:
|
||||||
|
"""Returns an iterator over sequence names in the dataset."""
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return self._seq_to_idx.keys()
|
||||||
|
|
||||||
|
def sequence_frames_in_order(
|
||||||
|
self, seq_name: str
|
||||||
|
) -> Iterator[Tuple[float, int, int]]:
|
||||||
|
"""Returns an iterator over the frame indices in a given sequence.
|
||||||
|
We attempt to first sort by timestamp (if they are available),
|
||||||
|
then by frame number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_name: the name of the sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||||
|
where `frame_no` is the index within the sequence, and
|
||||||
|
`dataset_idx` is the index within the dataset.
|
||||||
|
`None` timestamps are replaced with 0s.
|
||||||
|
"""
|
||||||
|
# pyre-ignore[16]
|
||||||
|
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||||
|
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||||
|
|
||||||
|
yield from sorted(
|
||||||
|
[
|
||||||
|
(timestamp, frame_no, idx)
|
||||||
|
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||||
|
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||||
|
only dataset indices.
|
||||||
|
"""
|
||||||
|
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||||
|
yield idx
|
||||||
|
|
||||||
|
# frame_data_type is the actual type of frames returned by the dataset.
|
||||||
|
# Collation uses its classmethod `collate`
|
||||||
|
frame_data_type: ClassVar[Type[FrameData]] = FrameData
|
||||||
120
pytorch3d/implicitron/dataset/dataset_map_provider.py
Normal file
120
pytorch3d/implicitron/dataset/dataset_map_provider.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
|
from iopath.common.file_io import PathManager
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .dataset_base import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetMap:
|
||||||
|
"""
|
||||||
|
A collection of datasets for implicitron.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
|
||||||
|
train: a dataset for training
|
||||||
|
val: a dataset for validating during training
|
||||||
|
test: a dataset for final evaluation
|
||||||
|
"""
|
||||||
|
|
||||||
|
train: Optional[DatasetBase]
|
||||||
|
val: Optional[DatasetBase]
|
||||||
|
test: Optional[DatasetBase]
|
||||||
|
|
||||||
|
def __getitem__(self, split: str) -> Optional[DatasetBase]:
|
||||||
|
"""
|
||||||
|
Get one of the datasets by key (name of data split)
|
||||||
|
"""
|
||||||
|
if split not in ["train", "val", "test"]:
|
||||||
|
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||||
|
return getattr(self, split)
|
||||||
|
|
||||||
|
def iter_datasets(self) -> Iterator[DatasetBase]:
|
||||||
|
"""
|
||||||
|
Iterator over all datasets.
|
||||||
|
"""
|
||||||
|
if self.train is not None:
|
||||||
|
yield self.train
|
||||||
|
if self.val is not None:
|
||||||
|
yield self.val
|
||||||
|
if self.test is not None:
|
||||||
|
yield self.test
|
||||||
|
|
||||||
|
|
||||||
|
class Task(Enum):
|
||||||
|
SINGLE_SEQUENCE = "singlesequence"
|
||||||
|
MULTI_SEQUENCE = "multisequence"
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetMapProviderBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Base class for a provider of training / validation and testing
|
||||||
|
dataset objects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
An object containing the torch.Dataset objects in train/val/test fields.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
"""
|
||||||
|
If the data is all for a single scene, returns a list
|
||||||
|
of the known training cameras for that scene, which is
|
||||||
|
used for evaluating the difficulty of the unknown
|
||||||
|
cameras. Otherwise return None.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class PathManagerFactory(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Base class and default implementation of a tool which dataset_map_provider implementations
|
||||||
|
may use to construct a path manager if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
silence_logs: Whether to reduce log output from iopath library.
|
||||||
|
"""
|
||||||
|
|
||||||
|
silence_logs: bool = True
|
||||||
|
|
||||||
|
def get(self) -> Optional[PathManager]:
|
||||||
|
"""
|
||||||
|
Makes a PathManager if needed.
|
||||||
|
For open source users, this function should always return None.
|
||||||
|
Internally, this allows manifold access.
|
||||||
|
"""
|
||||||
|
if os.environ.get("INSIDE_RE_WORKER", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from iopath.fb.manifold import ManifoldPathHandler
|
||||||
|
except ImportError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if self.silence_logs:
|
||||||
|
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
|
||||||
|
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
path_manager = PathManager()
|
||||||
|
path_manager.register_handler(ManifoldPathHandler())
|
||||||
|
|
||||||
|
return path_manager
|
||||||
@@ -1,263 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Sequence
|
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
|
||||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
|
||||||
|
|
||||||
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
|
|
||||||
from .utils import (
|
|
||||||
DATASET_TYPE_KNOWN,
|
|
||||||
DATASET_TYPE_TEST,
|
|
||||||
DATASET_TYPE_TRAIN,
|
|
||||||
DATASET_TYPE_UNKNOWN,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO from dataset.dataset_configs import DATASET_CONFIGS
|
|
||||||
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
|
|
||||||
"default": {
|
|
||||||
"box_crop": True,
|
|
||||||
"box_crop_context": 0.3,
|
|
||||||
"image_width": 800,
|
|
||||||
"image_height": 800,
|
|
||||||
"remove_empty_masks": True,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
CO3D_CATEGORIES: List[str] = list(reversed([
|
|
||||||
"baseballbat", "banana", "bicycle", "microwave", "tv",
|
|
||||||
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
|
|
||||||
"umbrella", "wineglass", "laptop",
|
|
||||||
"hotdog", "stopsign", "frisbee", "baseballglove",
|
|
||||||
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
|
|
||||||
"handbag", "chair", "keyboard", "car", "motorcycle",
|
|
||||||
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
|
|
||||||
"toaster", "mouse", "toytrain", "book", "toytruck",
|
|
||||||
"orange", "broccoli", "plant", "teddybear",
|
|
||||||
"suitcase", "bench", "ball", "cake",
|
|
||||||
"vase", "hydrant", "apple", "donut",
|
|
||||||
]))
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
|
||||||
|
|
||||||
|
|
||||||
def dataset_zoo(
|
|
||||||
dataset_name: str = "co3d_singlesequence",
|
|
||||||
dataset_root: str = _CO3D_DATASET_ROOT,
|
|
||||||
category: str = "DEFAULT",
|
|
||||||
limit_to: int = -1,
|
|
||||||
limit_sequences_to: int = -1,
|
|
||||||
n_frames_per_sequence: int = -1,
|
|
||||||
test_on_train: bool = False,
|
|
||||||
load_point_clouds: bool = False,
|
|
||||||
mask_images: bool = False,
|
|
||||||
mask_depths: bool = False,
|
|
||||||
restrict_sequence_name: Sequence[str] = (),
|
|
||||||
test_restrict_sequence_id: int = -1,
|
|
||||||
assert_single_seq: bool = False,
|
|
||||||
only_test_set: bool = False,
|
|
||||||
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
|
|
||||||
path_manager: Optional[PathManager] = None,
|
|
||||||
) -> Dict[str, ImplicitronDatasetBase]:
|
|
||||||
"""
|
|
||||||
Generates the training / validation and testing dataset objects.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset_name: The name of the returned dataset.
|
|
||||||
dataset_root: The root folder of the dataset.
|
|
||||||
category: The object category of the dataset.
|
|
||||||
limit_to: Limit the dataset to the first #limit_to frames.
|
|
||||||
limit_sequences_to: Limit the dataset to the first
|
|
||||||
#limit_sequences_to sequences.
|
|
||||||
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
|
|
||||||
in each sequence.
|
|
||||||
test_on_train: Construct validation and test datasets from
|
|
||||||
the training subset.
|
|
||||||
load_point_clouds: Enable returning scene point clouds from the dataset.
|
|
||||||
mask_images: Mask the loaded images with segmentation masks.
|
|
||||||
mask_depths: Mask the loaded depths with segmentation masks.
|
|
||||||
restrict_sequence_name: Restrict the dataset sequences to the ones
|
|
||||||
present in the given list of names.
|
|
||||||
test_restrict_sequence_id: The ID of the loaded sequence.
|
|
||||||
Active for dataset_name='co3d_singlesequence'.
|
|
||||||
assert_single_seq: Assert that only frames from a single sequence
|
|
||||||
are present in all generated datasets.
|
|
||||||
only_test_set: Load only the test set.
|
|
||||||
aux_dataset_kwargs: Specifies additional arguments to the
|
|
||||||
ImplicitronDataset constructor call.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
datasets: A dictionary containing the
|
|
||||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
|
||||||
"""
|
|
||||||
datasets = {}
|
|
||||||
|
|
||||||
# TODO:
|
|
||||||
# - implement loading multiple categories
|
|
||||||
|
|
||||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
|
||||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
|
||||||
# to the names of the subsets in the CO3D dataset.
|
|
||||||
set_names_mapping = _get_co3d_set_names_mapping(
|
|
||||||
dataset_name,
|
|
||||||
test_on_train,
|
|
||||||
only_test_set,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load the evaluation batches
|
|
||||||
task = dataset_name.split("_")[-1]
|
|
||||||
batch_indices_path = os.path.join(
|
|
||||||
dataset_root,
|
|
||||||
category,
|
|
||||||
f"eval_batches_{task}.json",
|
|
||||||
)
|
|
||||||
if not os.path.isfile(batch_indices_path):
|
|
||||||
# The batch indices file does not exist.
|
|
||||||
# Most probably the user has not specified the root folder.
|
|
||||||
raise ValueError("Please specify a correct dataset_root folder.")
|
|
||||||
|
|
||||||
with open(batch_indices_path, "r") as f:
|
|
||||||
eval_batch_index = json.load(f)
|
|
||||||
|
|
||||||
if task == "singlesequence":
|
|
||||||
assert (
|
|
||||||
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
|
|
||||||
), (
|
|
||||||
"Please specify an integer id 'test_restrict_sequence_id'"
|
|
||||||
+ " of the sequence considered for 'singlesequence'"
|
|
||||||
+ " training and evaluation."
|
|
||||||
)
|
|
||||||
assert len(restrict_sequence_name) == 0, (
|
|
||||||
"For the 'singlesequence' task, the restrict_sequence_name has"
|
|
||||||
" to be unset while test_restrict_sequence_id has to be set to an"
|
|
||||||
" integer defining the order of the evaluation sequence."
|
|
||||||
)
|
|
||||||
# a sort-stable set() equivalent:
|
|
||||||
eval_batches_sequence_names = list(
|
|
||||||
{b[0][0]: None for b in eval_batch_index}.keys()
|
|
||||||
)
|
|
||||||
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
|
|
||||||
eval_batch_index = [
|
|
||||||
b for b in eval_batch_index if b[0][0] == eval_sequence_name
|
|
||||||
]
|
|
||||||
# overwrite the restrict_sequence_name
|
|
||||||
restrict_sequence_name = [eval_sequence_name]
|
|
||||||
|
|
||||||
for dataset, subsets in set_names_mapping.items():
|
|
||||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
|
||||||
assert os.path.isfile(frame_file)
|
|
||||||
|
|
||||||
sequence_file = os.path.join(
|
|
||||||
dataset_root, category, "sequence_annotations.jgz"
|
|
||||||
)
|
|
||||||
assert os.path.isfile(sequence_file)
|
|
||||||
|
|
||||||
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
|
||||||
assert os.path.isfile(subset_lists_file)
|
|
||||||
|
|
||||||
# TODO: maybe directly in param list
|
|
||||||
params = {
|
|
||||||
**copy.deepcopy(aux_dataset_kwargs),
|
|
||||||
"frame_annotations_file": frame_file,
|
|
||||||
"sequence_annotations_file": sequence_file,
|
|
||||||
"subset_lists_file": subset_lists_file,
|
|
||||||
"dataset_root": dataset_root,
|
|
||||||
"limit_to": limit_to,
|
|
||||||
"limit_sequences_to": limit_sequences_to,
|
|
||||||
"n_frames_per_sequence": n_frames_per_sequence
|
|
||||||
if dataset == "train"
|
|
||||||
else -1,
|
|
||||||
"subsets": subsets,
|
|
||||||
"load_point_clouds": load_point_clouds,
|
|
||||||
"mask_images": mask_images,
|
|
||||||
"mask_depths": mask_depths,
|
|
||||||
"pick_sequence": restrict_sequence_name,
|
|
||||||
"path_manager": path_manager,
|
|
||||||
}
|
|
||||||
|
|
||||||
datasets[dataset] = ImplicitronDataset(**params)
|
|
||||||
if dataset == "test":
|
|
||||||
if len(restrict_sequence_name) > 0:
|
|
||||||
eval_batch_index = [
|
|
||||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
|
||||||
]
|
|
||||||
|
|
||||||
datasets[dataset].eval_batches = datasets[
|
|
||||||
dataset
|
|
||||||
].seq_frame_index_to_dataset_index(eval_batch_index)
|
|
||||||
|
|
||||||
if assert_single_seq:
|
|
||||||
# check theres only one sequence in all datasets
|
|
||||||
assert (
|
|
||||||
len(
|
|
||||||
{
|
|
||||||
e["frame_annotation"].sequence_name
|
|
||||||
for dset in datasets.values()
|
|
||||||
for e in dset.frame_annots
|
|
||||||
}
|
|
||||||
)
|
|
||||||
<= 1
|
|
||||||
), "Multiple sequences loaded but expected one"
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
|
||||||
|
|
||||||
if test_on_train:
|
|
||||||
datasets["val"] = datasets["train"]
|
|
||||||
datasets["test"] = datasets["train"]
|
|
||||||
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
enable_get_default_args(dataset_zoo)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_co3d_set_names_mapping(
|
|
||||||
dataset_name: str,
|
|
||||||
test_on_train: bool,
|
|
||||||
only_test: bool,
|
|
||||||
) -> Dict[str, List[str]]:
|
|
||||||
"""
|
|
||||||
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
|
|
||||||
to the names of the corresponding subsets in the CO3D dataset
|
|
||||||
("test_known"/"test_unseen"/"train_known"/"train_unseen").
|
|
||||||
"""
|
|
||||||
single_seq = dataset_name == "co3d_singlesequence"
|
|
||||||
|
|
||||||
if only_test:
|
|
||||||
set_names_mapping = {}
|
|
||||||
else:
|
|
||||||
set_names_mapping = {
|
|
||||||
"train": [
|
|
||||||
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
|
|
||||||
+ "_"
|
|
||||||
+ DATASET_TYPE_KNOWN
|
|
||||||
]
|
|
||||||
}
|
|
||||||
if not test_on_train:
|
|
||||||
prefixes = [DATASET_TYPE_TEST]
|
|
||||||
if not single_seq:
|
|
||||||
prefixes.append(DATASET_TYPE_TRAIN)
|
|
||||||
set_names_mapping.update(
|
|
||||||
{
|
|
||||||
dset: [
|
|
||||||
p + "_" + t
|
|
||||||
for p in prefixes
|
|
||||||
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
|
|
||||||
]
|
|
||||||
for dset in ["val", "test"]
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return set_names_mapping
|
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
@@ -13,17 +14,12 @@ import os
|
|||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field, fields
|
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
List,
|
||||||
Mapping,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
@@ -34,270 +30,31 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.io import IO
|
from pytorch3d.io import IO
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
|
||||||
from . import types
|
from . import types
|
||||||
|
from .dataset_base import DatasetBase, FrameData
|
||||||
|
from .utils import is_known_frame_scalar
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FrameData(Mapping[str, Any]):
|
|
||||||
"""
|
|
||||||
A type of the elements returned by indexing the dataset object.
|
|
||||||
It can represent both individual frames and batches of thereof;
|
|
||||||
in this documentation, the sizes of tensors refer to single frames;
|
|
||||||
add the first batch dimension for the collation result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
frame_number: The number of the frame within its sequence.
|
|
||||||
0-based continuous integers.
|
|
||||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
|
||||||
sequence_name: The unique name of the frame's sequence.
|
|
||||||
sequence_category: The object category of the sequence.
|
|
||||||
image_size_hw: The size of the image in pixels; (height, width) tuple.
|
|
||||||
image_path: The qualified path to the loaded image (with dataset_root).
|
|
||||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
|
||||||
of the frame; elements are floats in [0, 1].
|
|
||||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
|
||||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
|
||||||
are a result of zero-padding of the image after cropping around
|
|
||||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
|
||||||
depth_path: The qualified path to the frame's depth map.
|
|
||||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
|
||||||
of the frame; values correspond to distances from the camera;
|
|
||||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
|
||||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
|
||||||
depth map that are valid for evaluation, they have been checked for
|
|
||||||
consistency across views; elements are floats in {0.0, 1.0}.
|
|
||||||
mask_path: A qualified path to the foreground probability mask.
|
|
||||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
|
||||||
pixels belonging to the captured object; elements are floats
|
|
||||||
in [0, 1].
|
|
||||||
bbox_xywh: The bounding box capturing the object in the
|
|
||||||
format (x0, y0, width, height).
|
|
||||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
|
||||||
corrected for cropping if it happened.
|
|
||||||
camera_quality_score: The score proportional to the confidence of the
|
|
||||||
frame's camera estimation (the higher the more accurate).
|
|
||||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
|
||||||
frame's sequence point cloud (the higher the more accurate).
|
|
||||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
|
||||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
|
||||||
point cloud corresponding to the frame's sequence. When the object
|
|
||||||
represents a batch of frames, point clouds may be deduplicated;
|
|
||||||
see `sequence_point_cloud_idx`.
|
|
||||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
|
||||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
|
||||||
corresponding point cloud to `image_rgb[i]`, use
|
|
||||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
|
||||||
frame_type: The type of the loaded frame specified in
|
|
||||||
`subset_lists_file`, if provided.
|
|
||||||
meta: A dict for storing additional frame information.
|
|
||||||
"""
|
|
||||||
|
|
||||||
frame_number: Optional[torch.LongTensor]
|
|
||||||
frame_timestamp: Optional[torch.Tensor]
|
|
||||||
sequence_name: Union[str, List[str]]
|
|
||||||
sequence_category: Union[str, List[str]]
|
|
||||||
image_size_hw: Optional[torch.Tensor] = None
|
|
||||||
image_path: Union[str, List[str], None] = None
|
|
||||||
image_rgb: Optional[torch.Tensor] = None
|
|
||||||
# masks out padding added due to cropping the square bit
|
|
||||||
mask_crop: Optional[torch.Tensor] = None
|
|
||||||
depth_path: Union[str, List[str], None] = None
|
|
||||||
depth_map: Optional[torch.Tensor] = None
|
|
||||||
depth_mask: Optional[torch.Tensor] = None
|
|
||||||
mask_path: Union[str, List[str], None] = None
|
|
||||||
fg_probability: Optional[torch.Tensor] = None
|
|
||||||
bbox_xywh: Optional[torch.Tensor] = None
|
|
||||||
camera: Optional[PerspectiveCameras] = None
|
|
||||||
camera_quality_score: Optional[torch.Tensor] = None
|
|
||||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
|
||||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
|
||||||
sequence_point_cloud: Optional[Pointclouds] = None
|
|
||||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
|
||||||
frame_type: Union[str, List[str], None] = None # seen | unseen
|
|
||||||
meta: dict = field(default_factory=lambda: {})
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
new_params = {}
|
|
||||||
for f in fields(self):
|
|
||||||
value = getattr(self, f.name)
|
|
||||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
|
||||||
new_params[f.name] = value.to(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
new_params[f.name] = value
|
|
||||||
return type(self)(**new_params)
|
|
||||||
|
|
||||||
def cpu(self):
|
|
||||||
return self.to(device=torch.device("cpu"))
|
|
||||||
|
|
||||||
def cuda(self):
|
|
||||||
return self.to(device=torch.device("cuda"))
|
|
||||||
|
|
||||||
# the following functions make sure **frame_data can be passed to functions
|
|
||||||
def __iter__(self):
|
|
||||||
for f in fields(self):
|
|
||||||
yield f.name
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(fields(self))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def collate(cls, batch):
|
|
||||||
"""
|
|
||||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
|
||||||
representation suitable for processing with deep networks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
elem = batch[0]
|
|
||||||
|
|
||||||
if isinstance(elem, cls):
|
|
||||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
|
||||||
id_to_idx = defaultdict(list)
|
|
||||||
for i, pc_id in enumerate(pointcloud_ids):
|
|
||||||
id_to_idx[pc_id].append(i)
|
|
||||||
|
|
||||||
sequence_point_cloud = []
|
|
||||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
|
||||||
for i, ind in enumerate(id_to_idx.values()):
|
|
||||||
sequence_point_cloud_idx[ind] = i
|
|
||||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
|
||||||
assert (sequence_point_cloud_idx >= 0).all()
|
|
||||||
|
|
||||||
override_fields = {
|
|
||||||
"sequence_point_cloud": sequence_point_cloud,
|
|
||||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
|
||||||
}
|
|
||||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
|
||||||
|
|
||||||
collated = {}
|
|
||||||
for f in fields(elem):
|
|
||||||
list_values = override_fields.get(
|
|
||||||
f.name, [getattr(d, f.name) for d in batch]
|
|
||||||
)
|
|
||||||
collated[f.name] = (
|
|
||||||
cls.collate(list_values)
|
|
||||||
if all(list_value is not None for list_value in list_values)
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return cls(**collated)
|
|
||||||
|
|
||||||
elif isinstance(elem, Pointclouds):
|
|
||||||
return join_pointclouds_as_batch(batch)
|
|
||||||
|
|
||||||
elif isinstance(elem, CamerasBase):
|
|
||||||
# TODO: don't store K; enforce working in NDC space
|
|
||||||
return join_cameras_as_batch(batch)
|
|
||||||
else:
|
|
||||||
return torch.utils.data._utils.collate.default_collate(batch)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
|
||||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
|
||||||
"""
|
|
||||||
Base class to describe a dataset to be used with Implicitron.
|
|
||||||
|
|
||||||
The dataset is made up of frames, and the frames are grouped into sequences.
|
|
||||||
Each sequence has a name (a string).
|
|
||||||
(A sequence could be a video, or a set of images of one scene.)
|
|
||||||
|
|
||||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
|
||||||
which will describe one frame in one sequence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Maps sequence name to the sequence's global frame indices.
|
|
||||||
# It is used for the default implementations of some functions in this class.
|
|
||||||
# Implementations which override them are free to ignore this member.
|
|
||||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def get_frame_numbers_and_timestamps(
|
|
||||||
self, idxs: Sequence[int]
|
|
||||||
) -> List[Tuple[int, float]]:
|
|
||||||
"""
|
|
||||||
If the sequences in the dataset are videos rather than
|
|
||||||
unordered views, then the dataset should override this method to
|
|
||||||
return the index and timestamp in their videos of the frames whose
|
|
||||||
indices are given in `idxs`. In addition,
|
|
||||||
the values in _seq_to_idx should be in ascending order.
|
|
||||||
If timestamps are absent, they should be replaced with a constant.
|
|
||||||
|
|
||||||
This is used for letting SceneBatchSampler identify consecutive
|
|
||||||
frames.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
idx: frame index in self
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple of
|
|
||||||
- frame index in video
|
|
||||||
- timestamp of frame in video
|
|
||||||
"""
|
|
||||||
raise ValueError("This dataset does not contain videos.")
|
|
||||||
|
|
||||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def sequence_names(self) -> Iterable[str]:
|
|
||||||
"""Returns an iterator over sequence names in the dataset."""
|
|
||||||
return self._seq_to_idx.keys()
|
|
||||||
|
|
||||||
def sequence_frames_in_order(
|
|
||||||
self, seq_name: str
|
|
||||||
) -> Iterator[Tuple[float, int, int]]:
|
|
||||||
"""Returns an iterator over the frame indices in a given sequence.
|
|
||||||
We attempt to first sort by timestamp (if they are available),
|
|
||||||
then by frame number.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seq_name: the name of the sequence.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
|
||||||
where `frame_no` is the index within the sequence, and
|
|
||||||
`dataset_idx` is the index within the dataset.
|
|
||||||
`None` timestamps are replaced with 0s.
|
|
||||||
"""
|
|
||||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
|
||||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
|
||||||
|
|
||||||
yield from sorted(
|
|
||||||
[
|
|
||||||
(timestamp, frame_no, idx)
|
|
||||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
|
||||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
|
||||||
only dataset indices.
|
|
||||||
"""
|
|
||||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
|
||||||
yield idx
|
|
||||||
|
|
||||||
|
|
||||||
class FrameAnnotsEntry(TypedDict):
|
class FrameAnnotsEntry(TypedDict):
|
||||||
subset: Optional[str]
|
subset: Optional[str]
|
||||||
frame_annotation: types.FrameAnnotation
|
frame_annotation: types.FrameAnnotation
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@registry.register
|
||||||
class ImplicitronDataset(ImplicitronDatasetBase):
|
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
A class for the Common Objects in 3D (CO3D) dataset.
|
A dataset with annotations in json files like the Common Objects in 3D
|
||||||
|
(CO3D) dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame_annotations_file: A zipped json file containing metadata of the
|
frame_annotations_file: A zipped json file containing metadata of the
|
||||||
@@ -361,16 +118,16 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
Type[types.FrameAnnotation]
|
Type[types.FrameAnnotation]
|
||||||
] = types.FrameAnnotation
|
] = types.FrameAnnotation
|
||||||
|
|
||||||
path_manager: Optional[PathManager] = None
|
path_manager: Any = None
|
||||||
frame_annotations_file: str = ""
|
frame_annotations_file: str = ""
|
||||||
sequence_annotations_file: str = ""
|
sequence_annotations_file: str = ""
|
||||||
subset_lists_file: str = ""
|
subset_lists_file: str = ""
|
||||||
subsets: Optional[List[str]] = None
|
subsets: Optional[List[str]] = None
|
||||||
limit_to: int = 0
|
limit_to: int = 0
|
||||||
limit_sequences_to: int = 0
|
limit_sequences_to: int = 0
|
||||||
pick_sequence: Sequence[str] = ()
|
pick_sequence: Tuple[str, ...] = ()
|
||||||
exclude_sequence: Sequence[str] = ()
|
exclude_sequence: Tuple[str, ...] = ()
|
||||||
limit_category_to: Sequence[int] = ()
|
limit_category_to: Tuple[int, ...] = ()
|
||||||
dataset_root: str = ""
|
dataset_root: str = ""
|
||||||
load_images: bool = True
|
load_images: bool = True
|
||||||
load_depths: bool = True
|
load_depths: bool = True
|
||||||
@@ -380,21 +137,21 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
max_points: int = 0
|
max_points: int = 0
|
||||||
mask_images: bool = False
|
mask_images: bool = False
|
||||||
mask_depths: bool = False
|
mask_depths: bool = False
|
||||||
image_height: Optional[int] = 256
|
image_height: Optional[int] = 800
|
||||||
image_width: Optional[int] = 256
|
image_width: Optional[int] = 800
|
||||||
box_crop: bool = False
|
box_crop: bool = True
|
||||||
box_crop_mask_thr: float = 0.4
|
box_crop_mask_thr: float = 0.4
|
||||||
box_crop_context: float = 1.0
|
box_crop_context: float = 0.3
|
||||||
remove_empty_masks: bool = False
|
remove_empty_masks: bool = True
|
||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
sort_frames: bool = False
|
sort_frames: bool = False
|
||||||
eval_batches: Optional[List[List[int]]] = None
|
eval_batches: Any = None
|
||||||
frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||||
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `subset_to_image_path`.
|
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
||||||
self.subset_to_image_path = None
|
self.subset_to_image_path = None
|
||||||
self._load_frames()
|
self._load_frames()
|
||||||
self._load_sequences()
|
self._load_sequences()
|
||||||
@@ -404,54 +161,174 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
self._filter_db() # also computes sequence indices
|
self._filter_db() # also computes sequence indices
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
|
def is_filtered(self):
|
||||||
|
"""
|
||||||
|
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||||
|
stored on the disk might be missing in the dataset object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
is_filtered: `True` if the dataset has been filtered, else `False`.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
self.remove_empty_masks
|
||||||
|
or self.limit_to > 0
|
||||||
|
or self.limit_sequences_to > 0
|
||||||
|
or len(self.pick_sequence) > 0
|
||||||
|
or len(self.exclude_sequence) > 0
|
||||||
|
or len(self.limit_category_to) > 0
|
||||||
|
or self.n_frames_per_sequence > 0
|
||||||
|
)
|
||||||
|
|
||||||
def seq_frame_index_to_dataset_index(
|
def seq_frame_index_to_dataset_index(
|
||||||
self,
|
self,
|
||||||
seq_frame_index: Union[
|
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||||
List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
allow_missing_indices: bool = False,
|
||||||
],
|
remove_missing_indices: bool = False,
|
||||||
) -> List[List[int]]:
|
) -> List[List[Union[Optional[int], int]]]:
|
||||||
"""
|
"""
|
||||||
Obtain indices into the dataset object given a list of frames specified as
|
Obtain indices into the dataset object given a list of frame ids.
|
||||||
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
|
|
||||||
|
Args:
|
||||||
|
seq_frame_index: The list of frame ids specified as
|
||||||
|
`List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
|
||||||
|
Image paths relative to the dataset_root can be stored specified as well:
|
||||||
|
`List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
|
||||||
|
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
||||||
|
entry from `seq_frame_index` which is missing in the dataset.
|
||||||
|
Otherwise, depending on `remove_missing_indices`, either returns `None`
|
||||||
|
in place of missing entries or removes the indices of missing entries.
|
||||||
|
remove_missing_indices: Active when `allow_missing_indices=True`.
|
||||||
|
If `False`, returns `None` in place of `seq_frame_index` entries that
|
||||||
|
are not present in the dataset.
|
||||||
|
If `True` removes missing indices from the returned indices.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
||||||
"""
|
"""
|
||||||
# TODO: check the frame numbers are unique
|
|
||||||
_dataset_seq_frame_n_index = {
|
_dataset_seq_frame_n_index = {
|
||||||
seq: {
|
seq: {
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
||||||
for idx in seq_idx
|
for idx in seq_idx
|
||||||
}
|
}
|
||||||
|
# pyre-ignore[16]
|
||||||
for seq, seq_idx in self._seq_to_idx.items()
|
for seq, seq_idx in self._seq_to_idx.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
|
def _get_dataset_idx(
|
||||||
idx = _dataset_seq_frame_n_index[seq_name][frame_no]
|
seq_name: str, frame_no: int, path: Optional[str] = None
|
||||||
|
) -> Optional[int]:
|
||||||
|
idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
|
||||||
|
idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
|
||||||
|
if idx is None:
|
||||||
|
msg = (
|
||||||
|
f"sequence_name={seq_name} / frame_number={frame_no}"
|
||||||
|
" not in the dataset!"
|
||||||
|
)
|
||||||
|
if not allow_missing_indices:
|
||||||
|
raise IndexError(msg)
|
||||||
|
warnings.warn(msg)
|
||||||
|
return idx
|
||||||
if path is not None:
|
if path is not None:
|
||||||
# Check that the loaded frame path is consistent
|
# Check that the loaded frame path is consistent
|
||||||
# with the one stored in self.frame_annots.
|
# with the one stored in self.frame_annots.
|
||||||
assert os.path.normpath(
|
assert os.path.normpath(
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots[idx]["frame_annotation"].image.path
|
self.frame_annots[idx]["frame_annotation"].image.path
|
||||||
) == os.path.normpath(
|
) == os.path.normpath(
|
||||||
path
|
path
|
||||||
), f"Inconsistent batch {seq_name, frame_no, path}."
|
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index]
|
dataset_idx = [
|
||||||
return batches_idx
|
[_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
|
||||||
|
for batch in seq_frame_index
|
||||||
|
]
|
||||||
|
|
||||||
|
if allow_missing_indices and remove_missing_indices:
|
||||||
|
# remove all None indices, and also batches with only None entries
|
||||||
|
valid_dataset_idx = [
|
||||||
|
[b for b in batch if b is not None] for batch in dataset_idx
|
||||||
|
]
|
||||||
|
return [ # pyre-ignore[7]
|
||||||
|
batch for batch in valid_dataset_idx if len(batch) > 0
|
||||||
|
]
|
||||||
|
|
||||||
|
return dataset_idx
|
||||||
|
|
||||||
|
def subset_from_frame_index(
|
||||||
|
self,
|
||||||
|
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
||||||
|
allow_missing_indices: bool = True,
|
||||||
|
) -> "JsonIndexDataset":
|
||||||
|
# Get the indices into the frame annots.
|
||||||
|
dataset_indices = self.seq_frame_index_to_dataset_index(
|
||||||
|
[frame_index],
|
||||||
|
allow_missing_indices=self.is_filtered() and allow_missing_indices,
|
||||||
|
)[0]
|
||||||
|
valid_dataset_indices = [i for i in dataset_indices if i is not None]
|
||||||
|
|
||||||
|
# Deep copy the whole dataset except frame_annots, which are large so we
|
||||||
|
# deep copy only the requested subset of frame_annots.
|
||||||
|
memo = {id(self.frame_annots): None} # pyre-ignore[16]
|
||||||
|
dataset_new = copy.deepcopy(self, memo)
|
||||||
|
dataset_new.frame_annots = copy.deepcopy(
|
||||||
|
[self.frame_annots[i] for i in valid_dataset_indices]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will kill all unneeded sequence annotations.
|
||||||
|
dataset_new._invalidate_indexes(filter_seq_annots=True)
|
||||||
|
|
||||||
|
# Finally annotate the frame annotations with the name of the subset
|
||||||
|
# stored in meta.
|
||||||
|
for frame_annot in dataset_new.frame_annots:
|
||||||
|
frame_annotation = frame_annot["frame_annotation"]
|
||||||
|
if frame_annotation.meta is not None:
|
||||||
|
frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
|
||||||
|
|
||||||
|
# A sanity check - this will crash in case some entries from frame_index are missing
|
||||||
|
# in dataset_new.
|
||||||
|
valid_frame_index = [
|
||||||
|
fi for fi, di in zip(frame_index, dataset_indices) if di is not None
|
||||||
|
]
|
||||||
|
dataset_new.seq_frame_index_to_dataset_index(
|
||||||
|
[valid_frame_index], allow_missing_indices=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_new
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return f"ImplicitronDataset #frames={len(self.frame_annots)}"
|
# pyre-ignore[16]
|
||||||
|
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
# pyre-ignore[16]
|
||||||
return len(self.frame_annots)
|
return len(self.frame_annots)
|
||||||
|
|
||||||
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
||||||
return entry["subset"]
|
return entry["subset"]
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> CamerasBase:
|
||||||
|
"""
|
||||||
|
Returns the cameras corresponding to all the known frames.
|
||||||
|
"""
|
||||||
|
cameras = []
|
||||||
|
# pyre-ignore[16]
|
||||||
|
for frame_idx, frame_annot in enumerate(self.frame_annots):
|
||||||
|
frame_type = self._get_frame_type(frame_annot)
|
||||||
|
if frame_type is None:
|
||||||
|
raise ValueError("subsets not loaded")
|
||||||
|
if is_known_frame_scalar(frame_type):
|
||||||
|
cameras.append(self[frame_idx].camera)
|
||||||
|
return join_cameras_as_batch(cameras)
|
||||||
|
|
||||||
def __getitem__(self, index) -> FrameData:
|
def __getitem__(self, index) -> FrameData:
|
||||||
|
# pyre-ignore[16]
|
||||||
if index >= len(self.frame_annots):
|
if index >= len(self.frame_annots):
|
||||||
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
||||||
|
|
||||||
entry = self.frame_annots[index]["frame_annotation"]
|
entry = self.frame_annots[index]["frame_annotation"]
|
||||||
|
# pyre-ignore[16]
|
||||||
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
||||||
frame_data = FrameData(
|
frame_data = FrameData(
|
||||||
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
||||||
@@ -477,6 +354,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
frame_data.mask_path,
|
frame_data.mask_path,
|
||||||
frame_data.bbox_xywh,
|
frame_data.bbox_xywh,
|
||||||
clamp_bbox_xyxy,
|
clamp_bbox_xyxy,
|
||||||
|
frame_data.crop_bbox_xywh,
|
||||||
) = self._load_crop_fg_probability(entry)
|
) = self._load_crop_fg_probability(entry)
|
||||||
|
|
||||||
scale = 1.0
|
scale = 1.0
|
||||||
@@ -524,13 +402,14 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
Optional[str],
|
Optional[str],
|
||||||
Optional[torch.Tensor],
|
Optional[torch.Tensor],
|
||||||
Optional[torch.Tensor],
|
Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor],
|
||||||
]:
|
]:
|
||||||
fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = (
|
fg_probability = None
|
||||||
None,
|
full_path = None
|
||||||
None,
|
bbox_xywh = None
|
||||||
None,
|
clamp_bbox_xyxy = None
|
||||||
None,
|
crop_box_xywh = None
|
||||||
)
|
|
||||||
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
||||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||||
mask = _load_mask(self._local_path(full_path))
|
mask = _load_mask(self._local_path(full_path))
|
||||||
@@ -543,11 +422,21 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
||||||
|
|
||||||
if self.box_crop:
|
if self.box_crop:
|
||||||
clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context)
|
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
||||||
|
_get_clamp_bbox(
|
||||||
|
bbox_xywh,
|
||||||
|
image_path=entry.image.path,
|
||||||
|
box_crop_context=self.box_crop_context,
|
||||||
|
),
|
||||||
|
image_size_hw=tuple(mask.shape[-2:]),
|
||||||
|
)
|
||||||
|
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
||||||
|
|
||||||
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
||||||
|
|
||||||
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
||||||
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy
|
|
||||||
|
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
||||||
|
|
||||||
def _load_crop_images(
|
def _load_crop_images(
|
||||||
self,
|
self,
|
||||||
@@ -686,6 +575,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
)
|
)
|
||||||
if not frame_annots_list:
|
if not frame_annots_list:
|
||||||
raise ValueError("Empty dataset!")
|
raise ValueError("Empty dataset!")
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots = [
|
self.frame_annots = [
|
||||||
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
||||||
]
|
]
|
||||||
@@ -697,6 +587,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
||||||
if not seq_annots:
|
if not seq_annots:
|
||||||
raise ValueError("Empty sequences file!")
|
raise ValueError("Empty sequences file!")
|
||||||
|
# pyre-ignore[16]
|
||||||
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
||||||
|
|
||||||
def _load_subset_lists(self) -> None:
|
def _load_subset_lists(self) -> None:
|
||||||
@@ -712,7 +603,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
for subset, frames in subset_to_seq_frame.items()
|
for subset, frames in subset_to_seq_frame.items()
|
||||||
for _, _, path in frames
|
for _, _, path in frames
|
||||||
}
|
}
|
||||||
|
# pyre-ignore[16]
|
||||||
for frame in self.frame_annots:
|
for frame in self.frame_annots:
|
||||||
frame["subset"] = frame_path_to_subset.get(
|
frame["subset"] = frame_path_to_subset.get(
|
||||||
frame["frame_annotation"].image.path, None
|
frame["frame_annotation"].image.path, None
|
||||||
@@ -725,6 +616,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
|
|
||||||
def _sort_frames(self) -> None:
|
def _sort_frames(self) -> None:
|
||||||
# Sort frames to have them grouped by sequence, ordered by timestamp
|
# Sort frames to have them grouped by sequence, ordered by timestamp
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots = sorted(
|
self.frame_annots = sorted(
|
||||||
self.frame_annots,
|
self.frame_annots,
|
||||||
key=lambda f: (
|
key=lambda f: (
|
||||||
@@ -736,6 +628,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
def _filter_db(self) -> None:
|
def _filter_db(self) -> None:
|
||||||
if self.remove_empty_masks:
|
if self.remove_empty_masks:
|
||||||
logger.info("Removing images with empty masks.")
|
logger.info("Removing images with empty masks.")
|
||||||
|
# pyre-ignore[16]
|
||||||
old_len = len(self.frame_annots)
|
old_len = len(self.frame_annots)
|
||||||
|
|
||||||
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
||||||
@@ -776,6 +669,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
|
|
||||||
if len(self.limit_category_to) > 0:
|
if len(self.limit_category_to) > 0:
|
||||||
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
||||||
|
# pyre-ignore[16]
|
||||||
self.seq_annots = {
|
self.seq_annots = {
|
||||||
name: entry
|
name: entry
|
||||||
for name, entry in self.seq_annots.items()
|
for name, entry in self.seq_annots.items()
|
||||||
@@ -813,6 +707,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
if self.n_frames_per_sequence > 0:
|
if self.n_frames_per_sequence > 0:
|
||||||
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||||
keep_idx = []
|
keep_idx = []
|
||||||
|
# pyre-ignore[16]
|
||||||
for seq, seq_indices in self._seq_to_idx.items():
|
for seq, seq_indices in self._seq_to_idx.items():
|
||||||
# infer the seed from the sequence name, this is reproducible
|
# infer the seed from the sequence name, this is reproducible
|
||||||
# and makes the selection differ for different sequences
|
# and makes the selection differ for different sequences
|
||||||
@@ -842,14 +737,20 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
self._invalidate_seq_to_idx()
|
self._invalidate_seq_to_idx()
|
||||||
|
|
||||||
if filter_seq_annots:
|
if filter_seq_annots:
|
||||||
|
# pyre-ignore[16]
|
||||||
self.seq_annots = {
|
self.seq_annots = {
|
||||||
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
|
k: v
|
||||||
|
for k, v in self.seq_annots.items()
|
||||||
|
# pyre-ignore[16]
|
||||||
|
if k in self._seq_to_idx
|
||||||
}
|
}
|
||||||
|
|
||||||
def _invalidate_seq_to_idx(self) -> None:
|
def _invalidate_seq_to_idx(self) -> None:
|
||||||
seq_to_idx = defaultdict(list)
|
seq_to_idx = defaultdict(list)
|
||||||
|
# pyre-ignore[16]
|
||||||
for idx, entry in enumerate(self.frame_annots):
|
for idx, entry in enumerate(self.frame_annots):
|
||||||
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
||||||
|
# pyre-ignore[16]
|
||||||
self._seq_to_idx = seq_to_idx
|
self._seq_to_idx = seq_to_idx
|
||||||
|
|
||||||
def _resize_image(
|
def _resize_image(
|
||||||
@@ -867,16 +768,18 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
)
|
)
|
||||||
imre = torch.nn.functional.interpolate(
|
imre = torch.nn.functional.interpolate(
|
||||||
torch.from_numpy(image)[None],
|
torch.from_numpy(image)[None],
|
||||||
# pyre-ignore[6]
|
|
||||||
scale_factor=minscale,
|
scale_factor=minscale,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
align_corners=False if mode == "bilinear" else None,
|
align_corners=False if mode == "bilinear" else None,
|
||||||
recompute_scale_factor=True,
|
recompute_scale_factor=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
# pyre-fixme[19]: Expected 1 positional argument.
|
||||||
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
||||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||||
|
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
|
||||||
|
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
|
||||||
mask = torch.zeros(1, self.image_height, self.image_width)
|
mask = torch.zeros(1, self.image_height, self.image_width)
|
||||||
mask[:, 0 : imre.shape[1] - 1, 0 : imre.shape[2] - 1] = 1.0
|
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
||||||
return imre_, minscale, mask
|
return imre_, minscale, mask
|
||||||
|
|
||||||
def _local_path(self, path: str) -> str:
|
def _local_path(self, path: str) -> str:
|
||||||
@@ -889,6 +792,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
|||||||
) -> List[Tuple[int, float]]:
|
) -> List[Tuple[int, float]]:
|
||||||
out: List[Tuple[int, float]] = []
|
out: List[Tuple[int, float]] = []
|
||||||
for idx in idxs:
|
for idx in idxs:
|
||||||
|
# pyre-ignore[16]
|
||||||
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
||||||
out.append(
|
out.append(
|
||||||
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
||||||
@@ -929,7 +833,7 @@ def _load_1bit_png_mask(file: str) -> np.ndarray:
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def _load_depth_mask(path) -> np.ndarray:
|
def _load_depth_mask(path: str) -> np.ndarray:
|
||||||
if not path.lower().endswith(".png"):
|
if not path.lower().endswith(".png"):
|
||||||
raise ValueError('unsupported depth mask file name "%s"' % path)
|
raise ValueError('unsupported depth mask file name "%s"' % path)
|
||||||
m = _load_1bit_png_mask(path)
|
m = _load_1bit_png_mask(path)
|
||||||
@@ -954,7 +858,7 @@ def _load_mask(path) -> np.ndarray:
|
|||||||
|
|
||||||
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
||||||
nz = np.flatnonzero(arr)
|
nz = np.flatnonzero(arr)
|
||||||
return nz[0], nz[-1]
|
return nz[0], nz[-1] + 1
|
||||||
|
|
||||||
|
|
||||||
def _get_bbox_from_mask(
|
def _get_bbox_from_mask(
|
||||||
@@ -975,11 +879,15 @@ def _get_bbox_from_mask(
|
|||||||
|
|
||||||
|
|
||||||
def _get_clamp_bbox(
|
def _get_clamp_bbox(
|
||||||
bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = ""
|
bbox: torch.Tensor,
|
||||||
|
box_crop_context: float = 0.0,
|
||||||
|
image_path: str = "",
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# box_crop_context: rate of expansion for bbox
|
# box_crop_context: rate of expansion for bbox
|
||||||
# returns possibly expanded bbox xyxy as float
|
# returns possibly expanded bbox xyxy as float
|
||||||
|
|
||||||
|
bbox = bbox.clone() # do not edit bbox in place
|
||||||
|
|
||||||
# increase box size
|
# increase box size
|
||||||
if box_crop_context > 0.0:
|
if box_crop_context > 0.0:
|
||||||
c = box_crop_context
|
c = box_crop_context
|
||||||
@@ -991,27 +899,38 @@ def _get_clamp_bbox(
|
|||||||
|
|
||||||
if (bbox[2:] <= 1.0).any():
|
if (bbox[2:] <= 1.0).any():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"squashed image {impath}!! The bounding box contains no pixels."
|
f"squashed image {image_path}!! The bounding box contains no pixels."
|
||||||
)
|
)
|
||||||
|
|
||||||
bbox[2:] = torch.clamp(bbox[2:], 2)
|
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
||||||
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax]
|
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
||||||
# +1 because upper bound is not inclusive
|
|
||||||
|
|
||||||
return bbox
|
return bbox_xyxy
|
||||||
|
|
||||||
|
|
||||||
def _crop_around_box(tensor, bbox, impath: str = ""):
|
def _crop_around_box(tensor, bbox, impath: str = ""):
|
||||||
# bbox is xyxy, where the upper bound is corrected with +1
|
# bbox is xyxy, where the upper bound is corrected with +1
|
||||||
bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1])
|
bbox = _clamp_box_to_image_bounds_and_round(
|
||||||
bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2])
|
bbox,
|
||||||
bbox = bbox.round().long()
|
image_size_hw=tensor.shape[-2:],
|
||||||
|
)
|
||||||
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
||||||
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
||||||
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _clamp_box_to_image_bounds_and_round(
|
||||||
|
bbox_xyxy: torch.Tensor,
|
||||||
|
image_size_hw: Tuple[int, int],
|
||||||
|
) -> torch.LongTensor:
|
||||||
|
bbox_xyxy = bbox_xyxy.clone()
|
||||||
|
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
||||||
|
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
||||||
|
if not isinstance(bbox_xyxy, torch.LongTensor):
|
||||||
|
bbox_xyxy = bbox_xyxy.round().long()
|
||||||
|
return bbox_xyxy # pyre-ignore [7]
|
||||||
|
|
||||||
|
|
||||||
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
||||||
assert bbox is not None
|
assert bbox is not None
|
||||||
assert np.prod(orig_res) > 1e-8
|
assert np.prod(orig_res) > 1e-8
|
||||||
@@ -1020,6 +939,22 @@ def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
|||||||
return bbox * rel_size
|
return bbox * rel_size
|
||||||
|
|
||||||
|
|
||||||
|
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
|
||||||
|
wh = xyxy[2:] - xyxy[:2]
|
||||||
|
xywh = torch.cat([xyxy[:2], wh])
|
||||||
|
return xywh
|
||||||
|
|
||||||
|
|
||||||
|
def _bbox_xywh_to_xyxy(
|
||||||
|
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
|
xyxy = xywh.clone()
|
||||||
|
if clamp_size is not None:
|
||||||
|
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||||
|
xyxy[2:] += xyxy[:2]
|
||||||
|
return xyxy
|
||||||
|
|
||||||
|
|
||||||
def _safe_as_tensor(data, dtype):
|
def _safe_as_tensor(data, dtype):
|
||||||
if data is None:
|
if data is None:
|
||||||
return None
|
return None
|
||||||
326
pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py
Normal file
326
pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, open_dict
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
|
from .json_index_dataset import JsonIndexDataset
|
||||||
|
|
||||||
|
from .utils import (
|
||||||
|
DATASET_TYPE_KNOWN,
|
||||||
|
DATASET_TYPE_TEST,
|
||||||
|
DATASET_TYPE_TRAIN,
|
||||||
|
DATASET_TYPE_UNKNOWN,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
CO3D_CATEGORIES: List[str] = list(reversed([
|
||||||
|
"baseballbat", "banana", "bicycle", "microwave", "tv",
|
||||||
|
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
|
||||||
|
"umbrella", "wineglass", "laptop",
|
||||||
|
"hotdog", "stopsign", "frisbee", "baseballglove",
|
||||||
|
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
|
||||||
|
"handbag", "chair", "keyboard", "car", "motorcycle",
|
||||||
|
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
|
||||||
|
"toaster", "mouse", "toytrain", "book", "toytruck",
|
||||||
|
"orange", "broccoli", "plant", "teddybear",
|
||||||
|
"suitcase", "bench", "ball", "cake",
|
||||||
|
"vase", "hydrant", "apple", "donut",
|
||||||
|
]))
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||||
|
|
||||||
|
# _NEED_CONTROL is a list of those elements of JsonIndexDataset which
|
||||||
|
# are not directly specified for it in the config but come from the
|
||||||
|
# DatasetMapProvider.
|
||||||
|
_NEED_CONTROL: Tuple[str, ...] = (
|
||||||
|
"dataset_root",
|
||||||
|
"eval_batches",
|
||||||
|
"n_frames_per_sequence",
|
||||||
|
"path_manager",
|
||||||
|
"pick_sequence",
|
||||||
|
"subsets",
|
||||||
|
"frame_annotations_file",
|
||||||
|
"sequence_annotations_file",
|
||||||
|
"subset_lists_file",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||||
|
"""
|
||||||
|
Generates the training / validation and testing dataset objects for
|
||||||
|
a dataset laid out on disk like Co3D, with annotations in json files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The object category of the dataset.
|
||||||
|
task_str: "multisequence" or "singlesequence".
|
||||||
|
dataset_root: The root folder of the dataset.
|
||||||
|
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
|
||||||
|
in each sequence.
|
||||||
|
test_on_train: Construct validation and test datasets from
|
||||||
|
the training subset.
|
||||||
|
restrict_sequence_name: Restrict the dataset sequences to the ones
|
||||||
|
present in the given list of names.
|
||||||
|
test_restrict_sequence_id: The ID of the loaded sequence.
|
||||||
|
Active for task_str='singlesequence'.
|
||||||
|
assert_single_seq: Assert that only frames from a single sequence
|
||||||
|
are present in all generated datasets.
|
||||||
|
only_test_set: Load only the test set.
|
||||||
|
dataset_class_type: name of class (JsonIndexDataset or a subclass)
|
||||||
|
to use for the dataset.
|
||||||
|
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
|
||||||
|
to all the dataset constructors.
|
||||||
|
path_manager_factory: (Optional) An object that generates an instance of
|
||||||
|
PathManager that can translate provided file paths.
|
||||||
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
category: str
|
||||||
|
task_str: str = "singlesequence"
|
||||||
|
dataset_root: str = _CO3D_DATASET_ROOT
|
||||||
|
n_frames_per_sequence: int = -1
|
||||||
|
test_on_train: bool = False
|
||||||
|
restrict_sequence_name: Tuple[str, ...] = ()
|
||||||
|
test_restrict_sequence_id: int = -1
|
||||||
|
assert_single_seq: bool = False
|
||||||
|
only_test_set: bool = False
|
||||||
|
dataset: JsonIndexDataset
|
||||||
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
path_manager_factory: PathManagerFactory
|
||||||
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
||||||
|
not expose certain fields of each dataset class.
|
||||||
|
"""
|
||||||
|
with open_dict(args):
|
||||||
|
for key in _NEED_CONTROL:
|
||||||
|
del args[key]
|
||||||
|
|
||||||
|
def create_dataset(self):
|
||||||
|
"""
|
||||||
|
Prevent the member named dataset from being created.
|
||||||
|
"""
|
||||||
|
return
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
if self.only_test_set and self.test_on_train:
|
||||||
|
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||||
|
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
# - implement loading multiple categories
|
||||||
|
|
||||||
|
frame_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||||
|
)
|
||||||
|
sequence_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||||
|
)
|
||||||
|
subset_lists_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "set_lists.json"
|
||||||
|
)
|
||||||
|
common_kwargs = {
|
||||||
|
"dataset_root": self.dataset_root,
|
||||||
|
"path_manager": path_manager,
|
||||||
|
"frame_annotations_file": frame_file,
|
||||||
|
"sequence_annotations_file": sequence_file,
|
||||||
|
"subset_lists_file": subset_lists_file,
|
||||||
|
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||||
|
# to the names of the subsets in the CO3D dataset.
|
||||||
|
set_names_mapping = _get_co3d_set_names_mapping(
|
||||||
|
self.get_task(),
|
||||||
|
self.test_on_train,
|
||||||
|
self.only_test_set,
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the evaluation batches
|
||||||
|
batch_indices_path = os.path.join(
|
||||||
|
self.dataset_root,
|
||||||
|
self.category,
|
||||||
|
f"eval_batches_{self.task_str}.json",
|
||||||
|
)
|
||||||
|
if path_manager is not None:
|
||||||
|
batch_indices_path = path_manager.get_local_path(batch_indices_path)
|
||||||
|
if not os.path.isfile(batch_indices_path):
|
||||||
|
# The batch indices file does not exist.
|
||||||
|
# Most probably the user has not specified the root folder.
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for batch indices in {batch_indices_path}. "
|
||||||
|
+ "Please specify a correct dataset_root folder."
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(batch_indices_path, "r") as f:
|
||||||
|
eval_batch_index = json.load(f)
|
||||||
|
restrict_sequence_name = self.restrict_sequence_name
|
||||||
|
|
||||||
|
if self.get_task() == Task.SINGLE_SEQUENCE:
|
||||||
|
if (
|
||||||
|
self.test_restrict_sequence_id is None
|
||||||
|
or self.test_restrict_sequence_id < 0
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Please specify an integer id 'test_restrict_sequence_id'"
|
||||||
|
+ " of the sequence considered for 'singlesequence'"
|
||||||
|
+ " training and evaluation."
|
||||||
|
)
|
||||||
|
if len(self.restrict_sequence_name) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"For the 'singlesequence' task, the restrict_sequence_name has"
|
||||||
|
" to be unset while test_restrict_sequence_id has to be set to an"
|
||||||
|
" integer defining the order of the evaluation sequence."
|
||||||
|
)
|
||||||
|
# a sort-stable set() equivalent:
|
||||||
|
eval_batches_sequence_names = list(
|
||||||
|
{b[0][0]: None for b in eval_batch_index}.keys()
|
||||||
|
)
|
||||||
|
eval_sequence_name = eval_batches_sequence_names[
|
||||||
|
self.test_restrict_sequence_id
|
||||||
|
]
|
||||||
|
eval_batch_index = [
|
||||||
|
b for b in eval_batch_index if b[0][0] == eval_sequence_name
|
||||||
|
]
|
||||||
|
# overwrite the restrict_sequence_name
|
||||||
|
restrict_sequence_name = [eval_sequence_name]
|
||||||
|
|
||||||
|
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||||
|
JsonIndexDataset, self.dataset_class_type
|
||||||
|
)
|
||||||
|
expand_args_fields(dataset_type)
|
||||||
|
train_dataset = None
|
||||||
|
if not self.only_test_set:
|
||||||
|
train_dataset = dataset_type(
|
||||||
|
n_frames_per_sequence=self.n_frames_per_sequence,
|
||||||
|
subsets=set_names_mapping["train"],
|
||||||
|
pick_sequence=restrict_sequence_name,
|
||||||
|
**common_kwargs,
|
||||||
|
)
|
||||||
|
if self.test_on_train:
|
||||||
|
assert train_dataset is not None
|
||||||
|
val_dataset = test_dataset = train_dataset
|
||||||
|
else:
|
||||||
|
val_dataset = dataset_type(
|
||||||
|
n_frames_per_sequence=-1,
|
||||||
|
subsets=set_names_mapping["val"],
|
||||||
|
pick_sequence=restrict_sequence_name,
|
||||||
|
**common_kwargs,
|
||||||
|
)
|
||||||
|
test_dataset = dataset_type(
|
||||||
|
n_frames_per_sequence=-1,
|
||||||
|
subsets=set_names_mapping["test"],
|
||||||
|
pick_sequence=restrict_sequence_name,
|
||||||
|
**common_kwargs,
|
||||||
|
)
|
||||||
|
if len(restrict_sequence_name) > 0:
|
||||||
|
eval_batch_index = [
|
||||||
|
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||||
|
]
|
||||||
|
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
eval_batch_index
|
||||||
|
)
|
||||||
|
dataset_map = DatasetMap(
|
||||||
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.assert_single_seq:
|
||||||
|
# check there's only one sequence in all datasets
|
||||||
|
sequence_names = {
|
||||||
|
sequence_name
|
||||||
|
for dset in dataset_map.iter_datasets()
|
||||||
|
for sequence_name in dset.sequence_names()
|
||||||
|
}
|
||||||
|
if len(sequence_names) > 1:
|
||||||
|
raise ValueError("Multiple sequences loaded but expected one")
|
||||||
|
|
||||||
|
self.dataset_map = dataset_map
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
return self.dataset_map
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
return Task(self.task_str)
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# pyre-ignore[16]
|
||||||
|
train_dataset = self.dataset_map.train
|
||||||
|
assert isinstance(train_dataset, JsonIndexDataset)
|
||||||
|
return train_dataset.get_all_train_cameras()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_co3d_set_names_mapping(
|
||||||
|
task: Task,
|
||||||
|
test_on_train: bool,
|
||||||
|
only_test: bool,
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
|
||||||
|
to the names of the corresponding subsets in the CO3D dataset
|
||||||
|
("test_known"/"test_unseen"/"train_known"/"train_unseen").
|
||||||
|
|
||||||
|
The keys returned will be
|
||||||
|
- train (if not only_test)
|
||||||
|
- val (if not test_on_train)
|
||||||
|
- test (if not test_on_train)
|
||||||
|
"""
|
||||||
|
single_seq = task == Task.SINGLE_SEQUENCE
|
||||||
|
|
||||||
|
if only_test:
|
||||||
|
set_names_mapping = {}
|
||||||
|
else:
|
||||||
|
set_names_mapping = {
|
||||||
|
"train": [
|
||||||
|
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
|
||||||
|
+ "_"
|
||||||
|
+ DATASET_TYPE_KNOWN
|
||||||
|
]
|
||||||
|
}
|
||||||
|
if not test_on_train:
|
||||||
|
prefixes = [DATASET_TYPE_TEST]
|
||||||
|
if not single_seq:
|
||||||
|
prefixes.append(DATASET_TYPE_TRAIN)
|
||||||
|
set_names_mapping.update(
|
||||||
|
{
|
||||||
|
dset: [
|
||||||
|
p + "_" + t
|
||||||
|
for p in prefixes
|
||||||
|
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
|
||||||
|
]
|
||||||
|
for dset in ["val", "test"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return set_names_mapping
|
||||||
@@ -0,0 +1,358 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
|
||||||
|
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||||
|
"""
|
||||||
|
Generates the training, validation, and testing dataset objects for
|
||||||
|
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
||||||
|
|
||||||
|
The dataset is organized in the filesystem as follows:
|
||||||
|
```
|
||||||
|
self.dataset_root
|
||||||
|
├── <category_0>
|
||||||
|
│ ├── <sequence_name_0>
|
||||||
|
│ │ ├── depth_masks
|
||||||
|
│ │ ├── depths
|
||||||
|
│ │ ├── images
|
||||||
|
│ │ ├── masks
|
||||||
|
│ │ └── pointcloud.ply
|
||||||
|
│ ├── <sequence_name_1>
|
||||||
|
│ │ ├── depth_masks
|
||||||
|
│ │ ├── depths
|
||||||
|
│ │ ├── images
|
||||||
|
│ │ ├── masks
|
||||||
|
│ │ └── pointcloud.ply
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── <sequence_name_N>
|
||||||
|
│ ├── set_lists
|
||||||
|
│ ├── set_lists_<subset_name_0>.json
|
||||||
|
│ ├── set_lists_<subset_name_1>.json
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── set_lists_<subset_name_M>.json
|
||||||
|
│ ├── eval_batches
|
||||||
|
│ │ ├── eval_batches_<subset_name_0>.json
|
||||||
|
│ │ ├── eval_batches_<subset_name_1>.json
|
||||||
|
│ │ ├── ...
|
||||||
|
│ │ ├── eval_batches_<subset_name_M>.json
|
||||||
|
│ ├── frame_annotations.jgz
|
||||||
|
│ ├── sequence_annotations.jgz
|
||||||
|
├── <category_1>
|
||||||
|
├── ...
|
||||||
|
├── <category_K>
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset contains sequences named `<sequence_name_i>` from `K` categories with
|
||||||
|
names `<category_j>`. Each category comprises sequence folders
|
||||||
|
`<category_k>/<sequence_name_i>` containing the list of sequence images, depth maps,
|
||||||
|
foreground masks, and valid-depth masks `images`, `depths`, `masks`, and `depth_masks`
|
||||||
|
respectively. Furthermore, `<category_k>/<sequence_name_i>/set_lists/` stores `M`
|
||||||
|
json files `set_lists_<subset_name_l>.json`, each describing a certain sequence subset.
|
||||||
|
|
||||||
|
Users specify the loaded dataset subset by setting `self.subset_name` to one of the
|
||||||
|
available subset names `<subset_name_l>`.
|
||||||
|
|
||||||
|
`frame_annotations.jgz` and `sequence_annotations.jgz` are gzipped json files containing
|
||||||
|
the list of all frames and sequences of the given category stored as lists of
|
||||||
|
`FrameAnnotation` and `SequenceAnnotation` objects respectivelly.
|
||||||
|
|
||||||
|
Each `set_lists_<subset_name_l>.json` file contains the following dictionary:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"train": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"val": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"test": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
]
|
||||||
|
```
|
||||||
|
defining the list of frames (identified with their `sequence_name` and `frame_number`)
|
||||||
|
in the "train", "val", and "test" subsets of the dataset.
|
||||||
|
Note that `frame_number` can be obtained only from `frame_annotations.jgz` and
|
||||||
|
does not necesarrily correspond to the numeric suffix of the corresponding image
|
||||||
|
file name (e.g. a file `<category_0>/<sequence_name_0>/images/frame00005.jpg` can
|
||||||
|
have its frame number set to `20`, not 5).
|
||||||
|
|
||||||
|
Each `eval_batches_<subset_name_l>.json` file contains a list of evaluation examples
|
||||||
|
in the following form:
|
||||||
|
```
|
||||||
|
[
|
||||||
|
[ # batch 1
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
[ # batch 1
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
]
|
||||||
|
```
|
||||||
|
Note that the evaluation examples always come from the `"test"` subset of the dataset.
|
||||||
|
(test frames can repeat across batches).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The object category of the dataset.
|
||||||
|
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
||||||
|
e.g. "manyview_dev_0", "fewview_test", ...
|
||||||
|
dataset_root: The root folder of the dataset.
|
||||||
|
test_on_train: Construct validation and test datasets from
|
||||||
|
the training subset.
|
||||||
|
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||||
|
load_eval_batches: Load the file containing eval batches pointing to the
|
||||||
|
test dataset.
|
||||||
|
dataset_args: Specifies additional arguments to the
|
||||||
|
JsonIndexDataset constructor call.
|
||||||
|
path_manager_factory: (Optional) An object that generates an instance of
|
||||||
|
PathManager that can translate provided file paths.
|
||||||
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
category: str
|
||||||
|
subset_name: str
|
||||||
|
dataset_root: str = _CO3DV2_DATASET_ROOT
|
||||||
|
|
||||||
|
test_on_train: bool = False
|
||||||
|
only_test_set: bool = False
|
||||||
|
load_eval_batches: bool = True
|
||||||
|
|
||||||
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
dataset: JsonIndexDataset
|
||||||
|
|
||||||
|
path_manager_factory: PathManagerFactory
|
||||||
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
if self.only_test_set and self.test_on_train:
|
||||||
|
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||||
|
|
||||||
|
frame_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||||
|
)
|
||||||
|
sequence_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||||
|
)
|
||||||
|
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
|
||||||
|
# setup the common dataset arguments
|
||||||
|
common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args")
|
||||||
|
common_dataset_kwargs = {
|
||||||
|
**common_dataset_kwargs,
|
||||||
|
"dataset_root": self.dataset_root,
|
||||||
|
"frame_annotations_file": frame_file,
|
||||||
|
"sequence_annotations_file": sequence_file,
|
||||||
|
"subsets": None,
|
||||||
|
"subset_lists_file": "",
|
||||||
|
"path_manager": path_manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
# get the used dataset type
|
||||||
|
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||||
|
JsonIndexDataset, self.dataset_class_type
|
||||||
|
)
|
||||||
|
expand_args_fields(dataset_type)
|
||||||
|
|
||||||
|
dataset = dataset_type(**common_dataset_kwargs)
|
||||||
|
|
||||||
|
available_subset_names = self._get_available_subset_names()
|
||||||
|
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
||||||
|
if self.subset_name not in available_subset_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown subset name {self.subset_name}."
|
||||||
|
+ f" Choose one of available subsets: {str(available_subset_names)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the list of train/val/test frames
|
||||||
|
subset_mapping = self._load_annotation_json(
|
||||||
|
os.path.join(
|
||||||
|
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the evaluation batches
|
||||||
|
if self.load_eval_batches:
|
||||||
|
eval_batch_index = self._load_annotation_json(
|
||||||
|
os.path.join(
|
||||||
|
self.category,
|
||||||
|
"eval_batches",
|
||||||
|
f"eval_batches_{self.subset_name}.json",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = None
|
||||||
|
if not self.only_test_set:
|
||||||
|
# load the training set
|
||||||
|
logger.debug("Extracting train dataset.")
|
||||||
|
train_dataset = dataset.subset_from_frame_index(subset_mapping["train"])
|
||||||
|
logger.info(f"Train dataset: {str(train_dataset)}")
|
||||||
|
|
||||||
|
if self.test_on_train:
|
||||||
|
assert train_dataset is not None
|
||||||
|
val_dataset = test_dataset = train_dataset
|
||||||
|
else:
|
||||||
|
# load the val and test sets
|
||||||
|
logger.debug("Extracting val dataset.")
|
||||||
|
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
||||||
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
|
logger.debug("Extracting test dataset.")
|
||||||
|
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
||||||
|
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||||
|
if self.load_eval_batches:
|
||||||
|
# load the eval batches
|
||||||
|
logger.debug("Extracting eval batches.")
|
||||||
|
try:
|
||||||
|
test_dataset.eval_batches = (
|
||||||
|
test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
eval_batch_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
warnings.warn(
|
||||||
|
"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n"
|
||||||
|
+ "Some eval batches are missing from the test dataset.\n"
|
||||||
|
+ "The evaluation results will be incomparable to the\n"
|
||||||
|
+ "evaluation results calculated on the original dataset.\n"
|
||||||
|
+ "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
|
||||||
|
)
|
||||||
|
test_dataset.eval_batches = (
|
||||||
|
test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
eval_batch_index,
|
||||||
|
allow_missing_indices=True,
|
||||||
|
remove_missing_indices=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
||||||
|
|
||||||
|
self.dataset_map = DatasetMap(
|
||||||
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_dataset(self):
|
||||||
|
# The dataset object is created inside `self.get_dataset_map`
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
return self.dataset_map # pyre-ignore [16]
|
||||||
|
|
||||||
|
def get_category_to_subset_name_list(self) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Returns a global dataset index containing the available subset names per category
|
||||||
|
as a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
category_to_subset_name_list: A dictionary containing subset names available
|
||||||
|
per category of the following form:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
category_0: [category_0_subset_name_0, category_0_subset_name_1, ...],
|
||||||
|
category_1: [category_1_subset_name_0, category_1_subset_name_1, ...],
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
category_to_subset_name_list_json = "category_to_subset_name_list.json"
|
||||||
|
category_to_subset_name_list = self._load_annotation_json(
|
||||||
|
category_to_subset_name_list_json
|
||||||
|
)
|
||||||
|
return category_to_subset_name_list
|
||||||
|
|
||||||
|
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
||||||
|
return {
|
||||||
|
"manyview": Task.SINGLE_SEQUENCE,
|
||||||
|
"fewview": Task.MULTI_SEQUENCE,
|
||||||
|
}[self.subset_name.split("_")[0]]
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
train_dataset = self.dataset_map.train
|
||||||
|
assert isinstance(train_dataset, JsonIndexDataset)
|
||||||
|
return train_dataset.get_all_train_cameras()
|
||||||
|
|
||||||
|
def _load_annotation_json(self, json_filename: str):
|
||||||
|
full_path = os.path.join(
|
||||||
|
self.dataset_root,
|
||||||
|
json_filename,
|
||||||
|
)
|
||||||
|
logger.info(f"Loading frame index json from {full_path}.")
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
if path_manager is not None:
|
||||||
|
full_path = path_manager.get_local_path(full_path)
|
||||||
|
if not os.path.isfile(full_path):
|
||||||
|
# The batch indices file does not exist.
|
||||||
|
# Most probably the user has not specified the root folder.
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for dataset json file in {full_path}. "
|
||||||
|
+ "Please specify a correct dataset_root folder."
|
||||||
|
)
|
||||||
|
with open(full_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _get_available_subset_names(self):
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
if path_manager is not None:
|
||||||
|
dataset_root = path_manager.get_local_path(self.dataset_root)
|
||||||
|
else:
|
||||||
|
dataset_root = self.dataset_root
|
||||||
|
return get_available_subset_names(dataset_root, self.category)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the available subset names for a given category folder inside a root dataset
|
||||||
|
folder `dataset_root`.
|
||||||
|
"""
|
||||||
|
category_dir = os.path.join(dataset_root, category)
|
||||||
|
if not os.path.isdir(category_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for dataset files in {category_dir}. "
|
||||||
|
+ "Please specify a correct dataset_root folder."
|
||||||
|
)
|
||||||
|
set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists"))
|
||||||
|
return [
|
||||||
|
json_file.replace("set_lists_", "").replace(".json", "")
|
||||||
|
for json_file in set_list_jsons
|
||||||
|
]
|
||||||
63
pytorch3d/implicitron/dataset/llff_dataset_map_provider.py
Normal file
63
pytorch3d/implicitron/dataset/llff_dataset_map_provider.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
|
||||||
|
from .load_llff import load_llff_data
|
||||||
|
|
||||||
|
from .single_sequence_dataset import (
|
||||||
|
_interpret_blender_cameras,
|
||||||
|
SingleSceneDatasetMapProviderBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||||
|
"""
|
||||||
|
Provides data for one scene from the LLFF dataset.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
base_dir: directory holding the data for the scene.
|
||||||
|
object_name: The name of the scene (e.g. "fern"). This is just used as a label.
|
||||||
|
It will typically be equal to the name of the directory self.base_dir.
|
||||||
|
path_manager_factory: Creates path manager which may be used for
|
||||||
|
interpreting paths.
|
||||||
|
n_known_frames_for_test: If set, training frames are included in the val
|
||||||
|
and test datasets, and this many random training frames are added to
|
||||||
|
each test batch. If not set, test batches each contain just a single
|
||||||
|
testing frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_data(self) -> None:
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
images, poses, _ = load_llff_data(
|
||||||
|
self.base_dir, factor=8, path_manager=path_manager
|
||||||
|
)
|
||||||
|
hwf = poses[0, :3, -1]
|
||||||
|
poses = poses[:, :3, :4]
|
||||||
|
|
||||||
|
i_test = np.arange(images.shape[0])[::8]
|
||||||
|
i_test_index = set(i_test.tolist())
|
||||||
|
i_train = np.array(
|
||||||
|
[i for i in np.arange(images.shape[0]) if i not in i_test_index]
|
||||||
|
)
|
||||||
|
i_split = (i_train, i_test, i_test)
|
||||||
|
H, W, focal = hwf
|
||||||
|
focal_ndc = 2 * focal / min(H, W)
|
||||||
|
images = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||||
|
poses = torch.from_numpy(poses)
|
||||||
|
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.poses = _interpret_blender_cameras(poses, focal_ndc)
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.images = images
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.fg_probabilities = None
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.i_split = i_split
|
||||||
141
pytorch3d/implicitron/dataset/load_blender.py
Normal file
141
pytorch3d/implicitron/dataset/load_blender.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
# @lint-ignore-every LICENSELINT
|
||||||
|
# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py
|
||||||
|
# Copyright (c) 2020 bmild
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def translate_by_t_along_z(t):
|
||||||
|
tform = np.eye(4).astype(np.float32)
|
||||||
|
tform[2][3] = t
|
||||||
|
return tform
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_by_phi_along_x(phi):
|
||||||
|
tform = np.eye(4).astype(np.float32)
|
||||||
|
tform[1, 1] = tform[2, 2] = np.cos(phi)
|
||||||
|
tform[1, 2] = -np.sin(phi)
|
||||||
|
tform[2, 1] = -tform[1, 2]
|
||||||
|
return tform
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_by_theta_along_y(theta):
|
||||||
|
tform = np.eye(4).astype(np.float32)
|
||||||
|
tform[0, 0] = tform[2, 2] = np.cos(theta)
|
||||||
|
tform[0, 2] = -np.sin(theta)
|
||||||
|
tform[2, 0] = -tform[0, 2]
|
||||||
|
return tform
|
||||||
|
|
||||||
|
|
||||||
|
def pose_spherical(theta, phi, radius):
|
||||||
|
c2w = translate_by_t_along_z(radius)
|
||||||
|
c2w = rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w
|
||||||
|
c2w = rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w
|
||||||
|
c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
|
||||||
|
return c2w
|
||||||
|
|
||||||
|
|
||||||
|
def _local_path(path_manager, path):
|
||||||
|
if path_manager is None:
|
||||||
|
return path
|
||||||
|
return path_manager.get_local_path(path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_blender_data(
|
||||||
|
basedir,
|
||||||
|
half_res=False,
|
||||||
|
testskip=1,
|
||||||
|
debug=False,
|
||||||
|
path_manager=None,
|
||||||
|
focal_length_in_screen_space=False,
|
||||||
|
):
|
||||||
|
splits = ["train", "val", "test"]
|
||||||
|
metas = {}
|
||||||
|
for s in splits:
|
||||||
|
path = os.path.join(basedir, f"transforms_{s}.json")
|
||||||
|
with open(_local_path(path_manager, path)) as fp:
|
||||||
|
metas[s] = json.load(fp)
|
||||||
|
|
||||||
|
all_imgs = []
|
||||||
|
all_poses = []
|
||||||
|
counts = [0]
|
||||||
|
for s in splits:
|
||||||
|
meta = metas[s]
|
||||||
|
imgs = []
|
||||||
|
poses = []
|
||||||
|
if s == "train" or testskip == 0:
|
||||||
|
skip = 1
|
||||||
|
else:
|
||||||
|
skip = testskip
|
||||||
|
|
||||||
|
for frame in meta["frames"][::skip]:
|
||||||
|
fname = os.path.join(basedir, frame["file_path"] + ".png")
|
||||||
|
imgs.append(np.array(Image.open(_local_path(path_manager, fname))))
|
||||||
|
poses.append(np.array(frame["transform_matrix"]))
|
||||||
|
imgs = (np.array(imgs) / 255.0).astype(np.float32)
|
||||||
|
poses = np.array(poses).astype(np.float32)
|
||||||
|
counts.append(counts[-1] + imgs.shape[0])
|
||||||
|
all_imgs.append(imgs)
|
||||||
|
all_poses.append(poses)
|
||||||
|
|
||||||
|
i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
|
||||||
|
|
||||||
|
imgs = np.concatenate(all_imgs, 0)
|
||||||
|
poses = np.concatenate(all_poses, 0)
|
||||||
|
|
||||||
|
H, W = imgs[0].shape[:2]
|
||||||
|
camera_angle_x = float(meta["camera_angle_x"])
|
||||||
|
if focal_length_in_screen_space:
|
||||||
|
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
|
||||||
|
else:
|
||||||
|
focal = 1 / np.tan(0.5 * camera_angle_x)
|
||||||
|
|
||||||
|
render_poses = torch.stack(
|
||||||
|
[
|
||||||
|
torch.from_numpy(pose_spherical(angle, -30.0, 4.0))
|
||||||
|
for angle in np.linspace(-180, 180, 40 + 1)[:-1]
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# In debug mode, return extremely tiny images
|
||||||
|
if debug:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
H = H // 32
|
||||||
|
W = W // 32
|
||||||
|
if focal_length_in_screen_space:
|
||||||
|
focal = focal / 32.0
|
||||||
|
imgs = [
|
||||||
|
torch.from_numpy(
|
||||||
|
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
|
||||||
|
)
|
||||||
|
for i in range(imgs.shape[0])
|
||||||
|
]
|
||||||
|
imgs = torch.stack(imgs, 0)
|
||||||
|
poses = torch.from_numpy(poses)
|
||||||
|
return imgs, poses, render_poses, [H, W, focal], i_split
|
||||||
|
|
||||||
|
if half_res:
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# TODO: resize images using INTER_AREA (cv2)
|
||||||
|
H = H // 2
|
||||||
|
W = W // 2
|
||||||
|
if focal_length_in_screen_space:
|
||||||
|
focal = focal / 2.0
|
||||||
|
imgs = [
|
||||||
|
torch.from_numpy(
|
||||||
|
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
|
||||||
|
)
|
||||||
|
for i in range(imgs.shape[0])
|
||||||
|
]
|
||||||
|
imgs = torch.stack(imgs, 0)
|
||||||
|
|
||||||
|
poses = torch.from_numpy(poses)
|
||||||
|
|
||||||
|
return imgs, poses, render_poses, [H, W, focal], i_split
|
||||||
343
pytorch3d/implicitron/dataset/load_llff.py
Normal file
343
pytorch3d/implicitron/dataset/load_llff.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
# @lint-ignore-every LICENSELINT
|
||||||
|
# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py
|
||||||
|
# Copyright (c) 2020 bmild
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
# Slightly modified version of LLFF data loading code
|
||||||
|
# see https://github.com/Fyusion/LLFF for original
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _minify(basedir, path_manager, factors=(), resolutions=()):
|
||||||
|
needtoload = False
|
||||||
|
for r in factors:
|
||||||
|
imgdir = os.path.join(basedir, "images_{}".format(r))
|
||||||
|
if not _exists(path_manager, imgdir):
|
||||||
|
needtoload = True
|
||||||
|
for r in resolutions:
|
||||||
|
imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0]))
|
||||||
|
if not _exists(path_manager, imgdir):
|
||||||
|
needtoload = True
|
||||||
|
if not needtoload:
|
||||||
|
return
|
||||||
|
assert path_manager is None
|
||||||
|
|
||||||
|
from subprocess import check_output
|
||||||
|
|
||||||
|
imgdir = os.path.join(basedir, "images")
|
||||||
|
imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))]
|
||||||
|
imgs = [
|
||||||
|
f
|
||||||
|
for f in imgs
|
||||||
|
if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]])
|
||||||
|
]
|
||||||
|
imgdir_orig = imgdir
|
||||||
|
|
||||||
|
wd = os.getcwd()
|
||||||
|
|
||||||
|
for r in factors + resolutions:
|
||||||
|
if isinstance(r, int):
|
||||||
|
name = "images_{}".format(r)
|
||||||
|
resizearg = "{}%".format(100.0 / r)
|
||||||
|
else:
|
||||||
|
name = "images_{}x{}".format(r[1], r[0])
|
||||||
|
resizearg = "{}x{}".format(r[1], r[0])
|
||||||
|
imgdir = os.path.join(basedir, name)
|
||||||
|
if os.path.exists(imgdir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
logger.info(f"Minifying {r}, {basedir}")
|
||||||
|
|
||||||
|
os.makedirs(imgdir)
|
||||||
|
check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True)
|
||||||
|
|
||||||
|
ext = imgs[0].split(".")[-1]
|
||||||
|
args = " ".join(
|
||||||
|
["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)]
|
||||||
|
)
|
||||||
|
logger.info(args)
|
||||||
|
os.chdir(imgdir)
|
||||||
|
check_output(args, shell=True)
|
||||||
|
os.chdir(wd)
|
||||||
|
|
||||||
|
if ext != "png":
|
||||||
|
check_output("rm {}/*.{}".format(imgdir, ext), shell=True)
|
||||||
|
logger.info("Removed duplicates")
|
||||||
|
logger.info("Done")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_data(
|
||||||
|
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||||
|
):
|
||||||
|
|
||||||
|
poses_arr = np.load(
|
||||||
|
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||||
|
)
|
||||||
|
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
|
||||||
|
bds = poses_arr[:, -2:].transpose([1, 0])
|
||||||
|
|
||||||
|
img0 = [
|
||||||
|
os.path.join(basedir, "images", f)
|
||||||
|
for f in sorted(_ls(path_manager, os.path.join(basedir, "images")))
|
||||||
|
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
|
||||||
|
][0]
|
||||||
|
|
||||||
|
def imread(f):
|
||||||
|
return np.array(Image.open(f))
|
||||||
|
|
||||||
|
sh = imread(_local_path(path_manager, img0)).shape
|
||||||
|
|
||||||
|
sfx = ""
|
||||||
|
|
||||||
|
if factor is not None:
|
||||||
|
sfx = "_{}".format(factor)
|
||||||
|
_minify(basedir, path_manager, factors=[factor])
|
||||||
|
factor = factor
|
||||||
|
elif height is not None:
|
||||||
|
factor = sh[0] / float(height)
|
||||||
|
width = int(sh[1] / factor)
|
||||||
|
_minify(basedir, path_manager, resolutions=[[height, width]])
|
||||||
|
sfx = "_{}x{}".format(width, height)
|
||||||
|
elif width is not None:
|
||||||
|
factor = sh[1] / float(width)
|
||||||
|
height = int(sh[0] / factor)
|
||||||
|
_minify(basedir, path_manager, resolutions=[[height, width]])
|
||||||
|
sfx = "_{}x{}".format(width, height)
|
||||||
|
else:
|
||||||
|
factor = 1
|
||||||
|
|
||||||
|
imgdir = os.path.join(basedir, "images" + sfx)
|
||||||
|
if not _exists(path_manager, imgdir):
|
||||||
|
raise ValueError(f"{imgdir} does not exist, returning")
|
||||||
|
|
||||||
|
imgfiles = [
|
||||||
|
_local_path(path_manager, os.path.join(imgdir, f))
|
||||||
|
for f in sorted(_ls(path_manager, imgdir))
|
||||||
|
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
|
||||||
|
]
|
||||||
|
if poses.shape[-1] != len(imgfiles):
|
||||||
|
raise ValueError(
|
||||||
|
"Mismatch between imgs {} and poses {} !!!!".format(
|
||||||
|
len(imgfiles), poses.shape[-1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
sh = imread(imgfiles[0]).shape
|
||||||
|
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
|
||||||
|
poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor
|
||||||
|
|
||||||
|
if not load_imgs:
|
||||||
|
return poses, bds
|
||||||
|
|
||||||
|
imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles]
|
||||||
|
imgs = np.stack(imgs, -1)
|
||||||
|
|
||||||
|
logger.info(f"Loaded image data, shape {imgs.shape}")
|
||||||
|
return poses, bds, imgs
|
||||||
|
|
||||||
|
|
||||||
|
def normalize(x):
|
||||||
|
denom = np.linalg.norm(x)
|
||||||
|
if denom < 0.001:
|
||||||
|
warnings.warn("unsafe normalize()")
|
||||||
|
return x / denom
|
||||||
|
|
||||||
|
|
||||||
|
def viewmatrix(z, up, pos):
|
||||||
|
vec2 = normalize(z)
|
||||||
|
vec1_avg = up
|
||||||
|
vec0 = normalize(np.cross(vec1_avg, vec2))
|
||||||
|
vec1 = normalize(np.cross(vec2, vec0))
|
||||||
|
m = np.stack([vec0, vec1, vec2, pos], 1)
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
|
def ptstocam(pts, c2w):
|
||||||
|
tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
|
||||||
|
return tt
|
||||||
|
|
||||||
|
|
||||||
|
def poses_avg(poses):
|
||||||
|
|
||||||
|
hwf = poses[0, :3, -1:]
|
||||||
|
|
||||||
|
center = poses[:, :3, 3].mean(0)
|
||||||
|
vec2 = normalize(poses[:, :3, 2].sum(0))
|
||||||
|
up = poses[:, :3, 1].sum(0)
|
||||||
|
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
|
||||||
|
|
||||||
|
return c2w
|
||||||
|
|
||||||
|
|
||||||
|
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||||
|
render_poses = []
|
||||||
|
rads = np.array(list(rads) + [1.0])
|
||||||
|
hwf = c2w[:, 4:5]
|
||||||
|
|
||||||
|
for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
|
||||||
|
c = np.dot(
|
||||||
|
c2w[:3, :4],
|
||||||
|
np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
|
||||||
|
* rads,
|
||||||
|
)
|
||||||
|
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
|
||||||
|
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
|
||||||
|
return render_poses
|
||||||
|
|
||||||
|
|
||||||
|
def recenter_poses(poses):
|
||||||
|
|
||||||
|
poses_ = poses + 0
|
||||||
|
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||||
|
c2w = poses_avg(poses)
|
||||||
|
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
|
||||||
|
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
|
||||||
|
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
|
||||||
|
|
||||||
|
poses = np.linalg.inv(c2w) @ poses
|
||||||
|
poses_[:, :3, :4] = poses[:, :3, :4]
|
||||||
|
poses = poses_
|
||||||
|
return poses
|
||||||
|
|
||||||
|
|
||||||
|
def spherify_poses(poses, bds):
|
||||||
|
def add_row_to_homogenize_transform(p):
|
||||||
|
r"""Add the last row to homogenize 3 x 4 transformation matrices."""
|
||||||
|
return np.concatenate(
|
||||||
|
[p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# p34_to_44 = lambda p: np.concatenate(
|
||||||
|
# [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
|
||||||
|
# )
|
||||||
|
|
||||||
|
p34_to_44 = add_row_to_homogenize_transform
|
||||||
|
|
||||||
|
rays_d = poses[:, :3, 2:3]
|
||||||
|
rays_o = poses[:, :3, 3:4]
|
||||||
|
|
||||||
|
def min_line_dist(rays_o, rays_d):
|
||||||
|
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
|
||||||
|
b_i = -A_i @ rays_o
|
||||||
|
pt_mindist = np.squeeze(
|
||||||
|
-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
|
||||||
|
)
|
||||||
|
return pt_mindist
|
||||||
|
|
||||||
|
pt_mindist = min_line_dist(rays_o, rays_d)
|
||||||
|
|
||||||
|
center = pt_mindist
|
||||||
|
up = (poses[:, :3, 3] - center).mean(0)
|
||||||
|
|
||||||
|
vec0 = normalize(up)
|
||||||
|
vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0))
|
||||||
|
vec2 = normalize(np.cross(vec0, vec1))
|
||||||
|
pos = center
|
||||||
|
c2w = np.stack([vec1, vec2, vec0, pos], 1)
|
||||||
|
|
||||||
|
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
|
||||||
|
|
||||||
|
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
|
||||||
|
|
||||||
|
sc = 1.0 / rad
|
||||||
|
poses_reset[:, :3, 3] *= sc
|
||||||
|
bds *= sc
|
||||||
|
rad *= sc
|
||||||
|
|
||||||
|
centroid = np.mean(poses_reset[:, :3, 3], 0)
|
||||||
|
zh = centroid[2]
|
||||||
|
radcircle = np.sqrt(rad**2 - zh**2)
|
||||||
|
new_poses = []
|
||||||
|
|
||||||
|
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||||
|
|
||||||
|
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||||
|
up = np.array([0, 0, -1.0])
|
||||||
|
|
||||||
|
vec2 = normalize(camorigin)
|
||||||
|
vec0 = normalize(np.cross(vec2, up))
|
||||||
|
vec1 = normalize(np.cross(vec2, vec0))
|
||||||
|
pos = camorigin
|
||||||
|
p = np.stack([vec0, vec1, vec2, pos], 1)
|
||||||
|
|
||||||
|
new_poses.append(p)
|
||||||
|
|
||||||
|
new_poses = np.stack(new_poses, 0)
|
||||||
|
|
||||||
|
new_poses = np.concatenate(
|
||||||
|
[new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1
|
||||||
|
)
|
||||||
|
poses_reset = np.concatenate(
|
||||||
|
[
|
||||||
|
poses_reset[:, :3, :4],
|
||||||
|
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape),
|
||||||
|
],
|
||||||
|
-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
return poses_reset, new_poses, bds
|
||||||
|
|
||||||
|
|
||||||
|
def _local_path(path_manager, path):
|
||||||
|
if path_manager is None:
|
||||||
|
return path
|
||||||
|
return path_manager.get_local_path(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _ls(path_manager, path):
|
||||||
|
if path_manager is None:
|
||||||
|
return os.path.listdir(path)
|
||||||
|
return path_manager.ls(path)
|
||||||
|
|
||||||
|
|
||||||
|
def _exists(path_manager, path):
|
||||||
|
if path_manager is None:
|
||||||
|
return os.path.exists(path)
|
||||||
|
return path_manager.exists(path)
|
||||||
|
|
||||||
|
|
||||||
|
def load_llff_data(
|
||||||
|
basedir,
|
||||||
|
factor=8,
|
||||||
|
recenter=True,
|
||||||
|
bd_factor=0.75,
|
||||||
|
spherify=False,
|
||||||
|
path_zflat=False,
|
||||||
|
path_manager=None,
|
||||||
|
):
|
||||||
|
|
||||||
|
poses, bds, imgs = _load_data(
|
||||||
|
basedir, factor=factor, path_manager=path_manager
|
||||||
|
) # factor=8 downsamples original imgs by 8x
|
||||||
|
logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}")
|
||||||
|
|
||||||
|
# Correct rotation matrix ordering and move variable dim to axis 0
|
||||||
|
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
|
||||||
|
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
|
||||||
|
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
|
||||||
|
images = imgs
|
||||||
|
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
|
||||||
|
|
||||||
|
# Rescale if bd_factor is provided
|
||||||
|
sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor)
|
||||||
|
poses[:, :3, 3] *= sc
|
||||||
|
bds *= sc
|
||||||
|
|
||||||
|
if recenter:
|
||||||
|
poses = recenter_poses(poses)
|
||||||
|
|
||||||
|
if spherify:
|
||||||
|
poses, render_poses, bds = spherify_poses(poses, bds)
|
||||||
|
|
||||||
|
images = images.astype(np.float32)
|
||||||
|
poses = poses.astype(np.float32)
|
||||||
|
|
||||||
|
return images, poses, bds
|
||||||
@@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
from .implicitron_dataset import ImplicitronDatasetBase
|
from .dataset_base import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
||||||
@@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
of sequences.
|
of sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset: ImplicitronDatasetBase
|
dataset: DatasetBase
|
||||||
batch_size: int
|
batch_size: int
|
||||||
num_batches: int
|
num_batches: int
|
||||||
# the sampler first samples a random element k from this list and then
|
# the sampler first samples a random element k from this list and then
|
||||||
|
|||||||
205
pytorch3d/implicitron/dataset/single_sequence_dataset.py
Normal file
205
pytorch3d/implicitron/dataset/single_sequence_dataset.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
# This file defines a base class for dataset map providers which
|
||||||
|
# provide data for a single scene.
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Iterable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
Configurable,
|
||||||
|
expand_args_fields,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||||
|
|
||||||
|
from .dataset_base import DatasetBase, FrameData
|
||||||
|
from .dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
|
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
||||||
|
|
||||||
|
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
||||||
|
|
||||||
|
|
||||||
|
class SingleSceneDataset(DatasetBase, Configurable):
|
||||||
|
"""
|
||||||
|
A dataset from images from a single scene.
|
||||||
|
"""
|
||||||
|
|
||||||
|
images: List[torch.Tensor] = field()
|
||||||
|
fg_probabilities: Optional[List[torch.Tensor]] = field()
|
||||||
|
poses: List[PerspectiveCameras] = field()
|
||||||
|
object_name: str = field()
|
||||||
|
frame_types: List[str] = field()
|
||||||
|
eval_batches: Optional[List[List[int]]] = field()
|
||||||
|
|
||||||
|
def sequence_names(self) -> Iterable[str]:
|
||||||
|
return [_SINGLE_SEQUENCE_NAME]
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.poses)
|
||||||
|
|
||||||
|
def __getitem__(self, index) -> FrameData:
|
||||||
|
if index >= len(self):
|
||||||
|
raise IndexError(f"index {index} out of range {len(self)}")
|
||||||
|
image = self.images[index]
|
||||||
|
pose = self.poses[index]
|
||||||
|
frame_type = self.frame_types[index]
|
||||||
|
fg_probability = (
|
||||||
|
None if self.fg_probabilities is None else self.fg_probabilities[index]
|
||||||
|
)
|
||||||
|
|
||||||
|
frame_data = FrameData(
|
||||||
|
frame_number=index,
|
||||||
|
sequence_name=_SINGLE_SEQUENCE_NAME,
|
||||||
|
sequence_category=self.object_name,
|
||||||
|
camera=pose,
|
||||||
|
image_size_hw=torch.tensor(image.shape[1:]),
|
||||||
|
image_rgb=image,
|
||||||
|
fg_probability=fg_probability,
|
||||||
|
frame_type=frame_type,
|
||||||
|
)
|
||||||
|
return frame_data
|
||||||
|
|
||||||
|
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||||
|
return self.eval_batches
|
||||||
|
|
||||||
|
|
||||||
|
# pyre-fixme[13]: Uninitialized attribute
|
||||||
|
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||||
|
"""
|
||||||
|
Base for provider of data for one scene from LLFF or blender datasets.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
base_dir: directory holding the data for the scene.
|
||||||
|
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
|
||||||
|
It will typically be equal to the name of the directory self.base_dir.
|
||||||
|
path_manager_factory: Creates path manager which may be used for
|
||||||
|
interpreting paths.
|
||||||
|
n_known_frames_for_test: If set, training frames are included in the val
|
||||||
|
and test datasets, and this many random training frames are added to
|
||||||
|
each test batch. If not set, test batches each contain just a single
|
||||||
|
testing frame.
|
||||||
|
"""
|
||||||
|
|
||||||
|
base_dir: str
|
||||||
|
object_name: str
|
||||||
|
path_manager_factory: PathManagerFactory
|
||||||
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
n_known_frames_for_test: Optional[int] = None
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
run_auto_creation(self)
|
||||||
|
self._load_data()
|
||||||
|
|
||||||
|
def _load_data(self) -> None:
|
||||||
|
# This must be defined by each subclass,
|
||||||
|
# and should set the following on self.
|
||||||
|
# - poses: a list of length-1 camera objects
|
||||||
|
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||||
|
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||||
|
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _get_dataset(
|
||||||
|
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
||||||
|
) -> SingleSceneDataset:
|
||||||
|
expand_args_fields(SingleSceneDataset)
|
||||||
|
# pyre-ignore[16]
|
||||||
|
split = self.i_split[split_idx]
|
||||||
|
frame_types = [frame_type] * len(split)
|
||||||
|
fg_probabilities = (
|
||||||
|
None
|
||||||
|
# pyre-ignore[16]
|
||||||
|
if self.fg_probabilities is None
|
||||||
|
else self.fg_probabilities[split]
|
||||||
|
)
|
||||||
|
eval_batches = [[i] for i in range(len(split))]
|
||||||
|
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
||||||
|
train_split = self.i_split[0]
|
||||||
|
if set_eval_batches:
|
||||||
|
generator = np.random.default_rng(seed=0)
|
||||||
|
for batch in eval_batches:
|
||||||
|
# using permutation so that changes to n_known_frames_for_test
|
||||||
|
# result in consistent batches.
|
||||||
|
to_add = generator.permutation(len(train_split))[
|
||||||
|
: self.n_known_frames_for_test
|
||||||
|
]
|
||||||
|
batch.extend((to_add + len(split)).tolist())
|
||||||
|
split = np.concatenate([split, train_split])
|
||||||
|
frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
|
||||||
|
|
||||||
|
# pyre-ignore[28]
|
||||||
|
return SingleSceneDataset(
|
||||||
|
object_name=self.object_name,
|
||||||
|
# pyre-ignore[16]
|
||||||
|
images=self.images[split],
|
||||||
|
fg_probabilities=fg_probabilities,
|
||||||
|
# pyre-ignore[16]
|
||||||
|
poses=[self.poses[i] for i in split],
|
||||||
|
frame_types=frame_types,
|
||||||
|
eval_batches=eval_batches if set_eval_batches else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
return DatasetMap(
|
||||||
|
train=self._get_dataset(0, DATASET_TYPE_KNOWN),
|
||||||
|
val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
|
||||||
|
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
return Task.SINGLE_SEQUENCE
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||||
|
return join_cameras_as_batch(cameras)
|
||||||
|
|
||||||
|
|
||||||
|
def _interpret_blender_cameras(
|
||||||
|
poses: torch.Tensor, focal: float
|
||||||
|
) -> List[PerspectiveCameras]:
|
||||||
|
"""
|
||||||
|
Convert 4x4 matrices representing cameras in blender format
|
||||||
|
to PyTorch3D format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
poses: N x 3 x 4 camera matrices
|
||||||
|
focal: ndc space focal length
|
||||||
|
"""
|
||||||
|
pose_target_cameras = []
|
||||||
|
for pose_target in poses:
|
||||||
|
pose_target = pose_target[:3, :4]
|
||||||
|
mtx = torch.eye(4, dtype=pose_target.dtype)
|
||||||
|
mtx[:3, :3] = pose_target[:3, :3].t()
|
||||||
|
mtx[3, :3] = pose_target[:, 3]
|
||||||
|
mtx = mtx.inverse()
|
||||||
|
|
||||||
|
# flip the XZ coordinates.
|
||||||
|
mtx[:, [0, 2]] *= -1.0
|
||||||
|
|
||||||
|
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
|
||||||
|
|
||||||
|
focal_length_pt3 = torch.FloatTensor([[focal, focal]])
|
||||||
|
principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]])
|
||||||
|
|
||||||
|
cameras = PerspectiveCameras(
|
||||||
|
focal_length=focal_length_pt3,
|
||||||
|
principal_point=principal_point_pt3,
|
||||||
|
R=Rpt3[None],
|
||||||
|
T=Tpt3,
|
||||||
|
)
|
||||||
|
pose_target_cameras.append(cameras)
|
||||||
|
return pose_target_cameras
|
||||||
@@ -8,8 +8,8 @@
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import gzip
|
import gzip
|
||||||
import json
|
import json
|
||||||
from dataclasses import MISSING, Field, dataclass
|
from dataclasses import dataclass, Field, MISSING
|
||||||
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
|
from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pytorch3d.common.datatypes import get_args, get_origin
|
from pytorch3d.common.datatypes import get_args, get_origin
|
||||||
@@ -80,6 +80,7 @@ class FrameAnnotation:
|
|||||||
depth: Optional[DepthAnnotation] = None
|
depth: Optional[DepthAnnotation] = None
|
||||||
mask: Optional[MaskAnnotation] = None
|
mask: Optional[MaskAnnotation] = None
|
||||||
viewpoint: Optional[ViewpointAnnotation] = None
|
viewpoint: Optional[ViewpointAnnotation] = None
|
||||||
|
meta: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -169,9 +170,11 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
|
|
||||||
cls = get_origin(typeannot) or typeannot
|
cls = get_origin(typeannot) or typeannot
|
||||||
|
|
||||||
|
if typeannot is Any:
|
||||||
|
return dlist
|
||||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||||
return dlist
|
return dlist
|
||||||
elif any(obj is None for obj in dlist):
|
if any(obj is None for obj in dlist):
|
||||||
# filter out Nones and recurse on the resulting list
|
# filter out Nones and recurse on the resulting list
|
||||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||||
idx, notnone = zip(*idx_notnone)
|
idx, notnone = zip(*idx_notnone)
|
||||||
@@ -180,8 +183,13 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
for i, obj in zip(idx, converted):
|
for i, obj in zip(idx, converted):
|
||||||
res[i] = obj
|
res[i] = obj
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
is_optional, contained_type = _resolve_optional(typeannot)
|
||||||
|
if is_optional:
|
||||||
|
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||||
|
|
||||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||||
types = cls._field_types.values()
|
types = cls._field_types.values()
|
||||||
dlist_T = zip(*dlist)
|
dlist_T = zip(*dlist)
|
||||||
@@ -218,7 +226,7 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
|
|
||||||
keys = np.split(list(all_keys_res), indices[:-1])
|
keys = np.split(list(all_keys_res), indices[:-1])
|
||||||
vals = np.split(list(all_vals_res), indices[:-1])
|
vals = np.split(list(all_vals_res), indices[:-1])
|
||||||
return [cls(zip(*k, v)) for k, v in zip(keys, vals)]
|
return [cls(zip(k, v)) for k, v in zip(keys, vals)]
|
||||||
elif not dataclasses.is_dataclass(typeannot):
|
elif not dataclasses.is_dataclass(typeannot):
|
||||||
return dlist
|
return dlist
|
||||||
|
|
||||||
@@ -240,10 +248,15 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
|||||||
|
|
||||||
|
|
||||||
def _dataclass_from_dict(d, typeannot):
|
def _dataclass_from_dict(d, typeannot):
|
||||||
cls = get_origin(typeannot) or typeannot
|
if d is None or typeannot is Any:
|
||||||
if d is None:
|
|
||||||
return d
|
return d
|
||||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
is_optional, contained_type = _resolve_optional(typeannot)
|
||||||
|
if is_optional:
|
||||||
|
# an Optional not set to None, just use the contents of the Optional.
|
||||||
|
return _dataclass_from_dict(d, contained_type)
|
||||||
|
|
||||||
|
cls = get_origin(typeannot) or typeannot
|
||||||
|
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||||
types = cls._field_types.values()
|
types = cls._field_types.values()
|
||||||
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
||||||
elif issubclass(cls, (list, tuple)):
|
elif issubclass(cls, (list, tuple)):
|
||||||
@@ -315,3 +328,15 @@ def load_dataclass_jgzip(outfile, cls):
|
|||||||
"""
|
"""
|
||||||
with gzip.GzipFile(outfile, "rb") as f:
|
with gzip.GzipFile(outfile, "rb") as f:
|
||||||
return load_dataclass(cast(IO, f), cls, binary=True)
|
return load_dataclass(cast(IO, f), cls, binary=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||||
|
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||||
|
if get_origin(type_) is Union:
|
||||||
|
args = get_args(type_)
|
||||||
|
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||||
|
return True, args[0]
|
||||||
|
if type_ is Any:
|
||||||
|
return True, Any
|
||||||
|
|
||||||
|
return False, type_
|
||||||
|
|||||||
@@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known"
|
|||||||
DATASET_TYPE_UNKNOWN = "unseen"
|
DATASET_TYPE_UNKNOWN = "unseen"
|
||||||
|
|
||||||
|
|
||||||
|
def is_known_frame_scalar(frame_type: str) -> bool:
|
||||||
|
"""
|
||||||
|
Given a single frame type corresponding to a single frame, return whether
|
||||||
|
the frame is a known frame.
|
||||||
|
"""
|
||||||
|
return frame_type.endswith(DATASET_TYPE_KNOWN)
|
||||||
|
|
||||||
|
|
||||||
def is_known_frame(
|
def is_known_frame(
|
||||||
frame_type: List[str], device: Optional[str] = None
|
frame_type: List[str], device: Optional[str] = None
|
||||||
) -> torch.BoolTensor:
|
) -> torch.BoolTensor:
|
||||||
@@ -23,8 +31,9 @@ def is_known_frame(
|
|||||||
Given a list `frame_type` of frame types in a batch, return a tensor
|
Given a list `frame_type` of frame types in a batch, return a tensor
|
||||||
of boolean flags expressing whether the corresponding frame is a known frame.
|
of boolean flags expressing whether the corresponding frame is a known frame.
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
|
[is_known_frame_scalar(ft) for ft in frame_type],
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
@@ -37,6 +46,7 @@ def is_train_frame(
|
|||||||
Given a list `frame_type` of frame types in a batch, return a tensor
|
Given a list `frame_type` of frame types in a batch, return a tensor
|
||||||
of boolean flags expressing whether the corresponding frame is a training frame.
|
of boolean flags expressing whether the corresponding frame is a training frame.
|
||||||
"""
|
"""
|
||||||
|
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
[ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type],
|
[ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type],
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
|
|||||||
@@ -10,11 +10,12 @@ import torch
|
|||||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.structures import Pointclouds
|
||||||
|
|
||||||
from .implicitron_dataset import FrameData, ImplicitronDataset
|
from .dataset_base import FrameData
|
||||||
|
from .json_index_dataset import JsonIndexDataset
|
||||||
|
|
||||||
|
|
||||||
def get_implicitron_sequence_pointcloud(
|
def get_implicitron_sequence_pointcloud(
|
||||||
dataset: ImplicitronDataset,
|
dataset: JsonIndexDataset,
|
||||||
sequence_name: Optional[str] = None,
|
sequence_name: Optional[str] = None,
|
||||||
mask_points: bool = True,
|
mask_points: bool = True,
|
||||||
max_frames: int = -1,
|
max_frames: int = -1,
|
||||||
@@ -43,6 +44,7 @@ def get_implicitron_sequence_pointcloud(
|
|||||||
sequence_entries = [
|
sequence_entries = [
|
||||||
ei
|
ei
|
||||||
for ei in sequence_entries
|
for ei in sequence_entries
|
||||||
|
# pyre-ignore[16]
|
||||||
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
|
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
|
||||||
== sequence_name
|
== sequence_name
|
||||||
]
|
]
|
||||||
@@ -67,7 +69,7 @@ def get_implicitron_sequence_pointcloud(
|
|||||||
batch_size=len(sequence_dataset),
|
batch_size=len(sequence_dataset),
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=FrameData.collate,
|
collate_fn=dataset.frame_data_type.collate,
|
||||||
)
|
)
|
||||||
|
|
||||||
frame_data = next(iter(loader)) # there's only one batch
|
frame_data = next(iter(loader)) # there's only one batch
|
||||||
|
|||||||
@@ -5,21 +5,17 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
from typing import cast, Optional
|
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import lpips
|
import lpips
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||||
FrameData,
|
CO3D_CATEGORIES,
|
||||||
ImplicitronDataset,
|
|
||||||
ImplicitronDatasetBase,
|
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||||
aggregate_nvs_results,
|
aggregate_nvs_results,
|
||||||
eval_batch,
|
eval_batch,
|
||||||
@@ -47,10 +43,12 @@ def main() -> None:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
task_results = {}
|
task_results = {}
|
||||||
for task in ("singlesequence", "multisequence"):
|
for task in (Task.SINGLE_SEQUENCE, Task.MULTI_SEQUENCE):
|
||||||
task_results[task] = []
|
task_results[task] = []
|
||||||
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
|
for category in CO3D_CATEGORIES[: (20 if task == Task.SINGLE_SEQUENCE else 10)]:
|
||||||
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
|
for single_sequence_id in (
|
||||||
|
(0, 1) if task == Task.SINGLE_SEQUENCE else (None,)
|
||||||
|
):
|
||||||
category_result = evaluate_dbir_for_category(
|
category_result = evaluate_dbir_for_category(
|
||||||
category, task=task, single_sequence_id=single_sequence_id
|
category, task=task, single_sequence_id=single_sequence_id
|
||||||
)
|
)
|
||||||
@@ -74,9 +72,9 @@ def main() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_dbir_for_category(
|
def evaluate_dbir_for_category(
|
||||||
category: str = "apple",
|
category: str,
|
||||||
bg_color: float = 0.0,
|
task: Task,
|
||||||
task: str = "singlesequence",
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||||
single_sequence_id: Optional[int] = None,
|
single_sequence_id: Optional[int] = None,
|
||||||
num_workers: int = 16,
|
num_workers: int = 16,
|
||||||
):
|
):
|
||||||
@@ -90,6 +88,7 @@ def evaluate_dbir_for_category(
|
|||||||
task: Evaluation task. Either singlesequence or multisequence.
|
task: Evaluation task. Either singlesequence or multisequence.
|
||||||
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
|
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
|
||||||
num_workers: The number of workers for the employed dataloaders.
|
num_workers: The number of workers for the employed dataloaders.
|
||||||
|
path_manager: (optional) Used for interpreting paths.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
category_result: A dictionary of quantitative metrics.
|
category_result: A dictionary of quantitative metrics.
|
||||||
@@ -99,46 +98,35 @@ def evaluate_dbir_for_category(
|
|||||||
|
|
||||||
torch.manual_seed(42)
|
torch.manual_seed(42)
|
||||||
|
|
||||||
if task not in ["multisequence", "singlesequence"]:
|
dataset_map_provider_args = {
|
||||||
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
|
"category": category,
|
||||||
|
"dataset_root": os.environ["CO3D_DATASET_ROOT"],
|
||||||
datasets = dataset_zoo(
|
"assert_single_seq": task == Task.SINGLE_SEQUENCE,
|
||||||
category=category,
|
"task_str": task.value,
|
||||||
dataset_root=os.environ["CO3D_DATASET_ROOT"],
|
"test_on_train": False,
|
||||||
assert_single_seq=task == "singlesequence",
|
"test_restrict_sequence_id": single_sequence_id,
|
||||||
dataset_name=f"co3d_{task}",
|
"dataset_JsonIndexDataset_args": {"load_point_clouds": True},
|
||||||
test_on_train=False,
|
}
|
||||||
load_point_clouds=True,
|
data_source = ImplicitronDataSource(
|
||||||
test_restrict_sequence_id=single_sequence_id,
|
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args
|
||||||
)
|
)
|
||||||
|
|
||||||
dataloaders = dataloader_zoo(
|
datasets, dataloaders = data_source.get_datasets_and_dataloaders()
|
||||||
datasets,
|
|
||||||
dataset_name=f"co3d_{task}",
|
|
||||||
)
|
|
||||||
|
|
||||||
test_dataset = datasets["test"]
|
test_dataset = datasets.test
|
||||||
test_dataloader = dataloaders["test"]
|
test_dataloader = dataloaders.test
|
||||||
|
if test_dataset is None or test_dataloader is None:
|
||||||
|
raise ValueError("must have a test dataset.")
|
||||||
|
|
||||||
if task == "singlesequence":
|
image_size = cast(JsonIndexDataset, test_dataset).image_width
|
||||||
# all_source_cameras are needed for evaluation of the
|
|
||||||
# target camera difficulty
|
|
||||||
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
|
|
||||||
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
|
|
||||||
all_source_cameras = _get_all_source_cameras(
|
|
||||||
test_dataset, sequence_name, num_workers=num_workers
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
all_source_cameras = None
|
|
||||||
|
|
||||||
image_size = cast(ImplicitronDataset, test_dataset).image_width
|
|
||||||
|
|
||||||
if image_size is None:
|
if image_size is None:
|
||||||
raise ValueError("Image size should be set in the dataset")
|
raise ValueError("Image size should be set in the dataset")
|
||||||
|
|
||||||
# init the simple DBIR model
|
# init the simple DBIR model
|
||||||
model = ModelDBIR(
|
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
|
||||||
image_size=image_size,
|
render_image_width=image_size,
|
||||||
|
render_image_height=image_size,
|
||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
max_points=int(1e5),
|
max_points=int(1e5),
|
||||||
)
|
)
|
||||||
@@ -153,25 +141,31 @@ def evaluate_dbir_for_category(
|
|||||||
for frame_data in tqdm(test_dataloader):
|
for frame_data in tqdm(test_dataloader):
|
||||||
frame_data = dataclass_to_cuda_(frame_data)
|
frame_data = dataclass_to_cuda_(frame_data)
|
||||||
preds = model(**dataclasses.asdict(frame_data))
|
preds = model(**dataclasses.asdict(frame_data))
|
||||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
|
||||||
per_batch_eval_results.append(
|
per_batch_eval_results.append(
|
||||||
eval_batch(
|
eval_batch(
|
||||||
frame_data,
|
frame_data,
|
||||||
nvs_prediction,
|
preds["implicitron_render"],
|
||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
lpips_model=lpips_model,
|
lpips_model=lpips_model,
|
||||||
source_cameras=all_source_cameras,
|
source_cameras=data_source.all_train_cameras,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if task == Task.SINGLE_SEQUENCE:
|
||||||
|
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||||
|
else:
|
||||||
|
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||||
|
|
||||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||||
per_batch_eval_results, task
|
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||||
)
|
)
|
||||||
|
|
||||||
return category_result["results"]
|
return category_result["results"]
|
||||||
|
|
||||||
|
|
||||||
def _print_aggregate_results(task, task_results) -> None:
|
def _print_aggregate_results(
|
||||||
|
task: Task, task_results: Dict[Task, List[List[Dict[str, Any]]]]
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Prints the aggregate metrics for a given task.
|
Prints the aggregate metrics for a given task.
|
||||||
"""
|
"""
|
||||||
@@ -182,35 +176,5 @@ def _print_aggregate_results(task, task_results) -> None:
|
|||||||
print("")
|
print("")
|
||||||
|
|
||||||
|
|
||||||
def _get_all_source_cameras(
|
|
||||||
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Loads all training cameras of a given sequence.
|
|
||||||
|
|
||||||
The set of all seen cameras is needed for evaluating the viewpoint difficulty
|
|
||||||
for the singlescene evaluation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: Co3D dataset object.
|
|
||||||
sequence_name: The name of the sequence.
|
|
||||||
num_workers: The number of for the utilized dataloader.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# load all source cameras of the sequence
|
|
||||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
|
||||||
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
|
||||||
(all_frame_data,) = torch.utils.data.DataLoader(
|
|
||||||
dataset_for_loader,
|
|
||||||
shuffle=False,
|
|
||||||
batch_size=len(dataset_for_loader),
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=FrameData.collate,
|
|
||||||
)
|
|
||||||
is_known = is_known_frame(all_frame_data.frame_type)
|
|
||||||
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
|
||||||
return source_cameras
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -9,12 +9,15 @@ import copy
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
import torch.nn.functional as F
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import Task
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||||
|
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||||
from pytorch3d.implicitron.tools import vis_utils
|
from pytorch3d.implicitron.tools import vis_utils
|
||||||
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
|
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
|
||||||
from pytorch3d.implicitron.tools.image_utils import mask_background
|
from pytorch3d.implicitron.tools.image_utils import mask_background
|
||||||
@@ -31,18 +34,6 @@ from visdom import Visdom
|
|||||||
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class NewViewSynthesisPrediction:
|
|
||||||
"""
|
|
||||||
Holds the tensors that describe a result of synthesizing new views.
|
|
||||||
"""
|
|
||||||
|
|
||||||
depth_render: Optional[torch.Tensor] = None
|
|
||||||
image_render: Optional[torch.Tensor] = None
|
|
||||||
mask_render: Optional[torch.Tensor] = None
|
|
||||||
camera_distance: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _Visualizer:
|
class _Visualizer:
|
||||||
image_render: torch.Tensor
|
image_render: torch.Tensor
|
||||||
@@ -145,14 +136,14 @@ class _Visualizer:
|
|||||||
|
|
||||||
def eval_batch(
|
def eval_batch(
|
||||||
frame_data: FrameData,
|
frame_data: FrameData,
|
||||||
nvs_prediction: NewViewSynthesisPrediction,
|
implicitron_render: ImplicitronRender,
|
||||||
bg_color: Union[torch.Tensor, str, float] = "black",
|
bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
|
||||||
mask_thr: float = 0.5,
|
mask_thr: float = 0.5,
|
||||||
lpips_model=None,
|
lpips_model=None,
|
||||||
visualize: bool = False,
|
visualize: bool = False,
|
||||||
visualize_visdom_env: str = "eval_debug",
|
visualize_visdom_env: str = "eval_debug",
|
||||||
break_after_visualising: bool = True,
|
break_after_visualising: bool = True,
|
||||||
source_cameras: Optional[List[CamerasBase]] = None,
|
source_cameras: Optional[CamerasBase] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Produce performance metrics for a single batch of new-view synthesis
|
Produce performance metrics for a single batch of new-view synthesis
|
||||||
@@ -162,14 +153,14 @@ def eval_batch(
|
|||||||
is True), a new-view synthesis method (NVS) is tasked to generate new views
|
is True), a new-view synthesis method (NVS) is tasked to generate new views
|
||||||
of the scene from the viewpoint of the target views (for which
|
of the scene from the viewpoint of the target views (for which
|
||||||
frame_data.frame_type.endswith('known') is False). The resulting
|
frame_data.frame_type.endswith('known') is False). The resulting
|
||||||
synthesized new views, stored in `nvs_prediction`, are compared to the
|
synthesized new views, stored in `implicitron_render`, are compared to the
|
||||||
target ground truth in `frame_data` in terms of geometry and appearance
|
target ground truth in `frame_data` in terms of geometry and appearance
|
||||||
resulting in a dictionary of metrics returned by the `eval_batch` function.
|
resulting in a dictionary of metrics returned by the `eval_batch` function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
frame_data: A FrameData object containing the input to the new view
|
frame_data: A FrameData object containing the input to the new view
|
||||||
synthesis method.
|
synthesis method.
|
||||||
nvs_prediction: The data describing the synthesized new views.
|
implicitron_render: The data describing the synthesized new views.
|
||||||
bg_color: The background color of the generated new views and the
|
bg_color: The background color of the generated new views and the
|
||||||
ground truth.
|
ground truth.
|
||||||
lpips_model: A pre-trained model for evaluating the LPIPS metric.
|
lpips_model: A pre-trained model for evaluating the LPIPS metric.
|
||||||
@@ -184,26 +175,39 @@ def eval_batch(
|
|||||||
ValueError if frame_data does not have frame_type, camera, or image_rgb
|
ValueError if frame_data does not have frame_type, camera, or image_rgb
|
||||||
ValueError if the batch has a mix of training and test samples
|
ValueError if the batch has a mix of training and test samples
|
||||||
ValueError if the batch frames are not [unseen, known, known, ...]
|
ValueError if the batch frames are not [unseen, known, known, ...]
|
||||||
ValueError if one of the required fields in nvs_prediction is missing
|
ValueError if one of the required fields in implicitron_render is missing
|
||||||
"""
|
"""
|
||||||
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
|
|
||||||
frame_type = frame_data.frame_type
|
frame_type = frame_data.frame_type
|
||||||
if frame_type is None:
|
if frame_type is None:
|
||||||
raise ValueError("Frame type has not been set.")
|
raise ValueError("Frame type has not been set.")
|
||||||
|
|
||||||
# we check that all those fields are not None but Pyre can't infer that properly
|
# we check that all those fields are not None but Pyre can't infer that properly
|
||||||
# TODO: assign to local variables
|
# TODO: assign to local variables and simplify the code.
|
||||||
if frame_data.image_rgb is None:
|
if frame_data.image_rgb is None:
|
||||||
raise ValueError("Image is not in the evaluation batch.")
|
raise ValueError("Image is not in the evaluation batch.")
|
||||||
|
|
||||||
if frame_data.camera is None:
|
if frame_data.camera is None:
|
||||||
raise ValueError("Camera is not in the evaluation batch.")
|
raise ValueError("Camera is not in the evaluation batch.")
|
||||||
|
|
||||||
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
|
# eval all results in the resolution of the frame_data image
|
||||||
raise ValueError("One of the required predicted fields is missing")
|
image_resol = tuple(frame_data.image_rgb.shape[2:])
|
||||||
|
|
||||||
|
# Post-process the render:
|
||||||
|
# 1) check implicitron_render for Nones,
|
||||||
|
# 2) obtain copies to make sure we dont edit the original data,
|
||||||
|
# 3) take only the 1st (target) image
|
||||||
|
# 4) resize to match ground-truth resolution
|
||||||
|
cloned_render: Dict[str, torch.Tensor] = {}
|
||||||
|
for k in ["mask_render", "image_render", "depth_render"]:
|
||||||
|
field = getattr(implicitron_render, k)
|
||||||
|
if field is None:
|
||||||
|
raise ValueError(f"A required predicted field {k} is missing")
|
||||||
|
|
||||||
|
imode = "bilinear" if k == "image_render" else "nearest"
|
||||||
|
cloned_render[k] = (
|
||||||
|
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
|
||||||
|
)
|
||||||
|
|
||||||
# obtain copies to make sure we dont edit the original data
|
|
||||||
nvs_prediction = copy.deepcopy(nvs_prediction)
|
|
||||||
frame_data = copy.deepcopy(frame_data)
|
frame_data = copy.deepcopy(frame_data)
|
||||||
|
|
||||||
# mask the ground truth depth in case frame_data contains the depth mask
|
# mask the ground truth depth in case frame_data contains the depth mask
|
||||||
@@ -226,9 +230,6 @@ def eval_batch(
|
|||||||
+ " a target view while the rest should be source views."
|
+ " a target view while the rest should be source views."
|
||||||
) # TODO: do we need to enforce this?
|
) # TODO: do we need to enforce this?
|
||||||
|
|
||||||
# take only the first (target image)
|
|
||||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
|
||||||
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
|
|
||||||
for k in [
|
for k in [
|
||||||
"depth_map",
|
"depth_map",
|
||||||
"image_rgb",
|
"image_rgb",
|
||||||
@@ -242,10 +243,6 @@ def eval_batch(
|
|||||||
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
||||||
warnings.warn("Empty or missing depth map in evaluation!")
|
warnings.warn("Empty or missing depth map in evaluation!")
|
||||||
|
|
||||||
# eval all results in the resolution of the frame_data image
|
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `shape`.
|
|
||||||
image_resol = list(frame_data.image_rgb.shape[2:])
|
|
||||||
|
|
||||||
# threshold the masks to make ground truth binary masks
|
# threshold the masks to make ground truth binary masks
|
||||||
mask_fg, mask_crop = [
|
mask_fg, mask_crop = [
|
||||||
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
||||||
@@ -258,29 +255,14 @@ def eval_batch(
|
|||||||
bg_color=bg_color,
|
bg_color=bg_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# resize to the target resolution
|
|
||||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
|
||||||
imode = "bilinear" if k == "image_render" else "nearest"
|
|
||||||
val = getattr(nvs_prediction, k)
|
|
||||||
setattr(
|
|
||||||
nvs_prediction,
|
|
||||||
k,
|
|
||||||
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
|
|
||||||
# `List[typing.Any]`.
|
|
||||||
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
|
|
||||||
)
|
|
||||||
|
|
||||||
# clamp predicted images
|
# clamp predicted images
|
||||||
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
|
image_render = cloned_render["image_render"].clamp(0.0, 1.0)
|
||||||
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
|
|
||||||
|
|
||||||
if visualize:
|
if visualize:
|
||||||
visualizer = _Visualizer(
|
visualizer = _Visualizer(
|
||||||
image_render=image_render,
|
image_render=image_render,
|
||||||
image_rgb_masked=image_rgb_masked,
|
image_rgb_masked=image_rgb_masked,
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
|
depth_render=cloned_render["depth_render"],
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
depth_render=nvs_prediction.depth_render,
|
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||||
# `Optional[torch.Tensor]`.
|
# `Optional[torch.Tensor]`.
|
||||||
depth_map=frame_data.depth_map,
|
depth_map=frame_data.depth_map,
|
||||||
@@ -292,9 +274,7 @@ def eval_batch(
|
|||||||
results: Dict[str, Any] = {}
|
results: Dict[str, Any] = {}
|
||||||
|
|
||||||
results["iou"] = iou(
|
results["iou"] = iou(
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
cloned_render["mask_render"],
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
nvs_prediction.mask_render,
|
|
||||||
mask_fg,
|
mask_fg,
|
||||||
mask=mask_crop,
|
mask=mask_crop,
|
||||||
)
|
)
|
||||||
@@ -321,11 +301,9 @@ def eval_batch(
|
|||||||
if name_postfix == "_fg":
|
if name_postfix == "_fg":
|
||||||
# only record depth metrics for the foreground
|
# only record depth metrics for the foreground
|
||||||
_, abs_ = eval_depth(
|
_, abs_ = eval_depth(
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
cloned_render["depth_render"],
|
||||||
# `Optional[torch.Tensor]`.
|
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
|
||||||
nvs_prediction.depth_render,
|
# `Optional[Tensor]`.
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
|
|
||||||
# `Optional[torch.Tensor]`.
|
|
||||||
frame_data.depth_map,
|
frame_data.depth_map,
|
||||||
get_best_scale=True,
|
get_best_scale=True,
|
||||||
mask=loss_mask_now,
|
mask=loss_mask_now,
|
||||||
@@ -336,14 +314,14 @@ def eval_batch(
|
|||||||
if visualize:
|
if visualize:
|
||||||
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
||||||
if break_after_visualising:
|
if break_after_visualising:
|
||||||
import pdb
|
import pdb # noqa: B602
|
||||||
|
|
||||||
pdb.set_trace()
|
pdb.set_trace()
|
||||||
|
|
||||||
if lpips_model is not None:
|
if lpips_model is not None:
|
||||||
im1, im2 = [
|
im1, im2 = [
|
||||||
2.0 * im.clamp(0.0, 1.0) - 1.0
|
2.0 * im.clamp(0.0, 1.0) - 1.0
|
||||||
for im in (image_rgb_masked, nvs_prediction.image_render)
|
for im in (image_rgb_masked, cloned_render["image_render"])
|
||||||
]
|
]
|
||||||
results["lpips"] = lpips_model.forward(im1, im2).item()
|
results["lpips"] = lpips_model.forward(im1, im2).item()
|
||||||
|
|
||||||
@@ -426,30 +404,24 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
|
|||||||
Returns:
|
Returns:
|
||||||
single-element Tensor
|
single-element Tensor
|
||||||
"""
|
"""
|
||||||
# pyre-ignore[16] topk not recognized
|
|
||||||
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
||||||
|
|
||||||
|
|
||||||
def get_camera_difficulty_bin_edges(task: str):
|
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
|
||||||
"""
|
"""
|
||||||
Get the edges of camera difficulty bins.
|
Get the edges of camera difficulty bins.
|
||||||
"""
|
"""
|
||||||
_eps = 1e-5
|
_eps = 1e-5
|
||||||
if task == "multisequence":
|
lower, upper = camera_difficulty_bin_breaks
|
||||||
# TODO: extract those to constants
|
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
|
||||||
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
|
|
||||||
diff_bin_edges[0] = 0.0 - _eps
|
|
||||||
elif task == "singlesequence":
|
|
||||||
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No such eval task {task}.")
|
|
||||||
diff_bin_names = ["hard", "medium", "easy"]
|
diff_bin_names = ["hard", "medium", "easy"]
|
||||||
return diff_bin_edges, diff_bin_names
|
return diff_bin_edges, diff_bin_names
|
||||||
|
|
||||||
|
|
||||||
def summarize_nvs_eval_results(
|
def summarize_nvs_eval_results(
|
||||||
per_batch_eval_results: List[Dict[str, Any]],
|
per_batch_eval_results: List[Dict[str, Any]],
|
||||||
task: str = "singlesequence",
|
task: Task,
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||||
@@ -458,7 +430,8 @@ def summarize_nvs_eval_results(
|
|||||||
Args:
|
Args:
|
||||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||||
task: The type of the new-view synthesis task.
|
task: The type of the new-view synthesis task.
|
||||||
Either 'singlesequence' or 'multisequence'.
|
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||||
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nvs_results_flat: A flattened dict of all aggregate metrics.
|
nvs_results_flat: A flattened dict of all aggregate metrics.
|
||||||
@@ -466,10 +439,10 @@ def summarize_nvs_eval_results(
|
|||||||
"""
|
"""
|
||||||
n_batches = len(per_batch_eval_results)
|
n_batches = len(per_batch_eval_results)
|
||||||
eval_sets: List[Optional[str]] = []
|
eval_sets: List[Optional[str]] = []
|
||||||
if task == "singlesequence":
|
if task == Task.SINGLE_SEQUENCE:
|
||||||
eval_sets = [None]
|
eval_sets = [None]
|
||||||
# assert n_batches==100
|
# assert n_batches==100
|
||||||
elif task == "multisequence":
|
elif task == Task.MULTI_SEQUENCE:
|
||||||
eval_sets = ["train", "test"]
|
eval_sets = ["train", "test"]
|
||||||
# assert n_batches==1000
|
# assert n_batches==1000
|
||||||
else:
|
else:
|
||||||
@@ -485,17 +458,19 @@ def summarize_nvs_eval_results(
|
|||||||
# init the result database dict
|
# init the result database dict
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task)
|
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
|
||||||
|
camera_difficulty_bin_breaks
|
||||||
|
)
|
||||||
n_diff_edges = diff_bin_edges.numel()
|
n_diff_edges = diff_bin_edges.numel()
|
||||||
|
|
||||||
# add per set averages
|
# add per set averages
|
||||||
for SET in eval_sets:
|
for SET in eval_sets:
|
||||||
if SET is None:
|
if SET is None:
|
||||||
# task=='singlesequence'
|
assert task == Task.SINGLE_SEQUENCE
|
||||||
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
||||||
set_name = "test"
|
set_name = "test"
|
||||||
else:
|
else:
|
||||||
# task=='multisequence'
|
assert task == Task.MULTI_SEQUENCE
|
||||||
ok_set = is_train == int(SET == "train")
|
ok_set = is_train == int(SET == "train")
|
||||||
set_name = SET
|
set_name = SET
|
||||||
|
|
||||||
@@ -520,7 +495,7 @@ def summarize_nvs_eval_results(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if task == "multisequence":
|
if task == Task.MULTI_SEQUENCE:
|
||||||
# split based on n_src_views
|
# split based on n_src_views
|
||||||
n_src_views = batch_sizes - 1
|
n_src_views = batch_sizes - 1
|
||||||
for n_src in EVAL_N_SRC_VIEWS:
|
for n_src in EVAL_N_SRC_VIEWS:
|
||||||
|
|||||||
88
pytorch3d/implicitron/models/base_model.py
Normal file
88
pytorch3d/implicitron/models/base_model.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .renderer.base import EvaluationMode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ImplicitronRender:
|
||||||
|
"""
|
||||||
|
Holds the tensors that describe a result of rendering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
depth_render: Optional[torch.Tensor] = None
|
||||||
|
image_render: Optional[torch.Tensor] = None
|
||||||
|
mask_render: Optional[torch.Tensor] = None
|
||||||
|
camera_distance: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def clone(self) -> "ImplicitronRender":
|
||||||
|
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||||
|
return t.detach().clone() if t is not None else None
|
||||||
|
|
||||||
|
return ImplicitronRender(
|
||||||
|
depth_render=safe_clone(self.depth_render),
|
||||||
|
image_render=safe_clone(self.image_render),
|
||||||
|
mask_render=safe_clone(self.mask_render),
|
||||||
|
camera_distance=safe_clone(self.camera_distance),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImplicitronModelBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Replaceable abstract base for all image generation / rendering models.
|
||||||
|
`forward()` method produces a render with a depth map.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*, # force keyword-only arguments
|
||||||
|
image_rgb: Optional[torch.Tensor],
|
||||||
|
camera: CamerasBase,
|
||||||
|
fg_probability: Optional[torch.Tensor],
|
||||||
|
mask_crop: Optional[torch.Tensor],
|
||||||
|
depth_map: Optional[torch.Tensor],
|
||||||
|
sequence_name: Optional[List[str]],
|
||||||
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
|
||||||
|
the first `min(B, n_train_target_views)` images are considered targets and
|
||||||
|
are used to supervise the renders; the rest corresponding to the source
|
||||||
|
viewpoints from which features will be extracted.
|
||||||
|
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
|
||||||
|
to the viewpoints of target images, from which the rays will be sampled,
|
||||||
|
and source images, which will be used for intersecting with target rays.
|
||||||
|
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||||
|
foreground masks.
|
||||||
|
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||||
|
regions in the input images (i.e. regions that do not correspond
|
||||||
|
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||||
|
"mask_sample", rays will be sampled in the non zero regions.
|
||||||
|
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
||||||
|
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||||
|
from which images `image_rgb` were extracted. They are used to match
|
||||||
|
target frames with relevant source frames.
|
||||||
|
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||||
|
EvaluationMode.EVALUATION which determines the settings used for
|
||||||
|
rendering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
preds: A dictionary containing all outputs of the forward pass. All models should
|
||||||
|
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from .feature_extractor import FeatureExtractorBase
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Base class for an extractor of a set of features from images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_feat_dims(self) -> int:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
total number of feature dimensions of the output.
|
||||||
|
(i.e. sum_i(dim_i))
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
imgs: Optional[torch.Tensor],
|
||||||
|
masks: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[Any, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
imgs: A batch of input images of shape `(B, 3, H, W)`.
|
||||||
|
masks: A batch of input masks of shape `(B, 3, H, W)`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
|
||||||
|
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@@ -4,7 +4,6 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
@@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as Fu
|
import torch.nn.functional as Fu
|
||||||
import torchvision
|
import torchvision
|
||||||
from pytorch3d.implicitron.tools.config import Configurable
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
|
||||||
|
from . import FeatureExtractorBase
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
|
|||||||
_RESNET_STD = [0.229, 0.224, 0.225]
|
_RESNET_STD = [0.229, 0.224, 0.225]
|
||||||
|
|
||||||
|
|
||||||
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
@registry.register
|
||||||
|
class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||||
"""
|
"""
|
||||||
Implements an image feature extractor. Depending on the settings allows
|
Implements an image feature extractor. Depending on the settings allows
|
||||||
to extract:
|
to extract:
|
||||||
@@ -141,14 +143,15 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||||
return (img - self._resnet_mean) / self._resnet_std
|
return (img - self._resnet_mean) / self._resnet_std
|
||||||
|
|
||||||
def get_feat_dims(self, size_dict: bool = False):
|
def get_feat_dims(self) -> int:
|
||||||
if size_dict:
|
# pyre-fixme[29]
|
||||||
return copy.deepcopy(self._feat_dim)
|
|
||||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na...
|
|
||||||
return sum(self._feat_dim.values())
|
return sum(self._feat_dim.values())
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
|
self,
|
||||||
|
imgs: Optional[torch.Tensor],
|
||||||
|
masks: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
) -> Dict[Any, torch.Tensor]:
|
) -> Dict[Any, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -163,23 +166,22 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
out_feats = {}
|
out_feats = {}
|
||||||
|
|
||||||
imgs_input = imgs
|
imgs_input = imgs
|
||||||
if self.image_rescale != 1.0:
|
if self.image_rescale != 1.0 and imgs_input is not None:
|
||||||
imgs_resized = Fu.interpolate(
|
imgs_resized = Fu.interpolate(
|
||||||
imgs_input,
|
imgs_input,
|
||||||
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
|
|
||||||
# got `float`.
|
|
||||||
scale_factor=self.image_rescale,
|
scale_factor=self.image_rescale,
|
||||||
mode="bilinear",
|
mode="bilinear",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
imgs_resized = imgs_input
|
imgs_resized = imgs_input
|
||||||
|
|
||||||
|
if len(self.stages) > 0:
|
||||||
|
assert imgs_resized is not None
|
||||||
|
|
||||||
if self.normalize_image:
|
if self.normalize_image:
|
||||||
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
||||||
else:
|
else:
|
||||||
imgs_normed = imgs_resized
|
imgs_normed = imgs_resized
|
||||||
|
|
||||||
if len(self.stages) > 0:
|
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
|
||||||
# is not a function.
|
# is not a function.
|
||||||
feats = self.stem(imgs_normed)
|
feats = self.stem(imgs_normed)
|
||||||
@@ -206,7 +208,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
|||||||
out_feats[MASK_FEATURE_NAME] = masks
|
out_feats[MASK_FEATURE_NAME] = masks
|
||||||
|
|
||||||
if self.add_images:
|
if self.add_images:
|
||||||
assert imgs_input is not None
|
assert imgs_resized is not None
|
||||||
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
|
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
|
||||||
|
|
||||||
if self.feature_rescale != 1.0:
|
if self.feature_rescale != 1.0:
|
||||||
@@ -5,26 +5,39 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The #noqa comments below are for unused imports of pluggable implementations
|
||||||
|
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
from pytorch3d.implicitron.models.metrics import ( # noqa
|
||||||
NewViewSynthesisPrediction,
|
RegularizationMetrics,
|
||||||
|
RegularizationMetricsBase,
|
||||||
|
ViewMetrics,
|
||||||
|
ViewMetricsBase,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
|
||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
|
|
||||||
from .autodecoder import Autodecoder
|
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||||
|
from .feature_extractor import FeatureExtractorBase
|
||||||
|
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
||||||
|
from .global_encoder.global_encoder import GlobalEncoderBase
|
||||||
from .implicit_function.base import ImplicitFunctionBase
|
from .implicit_function.base import ImplicitFunctionBase
|
||||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||||
from .implicit_function.neural_radiance_field import ( # noqa
|
from .implicit_function.neural_radiance_field import ( # noqa
|
||||||
@@ -35,7 +48,7 @@ from .implicit_function.scene_representation_networks import ( # noqa
|
|||||||
SRNHyperNetImplicitFunction,
|
SRNHyperNetImplicitFunction,
|
||||||
SRNImplicitFunction,
|
SRNImplicitFunction,
|
||||||
)
|
)
|
||||||
from .metrics import ViewMetrics
|
|
||||||
from .renderer.base import (
|
from .renderer.base import (
|
||||||
BaseRenderer,
|
BaseRenderer,
|
||||||
EvaluationMode,
|
EvaluationMode,
|
||||||
@@ -45,19 +58,16 @@ from .renderer.base import (
|
|||||||
)
|
)
|
||||||
from .renderer.lstm_renderer import LSTMRenderer # noqa
|
from .renderer.lstm_renderer import LSTMRenderer # noqa
|
||||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
||||||
from .renderer.ray_sampler import RaySampler
|
from .renderer.ray_sampler import RaySamplerBase
|
||||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||||
from .resnet_feature_extractor import ResNetFeatureExtractor
|
from .view_pooler.view_pooler import ViewPooler
|
||||||
from .view_pooling.feature_aggregation import FeatureAggregatorBase
|
|
||||||
from .view_pooling.view_sampling import ViewSampler
|
|
||||||
|
|
||||||
|
|
||||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# pyre-ignore: 13
|
@registry.register
|
||||||
class GenericModel(Configurable, torch.nn.Module):
|
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||||
"""
|
"""
|
||||||
GenericModel is a wrapper for the neural implicit
|
GenericModel is a wrapper for the neural implicit
|
||||||
rendering and reconstruction pipeline which consists
|
rendering and reconstruction pipeline which consists
|
||||||
@@ -98,6 +108,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
------------------
|
------------------
|
||||||
Evaluate the implicit function(s) at the sampled ray points
|
Evaluate the implicit function(s) at the sampled ray points
|
||||||
(optionally pass in the aggregated image features from (4)).
|
(optionally pass in the aggregated image features from (4)).
|
||||||
|
(also optionally pass in a global encoding from global_encoder).
|
||||||
│
|
│
|
||||||
▼
|
▼
|
||||||
(6) Rendering
|
(6) Rendering
|
||||||
@@ -116,7 +127,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
this sequence of steps. Currently, steps 1, 3, 4, 5, 6
|
this sequence of steps. Currently, steps 1, 3, 4, 5, 6
|
||||||
can be customized by intializing a subclass of the appropriate
|
can be customized by intializing a subclass of the appropriate
|
||||||
baseclass and adding the newly created module to the registry.
|
baseclass and adding the newly created module to the registry.
|
||||||
Please see https://github.com/fairinternal/pytorch3d/blob/co3d/projects/implicitron_trainer/README.md#custom-plugins
|
Please see https://github.com/facebookresearch/pytorch3d/blob/main/projects/implicitron_trainer/README.md#custom-plugins
|
||||||
for more details on how to create and register a custom component.
|
for more details on how to create and register a custom component.
|
||||||
|
|
||||||
In the config .yaml files for experiments, the parameters below are
|
In the config .yaml files for experiments, the parameters below are
|
||||||
@@ -138,9 +149,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
||||||
splatting onto an image grid. Default: False.
|
splatting onto an image grid. Default: False.
|
||||||
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
||||||
view_pool: If True, features are sampled from the source image(s)
|
|
||||||
at the projected 2d locations of the sampled 3d ray points from the target
|
|
||||||
view(s), i.e. this activates step (3) above.
|
|
||||||
num_passes: The specified implicit_function is initialized num_passes
|
num_passes: The specified implicit_function is initialized num_passes
|
||||||
times and run sequentially.
|
times and run sequentially.
|
||||||
chunk_size_grid: The total number of points which can be rendered
|
chunk_size_grid: The total number of points which can be rendered
|
||||||
@@ -155,37 +163,49 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
sampling_mode_training: The sampling method to use during training. Must be
|
sampling_mode_training: The sampling method to use during training. Must be
|
||||||
a value from the RenderSamplingMode Enum.
|
a value from the RenderSamplingMode Enum.
|
||||||
sampling_mode_evaluation: Same as above but for evaluation.
|
sampling_mode_evaluation: Same as above but for evaluation.
|
||||||
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
|
global_encoder_class_type: The name of the class to use for global_encoder,
|
||||||
|
which must be available in the registry. Or `None` to disable global encoder.
|
||||||
|
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
|
||||||
of the image (referred to as the global_code) that can be used to model aspects of
|
of the image (referred to as the global_code) that can be used to model aspects of
|
||||||
the scene such as multiple objects or morphing objects. It is up to the implicit
|
the scene such as multiple objects or morphing objects. It is up to the implicit
|
||||||
function definition how to use it, but the most typical way is to broadcast and
|
function definition how to use it, but the most typical way is to broadcast and
|
||||||
concatenate to the other inputs for the implicit function.
|
concatenate to the other inputs for the implicit function.
|
||||||
|
raysampler_class_type: The name of the raysampler class which is available
|
||||||
|
in the global registry.
|
||||||
raysampler: An instance of RaySampler which is used to emit
|
raysampler: An instance of RaySampler which is used to emit
|
||||||
rays from the target view(s).
|
rays from the target view(s).
|
||||||
renderer_class_type: The name of the renderer class which is available in the global
|
renderer_class_type: The name of the renderer class which is available in the global
|
||||||
registry.
|
registry.
|
||||||
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
||||||
generate the images from the target view(s).
|
generate the images from the target view(s).
|
||||||
|
image_feature_extractor_class_type: If a str, constructs and enables
|
||||||
|
the `image_feature_extractor` object of this type. Or None if not needed.
|
||||||
image_feature_extractor: A module for extrating features from an input image.
|
image_feature_extractor: A module for extrating features from an input image.
|
||||||
view_sampler: An instance of ViewSampler which is used for sampling of
|
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
|
||||||
|
This means features are sampled from the source image(s)
|
||||||
|
at the projected 2d locations of the sampled 3d ray points from the target
|
||||||
|
view(s), i.e. this activates step (3) above.
|
||||||
|
view_pooler: An instance of ViewPooler which is used for sampling of
|
||||||
image-based features at the 2D projections of a set
|
image-based features at the 2D projections of a set
|
||||||
of 3D points.
|
of 3D points and aggregating the sampled features.
|
||||||
feature_aggregator_class_type: The name of the feature aggregator class which
|
|
||||||
is available in the global registry.
|
|
||||||
feature_aggregator: A feature aggregator class which inherits from
|
|
||||||
FeatureAggregatorBase. Typically, the aggregated features and their
|
|
||||||
masks are output by a `ViewSampler` which samples feature tensors extracted
|
|
||||||
from a set of source images. FeatureAggregator executes step (4) above.
|
|
||||||
implicit_function_class_type: The type of implicit function to use which
|
implicit_function_class_type: The type of implicit function to use which
|
||||||
is available in the global registry.
|
is available in the global registry.
|
||||||
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
||||||
are initialised to be in self._implicit_functions.
|
are initialised to be in self._implicit_functions.
|
||||||
|
view_metrics: An instance of ViewMetricsBase used to compute loss terms which
|
||||||
|
are independent of the model's parameters.
|
||||||
|
view_metrics_class_type: The type of view metrics to use, must be available in
|
||||||
|
the global registry.
|
||||||
|
regularization_metrics: An instance of RegularizationMetricsBase used to compute
|
||||||
|
regularization terms which can depend on the model's parameters.
|
||||||
|
regularization_metrics_class_type: The type of regularization metrics to use,
|
||||||
|
must be available in the global registry.
|
||||||
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
||||||
for `ViewMetrics` class for available loss functions.
|
for `ViewMetrics` class for available loss functions.
|
||||||
log_vars: A list of variable names which should be logged.
|
log_vars: A list of variable names which should be logged.
|
||||||
The names should correspond to a subset of the keys of the
|
The names should correspond to a subset of the keys of the
|
||||||
dict `preds` output by the `forward` function.
|
dict `preds` output by the `forward` function.
|
||||||
"""
|
""" # noqa: B950
|
||||||
|
|
||||||
mask_images: bool = True
|
mask_images: bool = True
|
||||||
mask_depths: bool = True
|
mask_depths: bool = True
|
||||||
@@ -194,7 +214,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
mask_threshold: float = 0.5
|
mask_threshold: float = 0.5
|
||||||
output_rasterized_mc: bool = False
|
output_rasterized_mc: bool = False
|
||||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
view_pool: bool = False
|
|
||||||
num_passes: int = 1
|
num_passes: int = 1
|
||||||
chunk_size_grid: int = 4096
|
chunk_size_grid: int = 4096
|
||||||
render_features_dimensions: int = 3
|
render_features_dimensions: int = 3
|
||||||
@@ -204,23 +223,25 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
sampling_mode_training: str = "mask_sample"
|
sampling_mode_training: str = "mask_sample"
|
||||||
sampling_mode_evaluation: str = "full_grid"
|
sampling_mode_evaluation: str = "full_grid"
|
||||||
|
|
||||||
# ---- autodecoder settings
|
# ---- global encoder settings
|
||||||
sequence_autodecoder: Autodecoder
|
global_encoder_class_type: Optional[str] = None
|
||||||
|
global_encoder: Optional[GlobalEncoderBase]
|
||||||
|
|
||||||
# ---- raysampler
|
# ---- raysampler
|
||||||
raysampler: RaySampler
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
|
raysampler: RaySamplerBase
|
||||||
|
|
||||||
# ---- renderer configs
|
# ---- renderer configs
|
||||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
renderer: BaseRenderer
|
renderer: BaseRenderer
|
||||||
|
|
||||||
# ---- view sampling settings - used if view_pool=True
|
# ---- image feature extractor settings
|
||||||
# (This is only created if view_pool is False)
|
# (This is only created if view_pooler is enabled)
|
||||||
image_feature_extractor: ResNetFeatureExtractor
|
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||||
view_sampler: ViewSampler
|
image_feature_extractor_class_type: Optional[str] = None
|
||||||
# ---- ---- view sampling feature aggregator settings
|
# ---- view pooler settings
|
||||||
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
view_pooler_enabled: bool = False
|
||||||
feature_aggregator: FeatureAggregatorBase
|
view_pooler: Optional[ViewPooler]
|
||||||
|
|
||||||
# ---- implicit function settings
|
# ---- implicit function settings
|
||||||
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||||
@@ -228,6 +249,13 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
# The actual implicit functions live in self._implicit_functions
|
# The actual implicit functions live in self._implicit_functions
|
||||||
implicit_function: ImplicitFunctionBase
|
implicit_function: ImplicitFunctionBase
|
||||||
|
|
||||||
|
# ----- metrics
|
||||||
|
view_metrics: ViewMetricsBase
|
||||||
|
view_metrics_class_type: str = "ViewMetrics"
|
||||||
|
|
||||||
|
regularization_metrics: RegularizationMetricsBase
|
||||||
|
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||||
|
|
||||||
# ---- loss weights
|
# ---- loss weights
|
||||||
loss_weights: Dict[str, float] = field(
|
loss_weights: Dict[str, float] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
@@ -259,19 +287,21 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
"loss_prev_stage_rgb_psnr_fg",
|
"loss_prev_stage_rgb_psnr_fg",
|
||||||
"loss_prev_stage_rgb_psnr",
|
"loss_prev_stage_rgb_psnr",
|
||||||
"loss_prev_stage_mask_bce",
|
"loss_prev_stage_mask_bce",
|
||||||
*STD_LOG_VARS,
|
# basic metrics
|
||||||
|
"objective",
|
||||||
|
"epoch",
|
||||||
|
"sec/it",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.view_metrics = ViewMetrics()
|
|
||||||
|
|
||||||
self._check_and_preprocess_renderer_configs()
|
if self.view_pooler_enabled:
|
||||||
self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
|
if self.image_feature_extractor_class_type is None:
|
||||||
self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
|
raise ValueError(
|
||||||
self.raysampler_args["image_width"] = self.render_image_width
|
"image_feature_extractor must be present for view pooling."
|
||||||
self.raysampler_args["image_height"] = self.render_image_height
|
)
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
self._implicit_functions = self._construct_implicit_functions()
|
self._implicit_functions = self._construct_implicit_functions()
|
||||||
@@ -283,10 +313,11 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
*, # force keyword-only arguments
|
*, # force keyword-only arguments
|
||||||
image_rgb: Optional[torch.Tensor],
|
image_rgb: Optional[torch.Tensor],
|
||||||
camera: CamerasBase,
|
camera: CamerasBase,
|
||||||
fg_probability: Optional[torch.Tensor],
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
mask_crop: Optional[torch.Tensor],
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
depth_map: Optional[torch.Tensor],
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
sequence_name: Optional[List[str]],
|
sequence_name: Optional[List[str]] = None,
|
||||||
|
frame_timestamp: Optional[torch.Tensor] = None,
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
@@ -309,6 +340,8 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
sequence_name: A list of `B` strings corresponding to the sequence names
|
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||||
from which images `image_rgb` were extracted. They are used to match
|
from which images `image_rgb` were extracted. They are used to match
|
||||||
target frames with relevant source frames.
|
target frames with relevant source frames.
|
||||||
|
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
|
||||||
|
of frame timestamps.
|
||||||
evaluation_mode: one of EvaluationMode.TRAINING or
|
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||||
EvaluationMode.EVALUATION which determines the settings used for
|
EvaluationMode.EVALUATION which determines the settings used for
|
||||||
rendering.
|
rendering.
|
||||||
@@ -333,6 +366,13 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
else min(self.n_train_target_views, batch_size)
|
else min(self.n_train_target_views, batch_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# A helper function for selecting n_target first elements from the input
|
||||||
|
# where the latter can be None.
|
||||||
|
def safe_slice_targets(
|
||||||
|
tensor: Optional[Union[torch.Tensor, List[str]]],
|
||||||
|
) -> Optional[Union[torch.Tensor, List[str]]]:
|
||||||
|
return None if tensor is None else tensor[:n_targets]
|
||||||
|
|
||||||
# Select the target cameras.
|
# Select the target cameras.
|
||||||
target_cameras = camera[list(range(n_targets))]
|
target_cameras = camera[list(range(n_targets))]
|
||||||
|
|
||||||
@@ -344,7 +384,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# (1) Sample rendering rays with the ray sampler.
|
# (1) Sample rendering rays with the ray sampler.
|
||||||
ray_bundle: RayBundle = self.raysampler(
|
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
|
||||||
target_cameras,
|
target_cameras,
|
||||||
evaluation_mode,
|
evaluation_mode,
|
||||||
mask=mask_crop[:n_targets]
|
mask=mask_crop[:n_targets]
|
||||||
@@ -355,38 +395,37 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
# custom_args hold additional arguments to the implicit function.
|
# custom_args hold additional arguments to the implicit function.
|
||||||
custom_args = {}
|
custom_args = {}
|
||||||
|
|
||||||
if self.view_pool:
|
if self.image_feature_extractor is not None:
|
||||||
if sequence_name is None:
|
|
||||||
raise ValueError("sequence_name must be provided for view pooling")
|
|
||||||
# (2) Extract features for the image
|
# (2) Extract features for the image
|
||||||
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
|
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
|
||||||
|
else:
|
||||||
|
img_feats = None
|
||||||
|
|
||||||
# (3) Sample features and masks at the ray points
|
if self.view_pooler_enabled:
|
||||||
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
|
if sequence_name is None:
|
||||||
|
raise ValueError("sequence_name must be provided for view pooling")
|
||||||
|
assert img_feats is not None
|
||||||
|
|
||||||
|
# (3-4) Sample features and masks at the ray points.
|
||||||
|
# Aggregate features from multiple views.
|
||||||
|
def curried_viewpooler(pts):
|
||||||
|
return self.view_pooler(
|
||||||
pts=pts,
|
pts=pts,
|
||||||
seq_id_pts=sequence_name[:n_targets],
|
seq_id_pts=sequence_name[:n_targets],
|
||||||
camera=camera,
|
camera=camera,
|
||||||
seq_id_camera=sequence_name,
|
seq_id_camera=sequence_name,
|
||||||
feats=img_feats,
|
feats=img_feats,
|
||||||
masks=mask_crop,
|
masks=mask_crop,
|
||||||
) # returns feats_sampled, masks_sampled
|
)
|
||||||
|
|
||||||
# (4) Aggregate features from multiple views
|
custom_args["fun_viewpool"] = curried_viewpooler
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
|
||||||
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
|
|
||||||
*curried_view_sampler(pts=pts),
|
|
||||||
pts=pts,
|
|
||||||
camera=camera,
|
|
||||||
) # TODO: do we need to pass a callback rather than compute here?
|
|
||||||
# precomputing will be faster for 2 passes
|
|
||||||
# -> but this is important for non-nerf
|
|
||||||
custom_args["fun_viewpool"] = curried_view_pool
|
|
||||||
|
|
||||||
global_code = None
|
global_code = None
|
||||||
if self.sequence_autodecoder.n_instances > 0:
|
if self.global_encoder is not None:
|
||||||
if sequence_name is None:
|
global_code = self.global_encoder( # pyre-fixme[29]
|
||||||
raise ValueError("sequence_name must be provided for autodecoder.")
|
sequence_name=safe_slice_targets(sequence_name),
|
||||||
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
|
frame_timestamp=safe_slice_targets(frame_timestamp),
|
||||||
|
)
|
||||||
custom_args["global_code"] = global_code
|
custom_args["global_code"] = global_code
|
||||||
|
|
||||||
# pyre-fixme[29]:
|
# pyre-fixme[29]:
|
||||||
@@ -422,15 +461,26 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.unbind_args()
|
func.unbind_args()
|
||||||
|
|
||||||
preds = self._get_view_metrics(
|
# A dict to store losses as well as rendering results.
|
||||||
|
preds: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
preds.update(
|
||||||
|
self.view_metrics(
|
||||||
|
results=preds,
|
||||||
raymarched=rendered,
|
raymarched=rendered,
|
||||||
xys=ray_bundle.xys,
|
xys=ray_bundle.xys,
|
||||||
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
|
image_rgb=safe_slice_targets(image_rgb),
|
||||||
depth_map=None if depth_map is None else depth_map[:n_targets],
|
depth_map=safe_slice_targets(depth_map),
|
||||||
fg_probability=None
|
fg_probability=safe_slice_targets(fg_probability),
|
||||||
if fg_probability is None
|
mask_crop=safe_slice_targets(mask_crop),
|
||||||
else fg_probability[:n_targets],
|
)
|
||||||
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
|
)
|
||||||
|
|
||||||
|
preds.update(
|
||||||
|
self.regularization_metrics(
|
||||||
|
results=preds,
|
||||||
|
model=self,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
||||||
@@ -452,7 +502,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
||||||
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
preds["nvs_prediction"] = NewViewSynthesisPrediction(
|
preds["implicitron_render"] = ImplicitronRender(
|
||||||
image_render=preds["images_render"],
|
image_render=preds["images_render"],
|
||||||
depth_render=preds["depths_render"],
|
depth_render=preds["depths_render"],
|
||||||
mask_render=preds["masks_render"],
|
mask_render=preds["masks_render"],
|
||||||
@@ -460,11 +510,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
raise AssertionError("Unreachable state")
|
raise AssertionError("Unreachable state")
|
||||||
|
|
||||||
# calc the AD penalty, returns None if autodecoder is not active
|
|
||||||
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
|
|
||||||
if ad_penalty is not None:
|
|
||||||
preds["loss_autodecoder_norm"] = ad_penalty
|
|
||||||
|
|
||||||
# (7) Compute losses
|
# (7) Compute losses
|
||||||
# finally get the optimization objective using self.loss_weights
|
# finally get the optimization objective using self.loss_weights
|
||||||
objective = self._get_objective(preds)
|
objective = self._get_objective(preds)
|
||||||
@@ -559,37 +604,64 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_viewpooled_feature_dim(self):
|
def _get_global_encoder_encoding_dim(self) -> int:
|
||||||
return (
|
if self.global_encoder is None:
|
||||||
self.feature_aggregator.get_aggregated_feature_dim(
|
return 0
|
||||||
|
return self.global_encoder.get_encoding_dim()
|
||||||
|
|
||||||
|
def _get_viewpooled_feature_dim(self) -> int:
|
||||||
|
if self.view_pooler is None:
|
||||||
|
return 0
|
||||||
|
assert self.image_feature_extractor is not None
|
||||||
|
return self.view_pooler.get_aggregated_feature_dim(
|
||||||
self.image_feature_extractor.get_feat_dims()
|
self.image_feature_extractor.get_feat_dims()
|
||||||
)
|
)
|
||||||
if self.view_pool
|
|
||||||
else 0
|
def create_raysampler(self):
|
||||||
|
raysampler_args = getattr(
|
||||||
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(
|
||||||
|
raysampler_args, "sampling_mode_training", self.sampling_mode_training
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(
|
||||||
|
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
|
||||||
|
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
|
||||||
|
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||||
|
**raysampler_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_and_preprocess_renderer_configs(self):
|
def create_renderer(self):
|
||||||
|
raysampler_args = getattr(
|
||||||
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
|
)
|
||||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||||
"stratified_sampling_coarse_training"
|
"stratified_sampling_coarse_training"
|
||||||
] = self.raysampler_args["stratified_point_sampling_training"]
|
] = raysampler_args["stratified_point_sampling_training"]
|
||||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||||
"stratified_sampling_coarse_evaluation"
|
"stratified_sampling_coarse_evaluation"
|
||||||
] = self.raysampler_args["stratified_point_sampling_evaluation"]
|
] = raysampler_args["stratified_point_sampling_evaluation"]
|
||||||
self.renderer_SignedDistanceFunctionRenderer_args[
|
self.renderer_SignedDistanceFunctionRenderer_args[
|
||||||
"render_features_dimensions"
|
"render_features_dimensions"
|
||||||
] = self.render_features_dimensions
|
] = self.render_features_dimensions
|
||||||
|
|
||||||
|
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||||
|
if "scene_extent" not in raysampler_args:
|
||||||
|
raise ValueError(
|
||||||
|
"SignedDistanceFunctionRenderer requires"
|
||||||
|
+ " a raysampler that defines the 'scene_extent' field"
|
||||||
|
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||||
|
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||||
|
)
|
||||||
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
||||||
"object_bounding_sphere"
|
"object_bounding_sphere"
|
||||||
] = self.raysampler_args["scene_extent"]
|
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
|
||||||
|
|
||||||
def create_image_feature_extractor(self):
|
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||||
"""
|
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||||
Custom creation function called by run_auto_creation so that the
|
**renderer_args
|
||||||
image_feature_extractor is not created if it is not be needed.
|
|
||||||
"""
|
|
||||||
if self.view_pool:
|
|
||||||
self.image_feature_extractor = ResNetFeatureExtractor(
|
|
||||||
**self.image_feature_extractor_args
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_implicit_function(self) -> None:
|
def create_implicit_function(self) -> None:
|
||||||
@@ -613,8 +685,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
||||||
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
||||||
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
||||||
self._get_viewpooled_feature_dim()
|
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||||
+ self.sequence_autodecoder.get_encoding_dim()
|
|
||||||
)
|
)
|
||||||
nerf_args["color_dim"] = nerformer_args[
|
nerf_args["color_dim"] = nerformer_args[
|
||||||
"color_dim"
|
"color_dim"
|
||||||
@@ -623,27 +694,25 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
# idr preprocessing
|
# idr preprocessing
|
||||||
idr = self.implicit_function_IdrFeatureField_args
|
idr = self.implicit_function_IdrFeatureField_args
|
||||||
idr["feature_vector_size"] = self.render_features_dimensions
|
idr["feature_vector_size"] = self.render_features_dimensions
|
||||||
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
|
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||||
|
|
||||||
# srn preprocessing
|
# srn preprocessing
|
||||||
srn = self.implicit_function_SRNImplicitFunction_args
|
srn = self.implicit_function_SRNImplicitFunction_args
|
||||||
srn.raymarch_function_args.latent_dim = (
|
srn.raymarch_function_args.latent_dim = (
|
||||||
self._get_viewpooled_feature_dim()
|
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||||
+ self.sequence_autodecoder.get_encoding_dim()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# srn_hypernet preprocessing
|
# srn_hypernet preprocessing
|
||||||
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
||||||
srn_hypernet_args = srn_hypernet.hypernet_args
|
srn_hypernet_args = srn_hypernet.hypernet_args
|
||||||
srn_hypernet_args.latent_dim_hypernet = (
|
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
||||||
self.sequence_autodecoder.get_encoding_dim()
|
|
||||||
)
|
|
||||||
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
||||||
|
|
||||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||||
implicit_function_type = registry.get(
|
implicit_function_type = registry.get(
|
||||||
ImplicitFunctionBase, self.implicit_function_class_type
|
ImplicitFunctionBase, self.implicit_function_class_type
|
||||||
)
|
)
|
||||||
|
expand_args_fields(implicit_function_type)
|
||||||
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
|
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
self.implicit_function_class_type
|
self.implicit_function_class_type
|
||||||
@@ -651,10 +720,9 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if implicit_function_type.requires_pooling_without_aggregation():
|
if implicit_function_type.requires_pooling_without_aggregation():
|
||||||
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
|
if self.view_pooler_enabled and self.view_pooler.has_aggregation():
|
||||||
if not self.view_pool or has_aggregation:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Chosen implicit function requires view pooling without aggregation."
|
"The chosen implicit function requires view pooling without aggregation."
|
||||||
)
|
)
|
||||||
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
|
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
|
||||||
config = getattr(self, config_name, None)
|
config = getattr(self, config_name, None)
|
||||||
@@ -697,6 +765,17 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Modified image_rgb, fg_mask, depth_map
|
Modified image_rgb, fg_mask, depth_map
|
||||||
"""
|
"""
|
||||||
|
if image_rgb is not None and image_rgb.ndim == 3:
|
||||||
|
# The FrameData object is used for both frames and batches of frames,
|
||||||
|
# and a user might get this error if those were confused.
|
||||||
|
# Perhaps a user has a FrameData `fd` representing a single frame and
|
||||||
|
# wrote something like `model(**fd)` instead of
|
||||||
|
# `model(**fd.collate([fd]))`.
|
||||||
|
raise ValueError(
|
||||||
|
"Model received unbatched inputs. "
|
||||||
|
+ "Perhaps they came from a FrameData which had not been collated."
|
||||||
|
)
|
||||||
|
|
||||||
fg_mask = fg_probability
|
fg_mask = fg_probability
|
||||||
if fg_mask is not None and self.mask_threshold > 0.0:
|
if fg_mask is not None and self.mask_threshold > 0.0:
|
||||||
# threshold masks
|
# threshold masks
|
||||||
@@ -720,45 +799,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
return image_rgb, fg_mask, depth_map
|
return image_rgb, fg_mask, depth_map
|
||||||
|
|
||||||
def _get_view_metrics(
|
|
||||||
self,
|
|
||||||
raymarched: RendererOutput,
|
|
||||||
xys: torch.Tensor,
|
|
||||||
image_rgb: Optional[torch.Tensor] = None,
|
|
||||||
depth_map: Optional[torch.Tensor] = None,
|
|
||||||
fg_probability: Optional[torch.Tensor] = None,
|
|
||||||
mask_crop: Optional[torch.Tensor] = None,
|
|
||||||
keys_prefix: str = "loss_",
|
|
||||||
):
|
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
|
||||||
metrics = self.view_metrics(
|
|
||||||
image_sampling_grid=xys,
|
|
||||||
images_pred=raymarched.features,
|
|
||||||
images=image_rgb,
|
|
||||||
depths_pred=raymarched.depths,
|
|
||||||
depths=depth_map,
|
|
||||||
masks_pred=raymarched.masks,
|
|
||||||
masks=fg_probability,
|
|
||||||
masks_crop=mask_crop,
|
|
||||||
keys_prefix=keys_prefix,
|
|
||||||
**raymarched.aux,
|
|
||||||
)
|
|
||||||
|
|
||||||
if raymarched.prev_stage:
|
|
||||||
metrics.update(
|
|
||||||
self._get_view_metrics(
|
|
||||||
raymarched.prev_stage,
|
|
||||||
xys,
|
|
||||||
image_rgb,
|
|
||||||
depth_map,
|
|
||||||
fg_probability,
|
|
||||||
mask_crop,
|
|
||||||
keys_prefix=(keys_prefix + "prev_stage_"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _rasterize_mc_samples(
|
def _rasterize_mc_samples(
|
||||||
self,
|
self,
|
||||||
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
@@ -12,10 +12,9 @@ import torch
|
|||||||
from pytorch3d.implicitron.tools.config import Configurable
|
from pytorch3d.implicitron.tools.config import Configurable
|
||||||
|
|
||||||
|
|
||||||
# TODO: probabilistic embeddings?
|
|
||||||
class Autodecoder(Configurable, torch.nn.Module):
|
class Autodecoder(Configurable, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Autodecoder module
|
Autodecoder which maps a list of integer or string keys to optimizable embeddings.
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
encoding_dim: Embedding dimension for the decoder.
|
encoding_dim: Embedding dimension for the decoder.
|
||||||
@@ -43,37 +42,37 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
# weight has been initialised from Normal(0, 1)
|
# weight has been initialised from Normal(0, 1)
|
||||||
self._autodecoder_codes.weight *= self.init_scale
|
self._autodecoder_codes.weight *= self.init_scale
|
||||||
|
|
||||||
self._sequence_map = self._build_sequence_map()
|
self._key_map = self._build_key_map()
|
||||||
# Make sure to register hooks for correct handling of saving/loading
|
# Make sure to register hooks for correct handling of saving/loading
|
||||||
# the module's _sequence_map.
|
# the module's _key_map.
|
||||||
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
|
self._register_load_state_dict_pre_hook(self._load_key_map_hook)
|
||||||
self._register_state_dict_hook(_save_sequence_map_hook)
|
self._register_state_dict_hook(_save_key_map_hook)
|
||||||
|
|
||||||
def _build_sequence_map(
|
def _build_key_map(
|
||||||
self, sequence_map_dict: Optional[Dict[str, int]] = None
|
self, key_map_dict: Optional[Dict[str, int]] = None
|
||||||
) -> Dict[str, int]:
|
) -> Dict[str, int]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
sequence_map_dict: A dictionary used to initialize the sequence_map.
|
key_map_dict: A dictionary used to initialize the key_map.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
sequence_map: a dictionary of key: id pairs.
|
key_map: a dictionary of key: id pairs.
|
||||||
"""
|
"""
|
||||||
# increments the counter when asked for a new value
|
# increments the counter when asked for a new value
|
||||||
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
|
key_map = defaultdict(iter(range(self.n_instances)).__next__)
|
||||||
if sequence_map_dict is not None:
|
if key_map_dict is not None:
|
||||||
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
|
# Assign all keys from the loaded key_map_dict to self._key_map.
|
||||||
# Since this is done in the original order, it should generate
|
# Since this is done in the original order, it should generate
|
||||||
# the same set of key:id pairs. We check this with an assert to be sure.
|
# the same set of key:id pairs. We check this with an assert to be sure.
|
||||||
for x, x_id in sequence_map_dict.items():
|
for x, x_id in key_map_dict.items():
|
||||||
x_id_ = sequence_map[x]
|
x_id_ = key_map[x]
|
||||||
assert x_id == x_id_
|
assert x_id == x_id_
|
||||||
return sequence_map
|
return key_map
|
||||||
|
|
||||||
def calc_squared_encoding_norm(self):
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
return None
|
return None
|
||||||
return (self._autodecoder_codes.weight ** 2).mean()
|
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
||||||
|
|
||||||
def get_encoding_dim(self) -> int:
|
def get_encoding_dim(self) -> int:
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
@@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: A batch of `N` sequence identifiers. Either a long tensor of size
|
x: A batch of `N` identifiers. Either a long tensor of size
|
||||||
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
|
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
|
||||||
are hashed to codes (without collisions).
|
are hashed to codes (without collisions).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
||||||
sequence-specific autodecoder codes.
|
key-specific autodecoder codes.
|
||||||
"""
|
"""
|
||||||
if self.n_instances == 0:
|
if self.n_instances == 0:
|
||||||
return None
|
return None
|
||||||
@@ -99,19 +98,21 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
if isinstance(x[0], str):
|
if isinstance(x[0], str):
|
||||||
try:
|
try:
|
||||||
|
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||||
|
# `Tensor`.
|
||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
# pyre-ignore[29]
|
# pyre-ignore[29]
|
||||||
[self._sequence_map[elem] for elem in x],
|
[self._key_map[elem] for elem in x],
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=next(self.parameters()).device,
|
device=next(self.parameters()).device,
|
||||||
)
|
)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise ValueError("Not enough n_instances in the autodecoder")
|
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||||
|
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
return self._autodecoder_codes(x)
|
return self._autodecoder_codes(x)
|
||||||
|
|
||||||
def _load_sequence_map_hook(
|
def _load_key_map_hook(
|
||||||
self,
|
self,
|
||||||
state_dict,
|
state_dict,
|
||||||
prefix,
|
prefix,
|
||||||
@@ -140,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
:meth:`~torch.nn.Module.load_state_dict`
|
:meth:`~torch.nn.Module.load_state_dict`
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Constructed sequence_map if it exists in the state_dict
|
Constructed key_map if it exists in the state_dict
|
||||||
else raises a warning only.
|
else raises a warning only.
|
||||||
"""
|
"""
|
||||||
sequence_map_key = prefix + "_sequence_map"
|
key_map_key = prefix + "_key_map"
|
||||||
if sequence_map_key in state_dict:
|
if key_map_key in state_dict:
|
||||||
sequence_map_dict = state_dict.pop(sequence_map_key)
|
key_map_dict = state_dict.pop(key_map_key)
|
||||||
self._sequence_map = self._build_sequence_map(
|
self._key_map = self._build_key_map(key_map_dict=key_map_dict)
|
||||||
sequence_map_dict=sequence_map_dict
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
warnings.warn("No sequence map in Autodecoder state dict!")
|
warnings.warn("No key map in Autodecoder state dict!")
|
||||||
|
|
||||||
|
|
||||||
def _save_sequence_map_hook(
|
def _save_key_map_hook(
|
||||||
self,
|
self,
|
||||||
state_dict,
|
state_dict,
|
||||||
prefix,
|
prefix,
|
||||||
@@ -167,6 +166,6 @@ def _save_sequence_map_hook(
|
|||||||
module
|
module
|
||||||
local_metadata (dict): a dict containing the metadata for this module.
|
local_metadata (dict): a dict containing the metadata for this module.
|
||||||
"""
|
"""
|
||||||
sequence_map_key = prefix + "_sequence_map"
|
key_map_key = prefix + "_key_map"
|
||||||
sequence_map_dict = dict(self._sequence_map.items())
|
key_map_dict = dict(self._key_map.items())
|
||||||
state_dict[sequence_map_key] = sequence_map_dict
|
state_dict[key_map_key] = key_map_dict
|
||||||
111
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
111
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
registry,
|
||||||
|
ReplaceableBase,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||||
|
|
||||||
|
from .autodecoder import Autodecoder
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalEncoderBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
A base class for implementing encoders of global frame-specific quantities.
|
||||||
|
|
||||||
|
The latter includes e.g. the harmonic encoding of a frame timestamp
|
||||||
|
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
|
||||||
|
(`SequenceAutodecoder`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def get_encoding_dim(self):
|
||||||
|
"""
|
||||||
|
Returns the dimensionality of the returned encoding.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Calculates the squared norm of the encoding to report as the
|
||||||
|
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def forward(self, **kwargs) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Given a set of inputs to encode, generates a tensor containing the encoding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
encoding: The tensor containing the global encoding.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: probabilistic embeddings?
|
||||||
|
@registry.register
|
||||||
|
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
|
||||||
|
"""
|
||||||
|
A global encoder implementation which provides an autodecoder encoding
|
||||||
|
of the frame's sequence identifier.
|
||||||
|
"""
|
||||||
|
|
||||||
|
autodecoder: Autodecoder
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def get_encoding_dim(self):
|
||||||
|
return self.autodecoder.get_encoding_dim()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
# run dtype checks and pass sequence_name to self.autodecoder
|
||||||
|
return self.autodecoder(sequence_name)
|
||||||
|
|
||||||
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
return self.autodecoder.calculate_squared_encoding_norm()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
A global encoder implementation which provides harmonic embeddings
|
||||||
|
of each frame's timestamp.
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_harmonic_functions: int = 10
|
||||||
|
append_input: bool = True
|
||||||
|
time_divisor: float = 1.0
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self._harmonic_embedding = HarmonicEmbedding(
|
||||||
|
n_harmonic_functions=self.n_harmonic_functions,
|
||||||
|
append_input=self.append_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_encoding_dim(self):
|
||||||
|
return self._harmonic_embedding.get_output_dim(1)
|
||||||
|
|
||||||
|
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
|
if frame_timestamp.shape[-1] != 1:
|
||||||
|
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||||
|
time = frame_timestamp / self.time_divisor
|
||||||
|
return self._harmonic_embedding(time) # pyre-ignore: 29
|
||||||
|
|
||||||
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
return None
|
||||||
@@ -3,7 +3,7 @@
|
|||||||
# implicit_differentiable_renderer.py
|
# implicit_differentiable_renderer.py
|
||||||
# Copyright (c) 2020 Lior Yariv
|
# Copyright (c) 2020 Lior Yariv
|
||||||
import math
|
import math
|
||||||
from typing import Sequence
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
@@ -15,13 +15,48 @@ from .base import ImplicitFunctionBase
|
|||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implicit function as used in http://github.com/lioryariv/idr.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
d_in: dimension of the input point.
|
||||||
|
n_harmonic_functions_xyz: If -1, do not embed the point.
|
||||||
|
If >=0, use a harmonic embedding with this number of
|
||||||
|
harmonic functions. (The harmonic embedding includes the input
|
||||||
|
itself, so a value of 0 means the point is used but without
|
||||||
|
any harmonic functions.)
|
||||||
|
d_out and feature_vector_size: Sum of these is the output
|
||||||
|
dimension. This implicit function thus returns a concatenation
|
||||||
|
of `d_out` signed distance function values and `feature_vector_size`
|
||||||
|
features (such as colors). When used in `GenericModel`,
|
||||||
|
`feature_vector_size` corresponds is automatically set to
|
||||||
|
`render_features_dimensions`.
|
||||||
|
dims: list of hidden layer sizes.
|
||||||
|
geometric_init: whether to use custom weight initialization
|
||||||
|
in linear layers. If False, pytorch default (uniform sampling)
|
||||||
|
is used.
|
||||||
|
bias: if geometric_init=True, initial value for bias subtracted
|
||||||
|
in the last layer.
|
||||||
|
skip_in: List of indices of layers that receive as input the initial
|
||||||
|
value concatenated with the output of the previous layers.
|
||||||
|
weight_norm: whether to apply weight normalization to each layer.
|
||||||
|
pooled_feature_dim: If view pooling is in use (provided as
|
||||||
|
fun_viewpool to forward()) this must be its number of features.
|
||||||
|
Otherwise this must be set to 0. (If used from GenericModel,
|
||||||
|
this config value will be overridden automatically.)
|
||||||
|
encoding_dim: If global coding is in use (provided as global_code
|
||||||
|
to forward()) this must be its number of featuress.
|
||||||
|
Otherwise this must be set to 0. (If used from GenericModel,
|
||||||
|
this config value will be overridden automatically.)
|
||||||
|
"""
|
||||||
|
|
||||||
feature_vector_size: int = 3
|
feature_vector_size: int = 3
|
||||||
d_in: int = 3
|
d_in: int = 3
|
||||||
d_out: int = 1
|
d_out: int = 1
|
||||||
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
|
dims: Tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512, 512)
|
||||||
geometric_init: bool = True
|
geometric_init: bool = True
|
||||||
bias: float = 1.0
|
bias: float = 1.0
|
||||||
skip_in: Sequence[int] = ()
|
skip_in: Tuple[int, ...] = ()
|
||||||
weight_norm: bool = True
|
weight_norm: bool = True
|
||||||
n_harmonic_functions_xyz: int = 0
|
n_harmonic_functions_xyz: int = 0
|
||||||
pooled_feature_dim: int = 0
|
pooled_feature_dim: int = 0
|
||||||
@@ -33,7 +68,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
||||||
|
|
||||||
self.embed_fn = None
|
self.embed_fn = None
|
||||||
if self.n_harmonic_functions_xyz > 0:
|
if self.n_harmonic_functions_xyz >= 0:
|
||||||
self.embed_fn = HarmonicEmbedding(
|
self.embed_fn = HarmonicEmbedding(
|
||||||
self.n_harmonic_functions_xyz, append_input=True
|
self.n_harmonic_functions_xyz, append_input=True
|
||||||
)
|
)
|
||||||
@@ -59,23 +94,23 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
if layer_idx == self.num_layers - 2:
|
if layer_idx == self.num_layers - 2:
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
lin.weight,
|
lin.weight,
|
||||||
mean=math.pi ** 0.5 / dims[layer_idx] ** 0.5,
|
mean=math.pi**0.5 / dims[layer_idx] ** 0.5,
|
||||||
std=0.0001,
|
std=0.0001,
|
||||||
)
|
)
|
||||||
torch.nn.init.constant_(lin.bias, -self.bias)
|
torch.nn.init.constant_(lin.bias, -self.bias)
|
||||||
elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0:
|
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
lin.weight[:, :3], 0.0, 2 ** 0.5 / out_dim ** 0.5
|
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
|
||||||
)
|
)
|
||||||
elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in:
|
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
|
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
|
||||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
||||||
else:
|
else:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
|
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
|
||||||
|
|
||||||
if self.weight_norm:
|
if self.weight_norm:
|
||||||
lin = nn.utils.weight_norm(lin)
|
lin = nn.utils.weight_norm(lin)
|
||||||
@@ -103,38 +138,43 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
self.embed_fn is None and fun_viewpool is None and global_code is None
|
self.embed_fn is None and fun_viewpool is None and global_code is None
|
||||||
):
|
):
|
||||||
return torch.tensor(
|
return torch.tensor(
|
||||||
[], device=rays_points_world.device, dtype=rays_points_world.dtype
|
[],
|
||||||
|
device=rays_points_world.device,
|
||||||
|
dtype=rays_points_world.dtype
|
||||||
|
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[Module,
|
||||||
|
# Tensor]`.
|
||||||
).view(0, self.out_dim)
|
).view(0, self.out_dim)
|
||||||
|
|
||||||
embedding = None
|
embeddings = []
|
||||||
if self.embed_fn is not None:
|
if self.embed_fn is not None:
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
embedding = self.embed_fn(rays_points_world)
|
embeddings.append(self.embed_fn(rays_points_world))
|
||||||
|
|
||||||
if fun_viewpool is not None:
|
if fun_viewpool is not None:
|
||||||
assert rays_points_world.ndim == 2
|
assert rays_points_world.ndim == 2
|
||||||
pooled_feature = fun_viewpool(rays_points_world[None])
|
pooled_feature = fun_viewpool(rays_points_world[None])
|
||||||
# TODO: pooled features are 4D!
|
# TODO: pooled features are 4D!
|
||||||
embedding = torch.cat((embedding, pooled_feature), dim=-1)
|
embeddings.append(pooled_feature)
|
||||||
|
|
||||||
if global_code is not None:
|
if global_code is not None:
|
||||||
assert embedding.ndim == 2
|
|
||||||
assert global_code.shape[0] == 1 # TODO: generalize to batches!
|
assert global_code.shape[0] == 1 # TODO: generalize to batches!
|
||||||
# This will require changing raytracer code
|
# This will require changing raytracer code
|
||||||
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
|
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
|
||||||
embedding = torch.cat(
|
embeddings.append(
|
||||||
(embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)),
|
global_code[0, None, :].expand(rays_points_world.shape[0], -1)
|
||||||
dim=-1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embedding = torch.cat(embeddings, dim=-1)
|
||||||
x = embedding
|
x = embedding
|
||||||
|
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s...
|
||||||
for layer_idx in range(self.num_layers - 1):
|
for layer_idx in range(self.num_layers - 1):
|
||||||
if layer_idx in self.skip_in:
|
if layer_idx in self.skip_in:
|
||||||
x = torch.cat([x, embedding], dim=-1) / 2 ** 0.5
|
x = torch.cat([x, embedding], dim=-1) / 2**0.5
|
||||||
|
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
x = self.linear_layers[layer_idx](x)
|
x = self.linear_layers[layer_idx](x)
|
||||||
|
|
||||||
|
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
|
||||||
if layer_idx < self.num_layers - 2:
|
if layer_idx < self.num_layers - 2:
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
x = self.softplus(x)
|
x = self.softplus(x)
|
||||||
|
|||||||
@@ -5,8 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import field
|
from typing import Optional, Tuple
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
@@ -206,7 +205,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
|
|||||||
transformer_dim_down_factor: float = 1.0
|
transformer_dim_down_factor: float = 1.0
|
||||||
n_hidden_neurons_xyz: int = 256
|
n_hidden_neurons_xyz: int = 256
|
||||||
n_layers_xyz: int = 8
|
n_layers_xyz: int = 8
|
||||||
append_xyz: List[int] = field(default_factory=lambda: [5])
|
append_xyz: Tuple[int, ...] = (5,)
|
||||||
|
|
||||||
def _construct_xyz_encoder(self, input_dim: int):
|
def _construct_xyz_encoder(self, input_dim: int):
|
||||||
return MLPWithInputSkips(
|
return MLPWithInputSkips(
|
||||||
@@ -224,7 +223,7 @@ class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
|
|||||||
transformer_dim_down_factor: float = 2.0
|
transformer_dim_down_factor: float = 2.0
|
||||||
n_hidden_neurons_xyz: int = 80
|
n_hidden_neurons_xyz: int = 80
|
||||||
n_layers_xyz: int = 2
|
n_layers_xyz: int = 2
|
||||||
append_xyz: List[int] = field(default_factory=lambda: [1])
|
append_xyz: Tuple[int, ...] = (1,)
|
||||||
|
|
||||||
def _construct_xyz_encoder(self, input_dim: int):
|
def _construct_xyz_encoder(self, input_dim: int):
|
||||||
return TransformerWithInputSkips(
|
return TransformerWithInputSkips(
|
||||||
@@ -286,7 +285,7 @@ class MLPWithInputSkips(torch.nn.Module):
|
|||||||
output_dim: int = 256,
|
output_dim: int = 256,
|
||||||
skip_dim: int = 39,
|
skip_dim: int = 39,
|
||||||
hidden_dim: int = 256,
|
hidden_dim: int = 256,
|
||||||
input_skips: List[int] = [5],
|
input_skips: Tuple[int, ...] = (5,),
|
||||||
skip_affine_trans: bool = False,
|
skip_affine_trans: bool = False,
|
||||||
no_last_relu=False,
|
no_last_relu=False,
|
||||||
):
|
):
|
||||||
@@ -362,7 +361,7 @@ class TransformerWithInputSkips(torch.nn.Module):
|
|||||||
output_dim: int = 256,
|
output_dim: int = 256,
|
||||||
skip_dim: int = 39,
|
skip_dim: int = 39,
|
||||||
hidden_dim: int = 64,
|
hidden_dim: int = 64,
|
||||||
input_skips: List[int] = [5],
|
input_skips: Tuple[int, ...] = (5,),
|
||||||
dim_down_factor: float = 1,
|
dim_down_factor: float = 1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -386,7 +385,7 @@ class TransformerWithInputSkips(torch.nn.Module):
|
|||||||
layers_pool, layers_ray = [], []
|
layers_pool, layers_ray = [], []
|
||||||
dimout = 0
|
dimout = 0
|
||||||
for layeri in range(n_layers):
|
for layeri in range(n_layers):
|
||||||
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
|
dimin = int(round(hidden_dim / (dim_down_factor**layeri)))
|
||||||
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
||||||
logger.info(f"Tr: {dimin} -> {dimout}")
|
logger.info(f"Tr: {dimin} -> {dimout}")
|
||||||
for _i, l in enumerate((layers_pool, layers_ray)):
|
for _i, l in enumerate((layers_pool, layers_ray)):
|
||||||
@@ -406,6 +405,8 @@ class TransformerWithInputSkips(torch.nn.Module):
|
|||||||
self.last = torch.nn.Linear(dimout, output_dim)
|
self.last = torch.nn.Linear(dimout, output_dim)
|
||||||
_xavier_init(self.last)
|
_xavier_init(self.last)
|
||||||
|
|
||||||
|
# pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as
|
||||||
|
# `ModuleList`.
|
||||||
self.layers_pool, self.layers_ray = (
|
self.layers_pool, self.layers_ray = (
|
||||||
torch.nn.ModuleList(layers_pool),
|
torch.nn.ModuleList(layers_pool),
|
||||||
torch.nn.ModuleList(layers_ray),
|
torch.nn.ModuleList(layers_ray),
|
||||||
|
|||||||
@@ -6,63 +6,180 @@
|
|||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools import metric_utils as utils
|
from pytorch3d.implicitron.tools import metric_utils as utils
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer import utils as rend_utils
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
|
|
||||||
|
from .renderer.base import RendererOutput
|
||||||
|
|
||||||
|
|
||||||
|
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Replaceable abstract base for regularization metrics.
|
||||||
|
`forward()` method produces regularization metrics and (unlike ViewMetrics) can
|
||||||
|
depend on the model's parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculates various regularization terms useful for supervising differentiable
|
||||||
|
rendering pipelines.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A model instance. Useful, for example, to implement
|
||||||
|
weights-based regularization.
|
||||||
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
|
containing all regularization metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the resulting regularization metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Replaceable abstract base for model metrics.
|
||||||
|
`forward()` method produces losses and other metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
class ViewMetrics(torch.nn.Module):
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
image_sampling_grid: torch.Tensor,
|
raymarched: RendererOutput,
|
||||||
images: Optional[torch.Tensor] = None,
|
xys: torch.Tensor,
|
||||||
images_pred: Optional[torch.Tensor] = None,
|
image_rgb: Optional[torch.Tensor] = None,
|
||||||
depths: Optional[torch.Tensor] = None,
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
depths_pred: Optional[torch.Tensor] = None,
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
masks: Optional[torch.Tensor] = None,
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
masks_pred: Optional[torch.Tensor] = None,
|
|
||||||
masks_crop: Optional[torch.Tensor] = None,
|
|
||||||
grad_theta: Optional[torch.Tensor] = None,
|
|
||||||
density_grid: Optional[torch.Tensor] = None,
|
|
||||||
keys_prefix: str = "loss_",
|
keys_prefix: str = "loss_",
|
||||||
mask_renders_by_pred: bool = False,
|
**kwargs,
|
||||||
) -> Dict[str, torch.Tensor]:
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculates various metrics and loss functions useful for supervising
|
||||||
|
differentiable rendering pipelines. Any additional parameters can be passed
|
||||||
|
in the `raymarched.aux` dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: A dictionary with the resulting view metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
raymarched: Output of the renderer.
|
||||||
|
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||||
|
the predictions are defined. All ground truth inputs are sampled at
|
||||||
|
these locations in order to extract values that correspond to the
|
||||||
|
predictions.
|
||||||
|
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||||
|
values.
|
||||||
|
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||||
|
values.
|
||||||
|
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||||
|
foreground masks.
|
||||||
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
|
containing all view metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the resulting view metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class RegularizationMetrics(RegularizationMetricsBase):
|
||||||
|
def forward(
|
||||||
|
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculates the AD penalty, or returns an empty dict if the model's autoencoder
|
||||||
|
is inactive.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: A model instance.
|
||||||
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
|
containing all regularization metrics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary with the resulting regularization metrics. The items
|
||||||
|
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||||
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
|
||||||
|
The calculated metric is:
|
||||||
|
autoencoder_norm: Autoencoder weight norm regularization term.
|
||||||
|
"""
|
||||||
|
metrics = {}
|
||||||
|
if getattr(model, "sequence_autodecoder", None) is not None:
|
||||||
|
ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
|
||||||
|
if ad_penalty is not None:
|
||||||
|
metrics["autodecoder_norm"] = ad_penalty
|
||||||
|
|
||||||
|
if keys_prefix is not None:
|
||||||
|
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class ViewMetrics(ViewMetricsBase):
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
raymarched: RendererOutput,
|
||||||
|
xys: torch.Tensor,
|
||||||
|
image_rgb: Optional[torch.Tensor] = None,
|
||||||
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
|
keys_prefix: str = "loss_",
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Calculates various differentiable metrics useful for supervising
|
Calculates various differentiable metrics useful for supervising
|
||||||
differentiable rendering pipelines.
|
differentiable rendering pipelines.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
|
results: A dict to store the results in.
|
||||||
image locations at which the predictions are defined.
|
raymarched.features: Predicted rgb or feature values.
|
||||||
All ground truth inputs are sampled at these
|
raymarched.depths: A tensor of shape `(B, ..., 1)` containing
|
||||||
locations in order to extract values that correspond
|
predicted depth values.
|
||||||
to the predictions.
|
raymarched.masks: A tensor of shape `(B, ..., 1)` containing
|
||||||
images: A tensor of shape `(B, H, W, 3)` containing ground truth
|
predicted foreground masks.
|
||||||
rgb values.
|
raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
|
||||||
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
|
evaluation of a gradient of a signed distance function w.r.t.
|
||||||
rgb values.
|
|
||||||
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
|
|
||||||
depth values.
|
|
||||||
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
|
||||||
depth values.
|
|
||||||
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
|
||||||
foreground masks.
|
|
||||||
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
|
||||||
foreground masks.
|
|
||||||
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
|
|
||||||
of a gradient of a signed distance function w.r.t.
|
|
||||||
input 3D coordinates used to compute the eikonal loss.
|
input 3D coordinates used to compute the eikonal loss.
|
||||||
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
|
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
|
||||||
`Hg x Wg x Dg` voxel grid of density values.
|
containing a `Hg x Wg x Dg` voxel grid of density values.
|
||||||
|
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||||
|
the predictions are defined. All ground truth inputs are sampled at
|
||||||
|
these locations in order to extract values that correspond to the
|
||||||
|
predictions.
|
||||||
|
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||||
|
values.
|
||||||
|
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||||
|
values.
|
||||||
|
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||||
|
foreground masks.
|
||||||
keys_prefix: A common prefix for all keys in the output dictionary
|
keys_prefix: A common prefix for all keys in the output dictionary
|
||||||
containing all metrics.
|
containing all view metrics.
|
||||||
mask_renders_by_pred: If `True`, masks rendered images by the predicted
|
|
||||||
`masks_pred` prior to computing all rgb metrics.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
||||||
names of the output metrics `metric_name_i` with their corresponding
|
names of the output metrics `metric_name_i` with their corresponding
|
||||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||||
|
|
||||||
@@ -90,109 +207,142 @@ class ViewMetrics(torch.nn.Module):
|
|||||||
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
||||||
predicted depth values.
|
predicted depth values.
|
||||||
"""
|
"""
|
||||||
|
metrics = self._calculate_stage(
|
||||||
|
raymarched,
|
||||||
|
xys,
|
||||||
|
image_rgb,
|
||||||
|
depth_map,
|
||||||
|
fg_probability,
|
||||||
|
mask_crop,
|
||||||
|
keys_prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
if raymarched.prev_stage:
|
||||||
|
metrics.update(
|
||||||
|
self(
|
||||||
|
raymarched.prev_stage,
|
||||||
|
xys,
|
||||||
|
image_rgb,
|
||||||
|
depth_map,
|
||||||
|
fg_probability,
|
||||||
|
mask_crop,
|
||||||
|
keys_prefix=(keys_prefix + "prev_stage_"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def _calculate_stage(
|
||||||
|
self,
|
||||||
|
raymarched: RendererOutput,
|
||||||
|
xys: torch.Tensor,
|
||||||
|
image_rgb: Optional[torch.Tensor] = None,
|
||||||
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
|
keys_prefix: str = "loss_",
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate metrics for the current stage.
|
||||||
|
"""
|
||||||
# TODO: extract functions
|
# TODO: extract functions
|
||||||
|
|
||||||
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
||||||
images_pred, masks_pred, depths_pred = [
|
image_rgb_pred, fg_probability_pred, depth_map_pred = [
|
||||||
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
|
_reshape_nongrid_var(x)
|
||||||
|
for x in [raymarched.features, raymarched.masks, raymarched.depths]
|
||||||
]
|
]
|
||||||
# reshape the sampling grid as well
|
# reshape the sampling grid as well
|
||||||
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
||||||
# now that we use rend_utils.ndc_grid_sample
|
# now that we use rend_utils.ndc_grid_sample
|
||||||
image_sampling_grid = image_sampling_grid.reshape(
|
xys = xys.reshape(xys.shape[0], -1, 1, 2)
|
||||||
image_sampling_grid.shape[0], -1, 1, 2
|
|
||||||
)
|
|
||||||
|
|
||||||
# closure with the given image_sampling_grid
|
# closure with the given xys
|
||||||
def sample(tensor, mode):
|
def sample(tensor, mode):
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
return tensor
|
return tensor
|
||||||
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
|
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
||||||
|
|
||||||
# eval all results in this size
|
# eval all results in this size
|
||||||
images = sample(images, mode="bilinear")
|
image_rgb = sample(image_rgb, mode="bilinear")
|
||||||
depths = sample(depths, mode="nearest")
|
depth_map = sample(depth_map, mode="nearest")
|
||||||
masks = sample(masks, mode="nearest")
|
fg_probability = sample(fg_probability, mode="nearest")
|
||||||
masks_crop = sample(masks_crop, mode="nearest")
|
mask_crop = sample(mask_crop, mode="nearest")
|
||||||
if masks_crop is None and images_pred is not None:
|
if mask_crop is None and image_rgb_pred is not None:
|
||||||
masks_crop = torch.ones_like(images_pred[:, :1])
|
mask_crop = torch.ones_like(image_rgb_pred[:, :1])
|
||||||
if masks_crop is None and depths_pred is not None:
|
if mask_crop is None and depth_map_pred is not None:
|
||||||
masks_crop = torch.ones_like(depths_pred[:, :1])
|
mask_crop = torch.ones_like(depth_map_pred[:, :1])
|
||||||
|
|
||||||
preds = {}
|
metrics = {}
|
||||||
if images is not None and images_pred is not None:
|
if image_rgb is not None and image_rgb_pred is not None:
|
||||||
# TODO: mask_renders_by_pred is always false; simplify
|
metrics.update(
|
||||||
preds.update(
|
|
||||||
_rgb_metrics(
|
_rgb_metrics(
|
||||||
images,
|
image_rgb,
|
||||||
images_pred,
|
image_rgb_pred,
|
||||||
masks,
|
fg_probability,
|
||||||
masks_pred,
|
fg_probability_pred,
|
||||||
masks_crop,
|
mask_crop,
|
||||||
mask_renders_by_pred,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if masks_pred is not None:
|
if fg_probability_pred is not None:
|
||||||
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
|
metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
|
||||||
if masks is not None and masks_pred is not None:
|
if fg_probability is not None and fg_probability_pred is not None:
|
||||||
preds["mask_neg_iou"] = utils.neg_iou_loss(
|
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||||
masks_pred, masks, mask=masks_crop
|
fg_probability_pred, fg_probability, mask=mask_crop
|
||||||
|
)
|
||||||
|
metrics["mask_bce"] = utils.calc_bce(
|
||||||
|
fg_probability_pred, fg_probability, mask=mask_crop
|
||||||
)
|
)
|
||||||
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
|
|
||||||
|
|
||||||
if depths is not None and depths_pred is not None:
|
if depth_map is not None and depth_map_pred is not None:
|
||||||
assert masks_crop is not None
|
assert mask_crop is not None
|
||||||
_, abs_ = utils.eval_depth(
|
_, abs_ = utils.eval_depth(
|
||||||
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
|
depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
|
||||||
)
|
)
|
||||||
preds["depth_abs"] = abs_.mean()
|
metrics["depth_abs"] = abs_.mean()
|
||||||
|
|
||||||
if masks is not None:
|
if fg_probability is not None:
|
||||||
mask = masks * masks_crop
|
mask = fg_probability * mask_crop
|
||||||
_, abs_ = utils.eval_depth(
|
_, abs_ = utils.eval_depth(
|
||||||
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
|
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||||
)
|
)
|
||||||
preds["depth_abs_fg"] = abs_.mean()
|
metrics["depth_abs_fg"] = abs_.mean()
|
||||||
|
|
||||||
# regularizers
|
# regularizers
|
||||||
|
grad_theta = raymarched.aux.get("grad_theta")
|
||||||
if grad_theta is not None:
|
if grad_theta is not None:
|
||||||
preds["eikonal"] = _get_eikonal_loss(grad_theta)
|
metrics["eikonal"] = _get_eikonal_loss(grad_theta)
|
||||||
|
|
||||||
|
density_grid = raymarched.aux.get("density_grid")
|
||||||
if density_grid is not None:
|
if density_grid is not None:
|
||||||
preds["density_tv"] = _get_grid_tv_loss(density_grid)
|
metrics["density_tv"] = _get_grid_tv_loss(density_grid)
|
||||||
|
|
||||||
if depths_pred is not None:
|
if depth_map_pred is not None:
|
||||||
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
|
metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
|
||||||
|
|
||||||
if keys_prefix is not None:
|
if keys_prefix is not None:
|
||||||
preds = {(keys_prefix + k): v for k, v in preds.items()}
|
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||||
|
|
||||||
return preds
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def _rgb_metrics(
|
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||||
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
|
|
||||||
):
|
|
||||||
assert masks_crop is not None
|
assert masks_crop is not None
|
||||||
if mask_renders_by_pred:
|
|
||||||
images = images[..., masks_pred.reshape(-1), :]
|
|
||||||
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
|
|
||||||
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
|
|
||||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||||
crop_mass = masks_crop.sum().clamp(1.0)
|
crop_mass = masks_crop.sum().clamp(1.0)
|
||||||
preds = {
|
results = {
|
||||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||||
}
|
}
|
||||||
if masks is not None:
|
if masks is not None:
|
||||||
masks = masks_crop * masks
|
masks = masks_crop * masks
|
||||||
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
||||||
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
||||||
return preds
|
return results
|
||||||
|
|
||||||
|
|
||||||
def _get_eikonal_loss(grad_theta):
|
def _get_eikonal_loss(grad_theta):
|
||||||
|
|||||||
@@ -5,13 +5,11 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
NewViewSynthesisPrediction,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.point_cloud_utils import (
|
from pytorch3d.implicitron.tools.point_cloud_utils import (
|
||||||
get_rgbd_point_cloud,
|
get_rgbd_point_cloud,
|
||||||
render_point_cloud_pytorch3d,
|
render_point_cloud_pytorch3d,
|
||||||
@@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import (
|
|||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.structures import Pointclouds
|
||||||
|
|
||||||
|
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||||
|
from .renderer.base import EvaluationMode
|
||||||
|
|
||||||
class ModelDBIR(torch.nn.Module):
|
|
||||||
|
@registry.register
|
||||||
|
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A simple depth-based image rendering model.
|
A simple depth-based image rendering model.
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
image_size: int = 256,
|
|
||||||
bg_color: float = 0.0,
|
|
||||||
max_points: int = -1,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initializes a simple DBIR model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
image_size: The size of the rendered rectangular images.
|
render_image_width: The width of the rendered rectangular images.
|
||||||
|
render_image_height: The height of the rendered rectangular images.
|
||||||
bg_color: The color of the background.
|
bg_color: The color of the background.
|
||||||
max_points: Maximum number of points in the point cloud
|
max_points: Maximum number of points in the point cloud
|
||||||
formed by unprojecting all source view depths.
|
formed by unprojecting all source view depths.
|
||||||
If more points are present, they are randomly subsampled
|
If more points are present, they are randomly subsampled
|
||||||
to #max_size points without replacement.
|
to this number of points without replacement.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
render_image_width: int = 256
|
||||||
|
render_image_height: int = 256
|
||||||
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
max_points: int = -1
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.image_size = image_size
|
|
||||||
self.bg_color = bg_color
|
|
||||||
self.max_points = max_points
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
*, # force keyword-only arguments
|
||||||
|
image_rgb: Optional[torch.Tensor],
|
||||||
camera: CamerasBase,
|
camera: CamerasBase,
|
||||||
image_rgb: torch.Tensor,
|
fg_probability: Optional[torch.Tensor],
|
||||||
depth_map: torch.Tensor,
|
mask_crop: Optional[torch.Tensor],
|
||||||
fg_probability: torch.Tensor,
|
depth_map: Optional[torch.Tensor],
|
||||||
|
sequence_name: Optional[List[str]],
|
||||||
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
frame_type: List[str],
|
frame_type: List[str],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
|
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
|
||||||
@@ -72,26 +72,39 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
preds: A dict with the following fields:
|
preds: A dict with the following fields:
|
||||||
nvs_prediction: The rendered colors, depth and mask
|
implicitron_render: The rendered colors, depth and mask
|
||||||
of the target views.
|
of the target views.
|
||||||
point_cloud: The point cloud of the scene. It's renders are
|
point_cloud: The point cloud of the scene. It's renders are
|
||||||
stored in `nvs_prediction`.
|
stored in `implicitron_render`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if image_rgb is None:
|
||||||
|
raise ValueError("ModelDBIR needs image input")
|
||||||
|
|
||||||
|
if fg_probability is None:
|
||||||
|
raise ValueError("ModelDBIR needs foreground mask input")
|
||||||
|
|
||||||
|
if depth_map is None:
|
||||||
|
raise ValueError("ModelDBIR needs depth map input")
|
||||||
|
|
||||||
is_known = is_known_frame(frame_type)
|
is_known = is_known_frame(frame_type)
|
||||||
is_known_idx = torch.where(is_known)[0]
|
is_known_idx = torch.where(is_known)[0]
|
||||||
|
|
||||||
mask_fg = (fg_probability > 0.5).type_as(image_rgb)
|
mask_fg = (fg_probability > 0.5).type_as(image_rgb)
|
||||||
|
|
||||||
point_cloud = get_rgbd_point_cloud(
|
point_cloud = get_rgbd_point_cloud(
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[List[int], int,
|
||||||
|
# LongTensor]` but got `Tensor`.
|
||||||
camera[is_known_idx],
|
camera[is_known_idx],
|
||||||
image_rgb[is_known_idx],
|
image_rgb[is_known_idx],
|
||||||
depth_map[is_known_idx],
|
depth_map[is_known_idx],
|
||||||
mask_fg[is_known_idx],
|
mask_fg[is_known_idx],
|
||||||
)
|
)
|
||||||
|
|
||||||
pcl_size = int(point_cloud.num_points_per_cloud())
|
pcl_size = point_cloud.num_points_per_cloud().item()
|
||||||
if (self.max_points > 0) and (pcl_size > self.max_points):
|
if (self.max_points > 0) and (pcl_size > self.max_points):
|
||||||
|
# pyre-fixme[6]: For 1st param expected `int` but got `Union[bool,
|
||||||
|
# float, int]`.
|
||||||
prm = torch.randperm(pcl_size)[: self.max_points]
|
prm = torch.randperm(pcl_size)[: self.max_points]
|
||||||
point_cloud = Pointclouds(
|
point_cloud = Pointclouds(
|
||||||
point_cloud.points_padded()[:, prm, :],
|
point_cloud.points_padded()[:, prm, :],
|
||||||
@@ -108,7 +121,7 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
|
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
|
||||||
camera[int(tgt_idx)],
|
camera[int(tgt_idx)],
|
||||||
point_cloud,
|
point_cloud,
|
||||||
render_size=(self.image_size, self.image_size),
|
render_size=(self.render_image_height, self.render_image_width),
|
||||||
point_radius=1e-2,
|
point_radius=1e-2,
|
||||||
topk=10,
|
topk=10,
|
||||||
bg_color=self.bg_color,
|
bg_color=self.bg_color,
|
||||||
@@ -121,7 +134,7 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
image_render.append(_image_render)
|
image_render.append(_image_render)
|
||||||
mask_render.append(_mask_render)
|
mask_render.append(_mask_render)
|
||||||
|
|
||||||
nvs_prediction = NewViewSynthesisPrediction(
|
implicitron_render = ImplicitronRender(
|
||||||
**{
|
**{
|
||||||
k: torch.cat(v, dim=0)
|
k: torch.cat(v, dim=0)
|
||||||
for k, v in zip(
|
for k, v in zip(
|
||||||
@@ -132,7 +145,7 @@ class ModelDBIR(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
preds = {
|
preds = {
|
||||||
"nvs_prediction": nvs_prediction,
|
"implicitron_render": implicitron_render,
|
||||||
"point_cloud": point_cloud,
|
"point_cloud": point_cloud,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ class RendererOutput:
|
|||||||
prev_stage: Optional[RendererOutput] = None
|
prev_stage: Optional[RendererOutput] = None
|
||||||
normals: Optional[torch.Tensor] = None
|
normals: Optional[torch.Tensor] = None
|
||||||
points: Optional[torch.Tensor] = None # TODO: redundant with depths
|
points: Optional[torch.Tensor] = None # TODO: redundant with depths
|
||||||
|
weights: Optional[torch.Tensor] = None
|
||||||
aux: Dict[str, Any] = field(default_factory=lambda: {})
|
aux: Dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
|
|
||||||
@@ -87,7 +88,7 @@ class BaseRenderer(ABC, ReplaceableBase):
|
|||||||
ray_bundle,
|
ray_bundle,
|
||||||
implicit_functions: List[ImplicitFunctionWrapper],
|
implicit_functions: List[ImplicitFunctionWrapper],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Each Renderer should implement its own forward function
|
Each Renderer should implement its own forward function
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
|
|||||||
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Implements the learnable LSTM raymarching function from SRN [1].
|
Implements the learnable LSTM raymarching function from SRN [1].
|
||||||
|
This requires there to be one implicit function, and it is expected to be
|
||||||
|
like SRNImplicitFunction or SRNHyperNetImplicitFunction.
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
num_raymarch_steps: The number of LSTM raymarching steps.
|
num_raymarch_steps: The number of LSTM raymarching steps.
|
||||||
@@ -32,6 +34,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
hidden_size: The dimensionality of the LSTM's hidden state.
|
hidden_size: The dimensionality of the LSTM's hidden state.
|
||||||
n_feature_channels: The number of feature channels returned by the
|
n_feature_channels: The number of feature channels returned by the
|
||||||
implicit_function evaluated at each raymarching step.
|
implicit_function evaluated at each raymarching step.
|
||||||
|
bg_color: If supplied, used as the background color. Otherwise the pixel
|
||||||
|
generator is used everywhere. This has to have length either 1
|
||||||
|
(for a constant value for all output channels) or equal to the number
|
||||||
|
of output channels (which is `out_features` on the pixel generator,
|
||||||
|
typically 3.)
|
||||||
verbose: If `True`, logs raymarching debug info.
|
verbose: If `True`, logs raymarching debug info.
|
||||||
|
|
||||||
References:
|
References:
|
||||||
@@ -45,6 +52,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
init_depth_noise_std: float = 5e-4
|
init_depth_noise_std: float = 5e-4
|
||||||
hidden_size: int = 16
|
hidden_size: int = 16
|
||||||
n_feature_channels: int = 256
|
n_feature_channels: int = 256
|
||||||
|
bg_color: Optional[List[float]] = None
|
||||||
verbose: bool = False
|
verbose: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -117,13 +125,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
msg = (
|
msg = (
|
||||||
f"{t}: mu={float(signed_distance.mean()):1.2e};"
|
f"{t}: mu={float(signed_distance.mean()):1.2e};"
|
||||||
+ f" std={float(signed_distance.std()):1.2e};"
|
+ f" std={float(signed_distance.std()):1.2e};"
|
||||||
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
|
||||||
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
|
|
||||||
# param but got `Tensor`.
|
|
||||||
+ f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
|
+ f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
|
||||||
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
|
||||||
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
|
|
||||||
# param but got `Tensor`.
|
|
||||||
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
|
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
|
||||||
)
|
)
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
@@ -153,6 +155,10 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
dim=-1, keepdim=True
|
dim=-1, keepdim=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.bg_color is not None:
|
||||||
|
background = features.new_tensor(self.bg_color)
|
||||||
|
features = torch.lerp(background, features, mask)
|
||||||
|
|
||||||
return RendererOutput(
|
return RendererOutput(
|
||||||
features=features[..., 0, :],
|
features=features[..., 0, :],
|
||||||
depths=depth,
|
depths=depth,
|
||||||
|
|||||||
@@ -4,18 +4,21 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry
|
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||||
|
from pytorch3d.renderer import RayBundle
|
||||||
|
|
||||||
from .base import BaseRenderer, EvaluationMode, RendererOutput
|
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||||
from .ray_point_refiner import RayPointRefiner
|
from .ray_point_refiner import RayPointRefiner
|
||||||
from .raymarcher import GenericRaymarcher
|
from .raymarcher import RaymarcherBase
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
||||||
|
BaseRenderer, torch.nn.Module
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Implements the multi-pass rendering function, in particular,
|
Implements the multi-pass rendering function, in particular,
|
||||||
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
||||||
@@ -33,7 +36,17 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
```
|
```
|
||||||
and the final rendered quantities are computed by a dot-product of ray values
|
and the final rendered quantities are computed by a dot-product of ray values
|
||||||
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
|
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
|
||||||
See below for possible values of `cap_fn` and `weight_fn`.
|
|
||||||
|
By default, for the EA raymarcher from [1] (
|
||||||
|
activated with `self.raymarcher_class_type="EmissionAbsorptionRaymarcher"`
|
||||||
|
):
|
||||||
|
```
|
||||||
|
cap_fn(x) = 1 - exp(-x),
|
||||||
|
weight_fn(x) = w * x.
|
||||||
|
```
|
||||||
|
Note that the latter can altered by changing `self.raymarcher_class_type`,
|
||||||
|
e.g. to "CumsumRaymarcher" which implements the cumulative-sum raymarcher
|
||||||
|
from NeuralVolumes [2].
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
n_pts_per_ray_fine_training: The number of points sampled per ray for the
|
n_pts_per_ray_fine_training: The number of points sampled per ray for the
|
||||||
@@ -46,42 +59,33 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
evaluation.
|
evaluation.
|
||||||
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
||||||
after sampling.
|
after sampling.
|
||||||
bg_color: The background color. A tuple of either 1 element or of D elements,
|
|
||||||
where D matches the feature dimensionality; it is broadcasted when necessary.
|
|
||||||
density_noise_std_train: Standard deviation of the noise added to the
|
density_noise_std_train: Standard deviation of the noise added to the
|
||||||
opacity field.
|
opacity field.
|
||||||
capping_function: The capping function of the raymarcher.
|
return_weights: Enables returning the rendering weights of the EA raymarcher.
|
||||||
Options:
|
Setting to `True` can lead to a prohibitivelly large memory consumption.
|
||||||
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
|
raymarcher_class_type: The type of self.raymarcher corresponding to
|
||||||
- "cap1" (`cap_fn(x) = min(x, 1)`)
|
a child of `RaymarcherBase` in the registry.
|
||||||
Set to "exponential" for the standard Emission Absorption raymarching.
|
raymarcher: The raymarcher object used to convert per-point features
|
||||||
weight_function: The weighting function of the raymarcher.
|
and opacities to a feature render.
|
||||||
Options:
|
|
||||||
- "product" (`weight_fn(w, x) = w * x`)
|
|
||||||
- "minimum" (`weight_fn(w, x) = min(w, x)`)
|
|
||||||
Set to "product" for the standard Emission Absorption raymarching.
|
|
||||||
background_opacity: The raw opacity value (i.e. before exponentiation)
|
|
||||||
of the background.
|
|
||||||
blend_output: If `True`, alpha-blends the output renders with the
|
|
||||||
background color using the rendered opacity mask.
|
|
||||||
|
|
||||||
References:
|
References:
|
||||||
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
|
[1] Mildenhall, Ben, et al. "Nerf: Representing Scenes as Neural Radiance
|
||||||
fields for view synthesis." ECCV 2020.
|
Fields for View Synthesis." ECCV 2020.
|
||||||
|
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
|
||||||
|
Volumes from Images." SIGGRAPH 2019.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
|
||||||
|
raymarcher: RaymarcherBase
|
||||||
|
|
||||||
n_pts_per_ray_fine_training: int = 64
|
n_pts_per_ray_fine_training: int = 64
|
||||||
n_pts_per_ray_fine_evaluation: int = 64
|
n_pts_per_ray_fine_evaluation: int = 64
|
||||||
stratified_sampling_coarse_training: bool = True
|
stratified_sampling_coarse_training: bool = True
|
||||||
stratified_sampling_coarse_evaluation: bool = False
|
stratified_sampling_coarse_evaluation: bool = False
|
||||||
append_coarse_samples_to_fine: bool = True
|
append_coarse_samples_to_fine: bool = True
|
||||||
bg_color: Tuple[float, ...] = (0.0,)
|
|
||||||
density_noise_std_train: float = 0.0
|
density_noise_std_train: float = 0.0
|
||||||
capping_function: str = "exponential" # exponential | cap1
|
return_weights: bool = False
|
||||||
weight_function: str = "product" # product | minimum
|
|
||||||
background_opacity: float = 1e10
|
|
||||||
blend_output: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -97,22 +101,14 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
add_input_samples=self.append_coarse_samples_to_fine,
|
add_input_samples=self.append_coarse_samples_to_fine,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
run_auto_creation(self)
|
||||||
self._raymarcher = GenericRaymarcher(
|
|
||||||
1,
|
|
||||||
self.bg_color,
|
|
||||||
capping_function=self.capping_function,
|
|
||||||
weight_function=self.weight_function,
|
|
||||||
background_opacity=self.background_opacity,
|
|
||||||
blend_output=self.blend_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle,
|
ray_bundle: RayBundle,
|
||||||
implicit_functions=[],
|
implicit_functions: List[ImplicitFunctionWrapper],
|
||||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
**kwargs
|
**kwargs,
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -149,14 +145,16 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
features, depth, mask, weights, aux = self._raymarcher(
|
output = self.raymarcher(
|
||||||
*implicit_functions[0](ray_bundle),
|
*implicit_functions[0](ray_bundle),
|
||||||
ray_lengths=ray_bundle.lengths,
|
ray_lengths=ray_bundle.lengths,
|
||||||
density_noise_std=density_noise_std,
|
density_noise_std=density_noise_std,
|
||||||
)
|
)
|
||||||
output = RendererOutput(
|
output.prev_stage = prev_stage
|
||||||
features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
|
|
||||||
)
|
weights = output.weights
|
||||||
|
if not self.return_weights:
|
||||||
|
output.weights = None
|
||||||
|
|
||||||
# we may need to make a recursive call
|
# we may need to make a recursive call
|
||||||
if len(implicit_functions) > 1:
|
if len(implicit_functions) > 1:
|
||||||
|
|||||||
@@ -4,21 +4,52 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import field
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools import camera_utils
|
from pytorch3d.implicitron.tools import camera_utils
|
||||||
from pytorch3d.implicitron.tools.config import Configurable
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
|
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .base import EvaluationMode, RenderSamplingMode
|
from .base import EvaluationMode, RenderSamplingMode
|
||||||
|
|
||||||
|
|
||||||
class RaySampler(Configurable, torch.nn.Module):
|
class RaySamplerBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Base class for ray samplers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
cameras: CamerasBase,
|
||||||
|
evaluation_mode: EvaluationMode,
|
||||||
|
mask: Optional[torch.Tensor] = None,
|
||||||
|
) -> RayBundle:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||||
|
evaluation_mode: one of `EvaluationMode.TRAINING` or
|
||||||
|
`EvaluationMode.EVALUATION` which determines the sampling mode
|
||||||
|
that is used.
|
||||||
|
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
|
||||||
|
Defines a non-negative mask of shape
|
||||||
|
`(batch_size, image_height, image_width)` where each per-pixel
|
||||||
|
value is proportional to the probability of sampling the
|
||||||
|
corresponding pixel's ray.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
||||||
|
sampled rendering rays.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
Samples a fixed number of points along rays which are in turn sampled for
|
Samples a fixed number of points along rays which are in turn sampled for
|
||||||
each camera in a batch.
|
each camera in a batch.
|
||||||
|
|
||||||
@@ -29,46 +60,19 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
for training and evaluation by setting `self.sampling_mode_training`
|
for training and evaluation by setting `self.sampling_mode_training`
|
||||||
and `self.sampling_mode_training` accordingly.
|
and `self.sampling_mode_training` accordingly.
|
||||||
|
|
||||||
The class allows two modes of sampling points along the rays:
|
The class allows to adjust the sampling points along rays by overwriting the
|
||||||
1) Sampling between fixed near and far z-planes:
|
`AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns
|
||||||
Active when `self.scene_extent <= 0`, samples points along each ray
|
the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`.
|
||||||
with approximately uniform spacing of z-coordinates between
|
|
||||||
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
|
|
||||||
This sampling is useful for rendering scenes where the camera is
|
|
||||||
in a constant distance from the focal point of the scene.
|
|
||||||
2) Adaptive near/far plane estimation around the world scene center:
|
|
||||||
Active when `self.scene_extent > 0`. Samples points on each
|
|
||||||
ray between near and far planes whose depths are determined based on
|
|
||||||
the distance from the camera center to a predefined scene center.
|
|
||||||
More specifically,
|
|
||||||
`min_depth = max(
|
|
||||||
(self.scene_center-camera_center).norm() - self.scene_extent, eps
|
|
||||||
)` and
|
|
||||||
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
|
|
||||||
This sampling is ideal for object-centric scenes whose contents are
|
|
||||||
centered around a known `self.scene_center` and fit into a bounding sphere
|
|
||||||
with a radius of `self.scene_extent`.
|
|
||||||
|
|
||||||
Similar to the sampling mode, the sampling parameters can be set separately
|
|
||||||
for training and evaluation.
|
|
||||||
|
|
||||||
Settings:
|
Settings:
|
||||||
image_width: The horizontal size of the image grid.
|
image_width: The horizontal size of the image grid.
|
||||||
image_height: The vertical size of the image grid.
|
image_height: The vertical size of the image grid.
|
||||||
scene_center: The xyz coordinates of the center of the scene used
|
|
||||||
along with `scene_extent` to compute the min and max depth planes
|
|
||||||
for sampling ray-points.
|
|
||||||
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
|
|
||||||
If `scene_extent <= 0`, the raysampler samples points between
|
|
||||||
`self.min_depth` and `self.max_depth` depths instead.
|
|
||||||
sampling_mode_training: The ray sampling mode for training. This should be a str
|
sampling_mode_training: The ray sampling mode for training. This should be a str
|
||||||
option from the RenderSamplingMode Enum
|
option from the RenderSamplingMode Enum
|
||||||
sampling_mode_evaluation: Same as above but for evaluation.
|
sampling_mode_evaluation: Same as above but for evaluation.
|
||||||
n_pts_per_ray_training: The number of points sampled along each ray during training.
|
n_pts_per_ray_training: The number of points sampled along each ray during training.
|
||||||
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
|
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
|
||||||
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
|
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
|
||||||
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
|
|
||||||
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
|
|
||||||
stratified_point_sampling_training: if set, performs stratified random sampling
|
stratified_point_sampling_training: if set, performs stratified random sampling
|
||||||
along the ray; otherwise takes ray points at deterministic offsets.
|
along the ray; otherwise takes ray points at deterministic offsets.
|
||||||
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
||||||
@@ -77,24 +81,17 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
|
|
||||||
image_width: int = 400
|
image_width: int = 400
|
||||||
image_height: int = 400
|
image_height: int = 400
|
||||||
scene_center: Tuple[float, float, float] = field(
|
|
||||||
default_factory=lambda: (0.0, 0.0, 0.0)
|
|
||||||
)
|
|
||||||
scene_extent: float = 0.0
|
|
||||||
sampling_mode_training: str = "mask_sample"
|
sampling_mode_training: str = "mask_sample"
|
||||||
sampling_mode_evaluation: str = "full_grid"
|
sampling_mode_evaluation: str = "full_grid"
|
||||||
n_pts_per_ray_training: int = 64
|
n_pts_per_ray_training: int = 64
|
||||||
n_pts_per_ray_evaluation: int = 64
|
n_pts_per_ray_evaluation: int = 64
|
||||||
n_rays_per_image_sampled_from_mask: int = 1024
|
n_rays_per_image_sampled_from_mask: int = 1024
|
||||||
min_depth: float = 0.1
|
|
||||||
max_depth: float = 8.0
|
|
||||||
# stratified sampling vs taking points at deterministic offsets
|
# stratified sampling vs taking points at deterministic offsets
|
||||||
stratified_point_sampling_training: bool = True
|
stratified_point_sampling_training: bool = True
|
||||||
stratified_point_sampling_evaluation: bool = False
|
stratified_point_sampling_evaluation: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scene_center = torch.FloatTensor(self.scene_center)
|
|
||||||
|
|
||||||
self._sampling_mode = {
|
self._sampling_mode = {
|
||||||
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
|
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
|
||||||
@@ -108,8 +105,8 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||||
min_depth=self.min_depth,
|
min_depth=0.0,
|
||||||
max_depth=self.max_depth,
|
max_depth=0.0,
|
||||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||||
if self._sampling_mode[EvaluationMode.TRAINING]
|
if self._sampling_mode[EvaluationMode.TRAINING]
|
||||||
== RenderSamplingMode.MASK_SAMPLE
|
== RenderSamplingMode.MASK_SAMPLE
|
||||||
@@ -121,8 +118,8 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
image_width=self.image_width,
|
image_width=self.image_width,
|
||||||
image_height=self.image_height,
|
image_height=self.image_height,
|
||||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||||
min_depth=self.min_depth,
|
min_depth=0.0,
|
||||||
max_depth=self.max_depth,
|
max_depth=0.0,
|
||||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||||
if self._sampling_mode[EvaluationMode.EVALUATION]
|
if self._sampling_mode[EvaluationMode.EVALUATION]
|
||||||
== RenderSamplingMode.MASK_SAMPLE
|
== RenderSamplingMode.MASK_SAMPLE
|
||||||
@@ -132,6 +129,9 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
cameras: CamerasBase,
|
cameras: CamerasBase,
|
||||||
@@ -163,18 +163,11 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
):
|
):
|
||||||
sample_mask = torch.nn.functional.interpolate(
|
sample_mask = torch.nn.functional.interpolate(
|
||||||
mask,
|
mask,
|
||||||
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
|
|
||||||
# `List[int]`.
|
|
||||||
size=[self.image_height, self.image_width],
|
size=[self.image_height, self.image_width],
|
||||||
mode="nearest",
|
mode="nearest",
|
||||||
)[:, 0]
|
)[:, 0]
|
||||||
|
|
||||||
if self.scene_extent > 0.0:
|
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
|
||||||
# Override the min/max depth set in initialization based on the
|
|
||||||
# input cameras.
|
|
||||||
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
|
||||||
cameras, self.scene_center, self.scene_extent
|
|
||||||
)
|
|
||||||
|
|
||||||
# pyre-fixme[29]:
|
# pyre-fixme[29]:
|
||||||
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
||||||
@@ -183,8 +176,75 @@ class RaySampler(Configurable, torch.nn.Module):
|
|||||||
ray_bundle = self._raysamplers[evaluation_mode](
|
ray_bundle = self._raysamplers[evaluation_mode](
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
mask=sample_mask,
|
mask=sample_mask,
|
||||||
min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None,
|
min_depth=min_depth,
|
||||||
max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None,
|
max_depth=max_depth,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ray_bundle
|
return ray_bundle
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class AdaptiveRaySampler(AbstractMaskRaySampler):
|
||||||
|
"""
|
||||||
|
Adaptively samples points on each ray between near and far planes whose
|
||||||
|
depths are determined based on the distance from the camera center
|
||||||
|
to a predefined scene center.
|
||||||
|
|
||||||
|
More specifically,
|
||||||
|
`min_depth = max(
|
||||||
|
(self.scene_center-camera_center).norm() - self.scene_extent, eps
|
||||||
|
)` and
|
||||||
|
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
|
||||||
|
|
||||||
|
This sampling is ideal for object-centric scenes whose contents are
|
||||||
|
centered around a known `self.scene_center` and fit into a bounding sphere
|
||||||
|
with a radius of `self.scene_extent`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scene_center: The xyz coordinates of the center of the scene used
|
||||||
|
along with `scene_extent` to compute the min and max depth planes
|
||||||
|
for sampling ray-points.
|
||||||
|
scene_extent: The radius of the scene bounding box centered at `scene_center`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
scene_extent: float = 8.0
|
||||||
|
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.scene_extent <= 0.0:
|
||||||
|
raise ValueError("Adaptive raysampler requires self.scene_extent > 0.")
|
||||||
|
self._scene_center = torch.FloatTensor(self.scene_center)
|
||||||
|
|
||||||
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Returns the adaptivelly calculated near/far planes.
|
||||||
|
"""
|
||||||
|
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
||||||
|
cameras, self._scene_center, self.scene_extent
|
||||||
|
)
|
||||||
|
return float(min_depth[0]), float(max_depth[0])
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class NearFarRaySampler(AbstractMaskRaySampler):
|
||||||
|
"""
|
||||||
|
Samples a fixed number of points between fixed near and far z-planes.
|
||||||
|
Specifically, samples points along each ray with approximately uniform spacing
|
||||||
|
of z-coordinates between the minimum depth `self.min_depth` and the maximum depth
|
||||||
|
`self.max_depth`. This sampling is useful for rendering scenes where the camera is
|
||||||
|
in a constant distance from the focal point of the scene.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_depth: The minimum depth of a ray-point.
|
||||||
|
max_depth: The maximum depth of a ray-point.
|
||||||
|
"""
|
||||||
|
|
||||||
|
min_depth: float = 0.1
|
||||||
|
max_depth: float = 8.0
|
||||||
|
|
||||||
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
|
"""
|
||||||
|
Returns the stored near/far planes.
|
||||||
|
"""
|
||||||
|
return self.min_depth, self.max_depth
|
||||||
|
|||||||
@@ -123,12 +123,12 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
|
|
||||||
ray_directions = ray_directions.reshape(-1, 3)
|
ray_directions = ray_directions.reshape(-1, 3)
|
||||||
mask_intersect = mask_intersect.reshape(-1)
|
mask_intersect = mask_intersect.reshape(-1)
|
||||||
|
# pyre-fixme[9]: object_mask has type `BoolTensor`; used as `Tensor`.
|
||||||
object_mask = object_mask.reshape(-1)
|
object_mask = object_mask.reshape(-1)
|
||||||
|
|
||||||
in_mask = ~network_object_mask & object_mask & ~sampler_mask
|
in_mask = ~network_object_mask & object_mask & ~sampler_mask
|
||||||
out_mask = ~object_mask & ~sampler_mask
|
out_mask = ~object_mask & ~sampler_mask
|
||||||
|
|
||||||
# pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
|
|
||||||
mask_left_out = (in_mask | out_mask) & ~mask_intersect
|
mask_left_out = (in_mask | out_mask) & ~mask_intersect
|
||||||
if (
|
if (
|
||||||
mask_left_out.sum() > 0
|
mask_left_out.sum() > 0
|
||||||
@@ -295,7 +295,7 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
) and not_proj_iters < self.line_step_iters:
|
) and not_proj_iters < self.line_step_iters:
|
||||||
# Step backwards
|
# Step backwards
|
||||||
acc_start_dis[not_projected_start] -= (
|
acc_start_dis[not_projected_start] -= (
|
||||||
(1 - self.line_search_step) / (2 ** not_proj_iters)
|
(1 - self.line_search_step) / (2**not_proj_iters)
|
||||||
) * curr_sdf_start[not_projected_start]
|
) * curr_sdf_start[not_projected_start]
|
||||||
curr_start_points[not_projected_start] = (
|
curr_start_points[not_projected_start] = (
|
||||||
cam_loc
|
cam_loc
|
||||||
@@ -303,7 +303,7 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
).reshape(-1, 3)[not_projected_start]
|
).reshape(-1, 3)[not_projected_start]
|
||||||
|
|
||||||
acc_end_dis[not_projected_end] += (
|
acc_end_dis[not_projected_end] += (
|
||||||
(1 - self.line_search_step) / (2 ** not_proj_iters)
|
(1 - self.line_search_step) / (2**not_proj_iters)
|
||||||
) * curr_sdf_end[not_projected_end]
|
) * curr_sdf_end[not_projected_end]
|
||||||
curr_end_points[not_projected_end] = (
|
curr_end_points[not_projected_end] = (
|
||||||
cam_loc
|
cam_loc
|
||||||
@@ -410,10 +410,17 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
if n_p_out > 0:
|
if n_p_out > 0:
|
||||||
out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
|
out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
|
||||||
sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][
|
sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][
|
||||||
torch.arange(n_p_out), out_pts_idx, :
|
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
|
||||||
|
# but got `Tensor`.
|
||||||
|
torch.arange(n_p_out),
|
||||||
|
out_pts_idx,
|
||||||
|
:,
|
||||||
]
|
]
|
||||||
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
|
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
|
||||||
p_out_mask, :
|
p_out_mask,
|
||||||
|
:
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]` but
|
||||||
|
# got `Tensor`.
|
||||||
][torch.arange(n_p_out), out_pts_idx]
|
][torch.arange(n_p_out), out_pts_idx]
|
||||||
|
|
||||||
# Get Network object mask
|
# Get Network object mask
|
||||||
@@ -434,10 +441,16 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
secant_pts
|
secant_pts
|
||||||
]
|
]
|
||||||
z_low = pts_intervals[secant_pts][
|
z_low = pts_intervals[secant_pts][
|
||||||
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
|
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
|
||||||
|
# but got `Tensor`.
|
||||||
|
torch.arange(n_secant_pts),
|
||||||
|
sampler_pts_ind[secant_pts] - 1,
|
||||||
]
|
]
|
||||||
sdf_low = sdf_val[secant_pts][
|
sdf_low = sdf_val[secant_pts][
|
||||||
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
|
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
|
||||||
|
# but got `Tensor`.
|
||||||
|
torch.arange(n_secant_pts),
|
||||||
|
sampler_pts_ind[secant_pts] - 1,
|
||||||
]
|
]
|
||||||
cam_loc_secant = cam_loc.reshape(-1, 3)[mask_intersect_idx[secant_pts]]
|
cam_loc_secant = cam_loc.reshape(-1, 3)[mask_intersect_idx[secant_pts]]
|
||||||
ray_directions_secant = ray_directions.reshape((-1, 3))[
|
ray_directions_secant = ray_directions.reshape((-1, 3))[
|
||||||
@@ -514,6 +527,7 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
mask_max_dis = max_dis[mask].unsqueeze(-1)
|
mask_max_dis = max_dis[mask].unsqueeze(-1)
|
||||||
mask_min_dis = min_dis[mask].unsqueeze(-1)
|
mask_min_dis = min_dis[mask].unsqueeze(-1)
|
||||||
steps = (
|
steps = (
|
||||||
|
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
|
||||||
steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis)
|
steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis)
|
||||||
+ mask_min_dis
|
+ mask_min_dis
|
||||||
)
|
)
|
||||||
@@ -533,8 +547,13 @@ class RayTracing(Configurable, nn.Module):
|
|||||||
mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
|
mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
|
||||||
min_vals, min_idx = mask_sdf_all.min(-1)
|
min_vals, min_idx = mask_sdf_all.min(-1)
|
||||||
min_mask_points = mask_points_all.reshape(-1, n, 3)[
|
min_mask_points = mask_points_all.reshape(-1, n, 3)[
|
||||||
torch.arange(0, n_mask_points), min_idx
|
# pyre-fixme[6]: For 2nd param expected `Union[bool, float, int]` but
|
||||||
|
# got `Tensor`.
|
||||||
|
torch.arange(0, n_mask_points),
|
||||||
|
min_idx,
|
||||||
]
|
]
|
||||||
|
# pyre-fixme[6]: For 2nd param expected `Union[bool, float, int]` but got
|
||||||
|
# `Tensor`.
|
||||||
min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
|
min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
|
||||||
|
|
||||||
return min_mask_points, min_mask_dist
|
return min_mask_points, min_mask_dist
|
||||||
@@ -553,7 +572,8 @@ def _get_sphere_intersection(
|
|||||||
# cam_loc = cam_loc.unsqueeze(-1)
|
# cam_loc = cam_loc.unsqueeze(-1)
|
||||||
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
|
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
|
||||||
ray_cam_dot = (ray_directions * cam_loc).sum(-1) # n_images x n_rays
|
ray_cam_dot = (ray_directions * cam_loc).sum(-1) # n_images x n_rays
|
||||||
under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, dim=-1) ** 2 - r ** 2)
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
|
under_sqrt = ray_cam_dot**2 - (cam_loc.norm(2, dim=-1) ** 2 - r**2)
|
||||||
|
|
||||||
under_sqrt = under_sqrt.reshape(-1)
|
under_sqrt = under_sqrt.reshape(-1)
|
||||||
mask_intersect = under_sqrt > 0
|
mask_intersect = under_sqrt > 0
|
||||||
|
|||||||
@@ -4,51 +4,99 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, Tuple, Union
|
from typing import Any, Callable, Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import RendererOutput
|
||||||
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
|
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
|
||||||
|
|
||||||
|
|
||||||
_TTensor = torch.Tensor
|
_TTensor = torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class GenericRaymarcher(torch.nn.Module):
|
class RaymarcherBase(ReplaceableBase):
|
||||||
|
"""
|
||||||
|
Defines a base class for raymarchers. Specifically, a raymarcher is responsible
|
||||||
|
for taking a set of features and density descriptors along rendering rays
|
||||||
|
and marching along them in order to generate a feature render.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
rays_densities: torch.Tensor,
|
||||||
|
rays_features: torch.Tensor,
|
||||||
|
aux: Dict[str, Any],
|
||||||
|
) -> RendererOutput:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rays_densities: Per-ray density values represented with a tensor
|
||||||
|
of shape `(..., n_points_per_ray, 1)`.
|
||||||
|
rays_features: Per-ray feature values represented with a tensor
|
||||||
|
of shape `(..., n_points_per_ray, feature_dim)`.
|
||||||
|
aux: a dictionary with extra information.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
|
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
|
||||||
and NeuralVolumes' Accumulative ray marcher. It additionally returns
|
and NeuralVolumes' cumsum ray marcher. It additionally returns
|
||||||
the rendering weights that can be used in the NVS pipeline to carry out
|
the rendering weights that can be used in the NVS pipeline to carry out
|
||||||
the importance ray-sampling in the refining pass.
|
the importance ray-sampling in the refining pass.
|
||||||
Different from `EmissionAbsorptionRaymarcher`, it takes raw
|
Different from `pytorch3d.renderer.EmissionAbsorptionRaymarcher`, it takes raw
|
||||||
(non-exponentiated) densities.
|
(non-exponentiated) densities.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
bg_color: background_color. Must be of shape (1,) or (feature_dim,)
|
surface_thickness: The thickness of the raymarched surface.
|
||||||
|
bg_color: The background color. A tuple of either 1 element or of D elements,
|
||||||
|
where D matches the feature dimensionality; it is broadcast when necessary.
|
||||||
|
background_opacity: The raw opacity value (i.e. before exponentiation)
|
||||||
|
of the background.
|
||||||
|
density_relu: If `True`, passes the input density through ReLU before
|
||||||
|
raymarching.
|
||||||
|
blend_output: If `True`, alpha-blends the output renders with the
|
||||||
|
background color using the rendered opacity mask.
|
||||||
|
|
||||||
|
capping_function: The capping function of the raymarcher.
|
||||||
|
Options:
|
||||||
|
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
|
||||||
|
- "cap1" (`cap_fn(x) = min(x, 1)`)
|
||||||
|
Set to "exponential" for the standard Emission Absorption raymarching.
|
||||||
|
weight_function: The weighting function of the raymarcher.
|
||||||
|
Options:
|
||||||
|
- "product" (`weight_fn(w, x) = w * x`)
|
||||||
|
- "minimum" (`weight_fn(w, x) = min(w, x)`)
|
||||||
|
Set to "product" for the standard Emission Absorption raymarching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
surface_thickness: int = 1
|
||||||
self,
|
bg_color: Tuple[float, ...] = (0.0,)
|
||||||
surface_thickness: int = 1,
|
background_opacity: float = 0.0
|
||||||
bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,),
|
density_relu: bool = True
|
||||||
capping_function: str = "exponential", # exponential | cap1
|
blend_output: bool = False
|
||||||
weight_function: str = "product", # product | minimum
|
|
||||||
background_opacity: float = 0.0,
|
@property
|
||||||
density_relu: bool = True,
|
def capping_function_type(self) -> str:
|
||||||
blend_output: bool = True,
|
raise NotImplementedError()
|
||||||
):
|
|
||||||
|
@property
|
||||||
|
def weight_function_type(self) -> str:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
surface_thickness: Denotes the overlap between the absorption
|
surface_thickness: Denotes the overlap between the absorption
|
||||||
function and the density function.
|
function and the density function.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.surface_thickness = surface_thickness
|
|
||||||
self.density_relu = density_relu
|
|
||||||
self.background_opacity = background_opacity
|
|
||||||
self.blend_output = blend_output
|
|
||||||
if not isinstance(bg_color, torch.Tensor):
|
|
||||||
bg_color = torch.tensor(bg_color)
|
|
||||||
|
|
||||||
|
bg_color = torch.tensor(self.bg_color)
|
||||||
if bg_color.ndim != 1:
|
if bg_color.ndim != 1:
|
||||||
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
||||||
|
|
||||||
@@ -57,12 +105,12 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
self._capping_function: Callable[[_TTensor], _TTensor] = {
|
self._capping_function: Callable[[_TTensor], _TTensor] = {
|
||||||
"exponential": lambda x: 1.0 - torch.exp(-x),
|
"exponential": lambda x: 1.0 - torch.exp(-x),
|
||||||
"cap1": lambda x: x.clamp(max=1.0),
|
"cap1": lambda x: x.clamp(max=1.0),
|
||||||
}[capping_function]
|
}[self.capping_function_type]
|
||||||
|
|
||||||
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
|
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
|
||||||
"product": lambda curr, acc: curr * acc,
|
"product": lambda curr, acc: curr * acc,
|
||||||
"minimum": lambda curr, acc: torch.minimum(curr, acc),
|
"minimum": lambda curr, acc: torch.minimum(curr, acc),
|
||||||
}[weight_function]
|
}[self.weight_function_type]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -71,7 +119,8 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
aux: Dict[str, Any],
|
aux: Dict[str, Any],
|
||||||
ray_lengths: torch.Tensor,
|
ray_lengths: torch.Tensor,
|
||||||
density_noise_std: float = 0.0,
|
density_noise_std: float = 0.0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
**kwargs,
|
||||||
|
) -> RendererOutput:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
rays_densities: Per-ray density values represented with a tensor
|
rays_densities: Per-ray density values represented with a tensor
|
||||||
@@ -87,7 +136,7 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
features: A tensor of shape `(..., feature_dim)` containing
|
features: A tensor of shape `(..., feature_dim)` containing
|
||||||
the rendered features for each ray.
|
the rendered features for each ray.
|
||||||
depth: A tensor of shape `(..., 1)` containing estimated depth.
|
depth: A tensor of shape `(..., 1)` containing estimated depth.
|
||||||
opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
|
opacities: A tensor of shape `(..., 1)` containing rendered opacities.
|
||||||
weights: A tensor of shape `(..., n_points_per_ray)` containing
|
weights: A tensor of shape `(..., n_points_per_ray)` containing
|
||||||
the ray-specific non-negative opacity weights. In general, they
|
the ray-specific non-negative opacity weights. In general, they
|
||||||
don't sum to 1 but do not overcome it, i.e.
|
don't sum to 1 but do not overcome it, i.e.
|
||||||
@@ -113,16 +162,15 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
rays_densities = rays_densities[..., 0]
|
rays_densities = rays_densities[..., 0]
|
||||||
|
|
||||||
if density_noise_std > 0.0:
|
if density_noise_std > 0.0:
|
||||||
rays_densities = (
|
noise: _TTensor = torch.randn_like(rays_densities).mul(density_noise_std)
|
||||||
rays_densities + torch.randn_like(rays_densities) * density_noise_std
|
rays_densities = rays_densities + noise
|
||||||
)
|
|
||||||
if self.density_relu:
|
if self.density_relu:
|
||||||
rays_densities = torch.relu(rays_densities)
|
rays_densities = torch.relu(rays_densities)
|
||||||
|
|
||||||
weighted_densities = deltas * rays_densities
|
weighted_densities = deltas * rays_densities
|
||||||
capped_densities = self._capping_function(weighted_densities)
|
capped_densities = self._capping_function(weighted_densities) # pyre-ignore: 29
|
||||||
|
|
||||||
rays_opacities = self._capping_function(
|
rays_opacities = self._capping_function( # pyre-ignore: 29
|
||||||
torch.cumsum(weighted_densities, dim=-1)
|
torch.cumsum(weighted_densities, dim=-1)
|
||||||
)
|
)
|
||||||
opacities = rays_opacities[..., -1:]
|
opacities = rays_opacities[..., -1:]
|
||||||
@@ -131,7 +179,9 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
absorption_shifted[..., : self.surface_thickness] = 1.0
|
absorption_shifted[..., : self.surface_thickness] = 1.0
|
||||||
|
|
||||||
weights = self._weight_function(capped_densities, absorption_shifted)
|
weights = self._weight_function( # pyre-ignore: 29
|
||||||
|
capped_densities, absorption_shifted
|
||||||
|
)
|
||||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||||
|
|
||||||
@@ -140,4 +190,42 @@ class GenericRaymarcher(torch.nn.Module):
|
|||||||
raise ValueError("Wrong number of background color channels.")
|
raise ValueError("Wrong number of background color channels.")
|
||||||
features = alpha * features + (1 - opacities) * self._bg_color
|
features = alpha * features + (1 - opacities) * self._bg_color
|
||||||
|
|
||||||
return features, depth, opacities, weights, aux
|
return RendererOutput(
|
||||||
|
features=features,
|
||||||
|
depths=depth,
|
||||||
|
masks=opacities,
|
||||||
|
weights=weights,
|
||||||
|
aux=aux,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class EmissionAbsorptionRaymarcher(AccumulativeRaymarcherBase):
|
||||||
|
"""
|
||||||
|
Implements the EmissionAbsorption raymarcher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
background_opacity: float = 1e10
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capping_function_type(self) -> str:
|
||||||
|
return "exponential"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight_function_type(self) -> str:
|
||||||
|
return "product"
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class CumsumRaymarcher(AccumulativeRaymarcherBase):
|
||||||
|
"""
|
||||||
|
Implements the NeuralVolumes' cumulative-sum raymarcher.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capping_function_type(self) -> str:
|
||||||
|
return "cap1"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def weight_function_type(self) -> str:
|
||||||
|
return "minimum"
|
||||||
|
|||||||
@@ -16,6 +16,30 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class RayNormalColoringNetwork(torch.nn.Module):
|
class RayNormalColoringNetwork(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Members:
|
||||||
|
d_in and feature_vector_size: Sum of these is the input
|
||||||
|
dimension. These must add up to the sum of
|
||||||
|
- 3 [for the points]
|
||||||
|
- 3 unless mode=no_normal [for the normals]
|
||||||
|
- 3 unless mode=no_view_dir [for view directions]
|
||||||
|
- the feature size, [number of channels in feature_vectors]
|
||||||
|
|
||||||
|
d_out: dimension of output.
|
||||||
|
mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting
|
||||||
|
part of the network input.
|
||||||
|
dims: list of hidden layer sizes.
|
||||||
|
weight_norm: whether to apply weight normalization to each layer.
|
||||||
|
n_harmonic_functions_dir:
|
||||||
|
If >0, use a harmonic embedding with this number of
|
||||||
|
harmonic functions for the view direction. Otherwise view directions
|
||||||
|
are fed without embedding, unless mode is `no_view_dir`.
|
||||||
|
pooled_feature_dim: If a pooling function is in use (provided as
|
||||||
|
pooling_fn to forward()) this must be its number of features.
|
||||||
|
Otherwise this must be set to 0. (If used from GenericModel,
|
||||||
|
this will be set automatically.)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_vector_size: int = 3,
|
feature_vector_size: int = 3,
|
||||||
|
|||||||
@@ -132,7 +132,11 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
eik_bounding_box: float = self.object_bounding_sphere
|
eik_bounding_box: float = self.object_bounding_sphere
|
||||||
n_eik_points = batch_size * num_pixels // 2
|
n_eik_points = batch_size * num_pixels // 2
|
||||||
eikonal_points = torch.empty(
|
eikonal_points = torch.empty(
|
||||||
n_eik_points, 3, device=self._bg_color.device
|
n_eik_points,
|
||||||
|
3,
|
||||||
|
# pyre-fixme[6]: For 3rd param expected `Union[None, str, device]`
|
||||||
|
# but got `Union[device, Tensor, Module]`.
|
||||||
|
device=self._bg_color.device,
|
||||||
).uniform_(-eik_bounding_box, eik_bounding_box)
|
).uniform_(-eik_bounding_box, eik_bounding_box)
|
||||||
eikonal_pixel_points = points.clone()
|
eikonal_pixel_points = points.clone()
|
||||||
eikonal_pixel_points = eikonal_pixel_points.detach()
|
eikonal_pixel_points = eikonal_pixel_points.detach()
|
||||||
@@ -196,7 +200,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|||||||
pooling_fn=None, # TODO
|
pooling_fn=None, # TODO
|
||||||
)
|
)
|
||||||
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
|
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
|
||||||
-self.soft_mask_alpha * sdf_output[~surface_mask]
|
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
||||||
|
-self.soft_mask_alpha
|
||||||
|
* sdf_output[~surface_mask]
|
||||||
)
|
)
|
||||||
|
|
||||||
# scatter points with surface_mask
|
# scatter points with surface_mask
|
||||||
|
|||||||
5
pytorch3d/implicitron/models/view_pooler/__init__.py
Normal file
5
pytorch3d/implicitron/models/view_pooler/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
@@ -6,11 +6,11 @@
|
|||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Optional, Sequence, Union
|
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
|
from pytorch3d.implicitron.models.view_pooler.view_sampler import (
|
||||||
cameras_points_cartesian_product,
|
cameras_points_cartesian_product,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
@@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_aggregated_feature_dim(
|
||||||
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Returns the final dimensionality of the output aggregated features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding
|
||||||
|
to the `feats_sampled` argument of `forward`,
|
||||||
|
or an `int` representing the sum of dimensionalities of each `t_i`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
aggregated_feature_dim: The final dimensionality of the output
|
||||||
|
aggregated features.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def has_aggregation(self) -> bool:
|
||||||
|
"""
|
||||||
|
Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
has_aggregation: `True` if `reduce_dim==1`, else `False`.
|
||||||
|
"""
|
||||||
|
return hasattr(self, "reduction_functions")
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||||
@@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, [])
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -147,7 +176,7 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
the stack of source-view-specific features to a single feature.
|
the stack of source-view-specific features to a single feature.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reduction_functions: Sequence[ReductionFunction] = (
|
reduction_functions: Tuple[ReductionFunction, ...] = (
|
||||||
ReductionFunction.AVG,
|
ReductionFunction.AVG,
|
||||||
ReductionFunction.STD,
|
ReductionFunction.STD,
|
||||||
)
|
)
|
||||||
@@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(
|
||||||
|
feats_or_feats_dim, self.reduction_functions
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -236,7 +269,7 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
|||||||
used when calculating the angle-based aggregation weights.
|
used when calculating the angle-based aggregation weights.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
reduction_functions: Sequence[ReductionFunction] = (
|
reduction_functions: Tuple[ReductionFunction, ...] = (
|
||||||
ReductionFunction.AVG,
|
ReductionFunction.AVG,
|
||||||
ReductionFunction.STD,
|
ReductionFunction.STD,
|
||||||
)
|
)
|
||||||
@@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(
|
||||||
|
feats_or_feats_dim, self.reduction_functions
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
def get_aggregated_feature_dim(
|
||||||
return _get_reduction_aggregator_feature_dim(feats, [])
|
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||||
|
):
|
||||||
|
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -511,7 +550,6 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
|
|||||||
# torch.Tensor, torch.nn.modules.module.Module]` is not a function.
|
# torch.Tensor, torch.nn.modules.module.Module]` is not a function.
|
||||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
|
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
|
||||||
camera_rep.T[:, None],
|
camera_rep.T[:, None],
|
||||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
|
|
||||||
camera_rep.R.permute(0, 2, 1),
|
camera_rep.R.permute(0, 2, 1),
|
||||||
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
||||||
# cam_centers_rep = camera_rep.get_camera_center().reshape(
|
# cam_centers_rep = camera_rep.get_camera_center().reshape(
|
||||||
@@ -610,6 +648,7 @@ def _avgmaxstd_reduction_function(
|
|||||||
x_aggr = torch.cat(pooled_features, dim=-1)
|
x_aggr = torch.cat(pooled_features, dim=-1)
|
||||||
|
|
||||||
# zero out features that were all masked out
|
# zero out features that were all masked out
|
||||||
|
# pyre-fixme[16]: `bool` has no attribute `type_as`.
|
||||||
any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr)
|
any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr)
|
||||||
x_aggr = x_aggr * any_active[..., None]
|
x_aggr = x_aggr * any_active[..., None]
|
||||||
|
|
||||||
@@ -637,6 +676,7 @@ def _std_reduction_function(
|
|||||||
):
|
):
|
||||||
if mu is None:
|
if mu is None:
|
||||||
mu = _avg_reduction_function(x, w, dim=dim)
|
mu = _avg_reduction_function(x, w, dim=dim)
|
||||||
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt()
|
std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt()
|
||||||
# FIXME: somehow this is extremely heavy in mem?
|
# FIXME: somehow this is extremely heavy in mem?
|
||||||
return std
|
return std
|
||||||
128
pytorch3d/implicitron/models/view_pooler/view_pooler.py
Normal file
128
pytorch3d/implicitron/models/view_pooler/view_pooler.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.tools.config import Configurable, run_auto_creation
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
from .feature_aggregator import FeatureAggregatorBase
|
||||||
|
from .view_sampler import ViewSampler
|
||||||
|
|
||||||
|
|
||||||
|
# pyre-ignore: 13
|
||||||
|
class ViewPooler(Configurable, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implements sampling of image-based features at the 2d projections of a set
|
||||||
|
of 3D points, and a subsequent aggregation of the resulting set of features
|
||||||
|
per-point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
view_sampler: An instance of ViewSampler which is used for sampling of
|
||||||
|
image-based features at the 2D projections of a set
|
||||||
|
of 3D points.
|
||||||
|
feature_aggregator_class_type: The name of the feature aggregator class which
|
||||||
|
is available in the global registry.
|
||||||
|
feature_aggregator: A feature aggregator class which inherits from
|
||||||
|
FeatureAggregatorBase. Typically, the aggregated features and their
|
||||||
|
masks are output by a `ViewSampler` which samples feature tensors extracted
|
||||||
|
from a set of source images. FeatureAggregator executes step (4) above.
|
||||||
|
"""
|
||||||
|
|
||||||
|
view_sampler: ViewSampler
|
||||||
|
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
||||||
|
feature_aggregator: FeatureAggregatorBase
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||||
|
"""
|
||||||
|
Returns the final dimensionality of the output aggregated features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
feats: Either a `dict` of sampled features `{f_i: t_i}` corresponding
|
||||||
|
to the `feats_sampled` argument of `feature_aggregator,forward`,
|
||||||
|
or an `int` representing the sum of dimensionalities of each `t_i`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
aggregated_feature_dim: The final dimensionality of the output
|
||||||
|
aggregated features.
|
||||||
|
"""
|
||||||
|
return self.feature_aggregator.get_aggregated_feature_dim(feats)
|
||||||
|
|
||||||
|
def has_aggregation(self):
|
||||||
|
"""
|
||||||
|
Specifies whether the `feature_aggregator` reduces the output `reduce_dim`
|
||||||
|
dimension to 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
has_aggregation: `True` if `reduce_dim==1`, else `False`.
|
||||||
|
"""
|
||||||
|
return self.feature_aggregator.has_aggregation()
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*, # force kw args
|
||||||
|
pts: torch.Tensor,
|
||||||
|
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
|
||||||
|
camera: CamerasBase,
|
||||||
|
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
|
||||||
|
feats: Dict[str, torch.Tensor],
|
||||||
|
masks: Optional[torch.Tensor],
|
||||||
|
**kwargs,
|
||||||
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Project each point cloud from a batch of point clouds to corresponding
|
||||||
|
input cameras, sample features at the 2D projection locations in a batch
|
||||||
|
of source images, and aggregate the pointwise sampled features.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
|
||||||
|
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
|
||||||
|
from which `pts` were extracted, or a list of string names.
|
||||||
|
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
|
||||||
|
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
|
||||||
|
corresponding to cameras in `camera`, or a list of string names.
|
||||||
|
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
|
||||||
|
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
|
||||||
|
masks: `[n_cameras x 1 x H x W]`, define valid image regions
|
||||||
|
for sampling `feats`.
|
||||||
|
Returns:
|
||||||
|
feats_aggregated: If `feature_aggregator.concatenate_output==True`, a tensor
|
||||||
|
of shape `(pts_batch, reduce_dim, n_pts, sum(dim_1, ... dim_N))`
|
||||||
|
containing the aggregated features. `reduce_dim` depends on
|
||||||
|
the specific feature aggregator implementation and typically
|
||||||
|
equals 1 or `n_cameras`.
|
||||||
|
If `feature_aggregator.concatenate_output==False`, the aggregator
|
||||||
|
does not concatenate the aggregated features and returns a dictionary
|
||||||
|
of per-feature aggregations `{f_i: t_i_aggregated}` instead.
|
||||||
|
Each `t_i_aggregated` is of shape
|
||||||
|
`(pts_batch, reduce_dim, n_pts, aggr_dim_i)`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# (1) Sample features and masks at the ray points
|
||||||
|
sampled_feats, sampled_masks = self.view_sampler(
|
||||||
|
pts=pts,
|
||||||
|
seq_id_pts=seq_id_pts,
|
||||||
|
camera=camera,
|
||||||
|
seq_id_camera=seq_id_camera,
|
||||||
|
feats=feats,
|
||||||
|
masks=masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# (2) Aggregate features from multiple views
|
||||||
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
|
feats_aggregated = self.feature_aggregator( # noqa: E731
|
||||||
|
sampled_feats,
|
||||||
|
sampled_masks,
|
||||||
|
pts=pts,
|
||||||
|
camera=camera,
|
||||||
|
) # TODO: do we need to pass a callback rather than compute here?
|
||||||
|
|
||||||
|
return feats_aggregated
|
||||||
@@ -205,7 +205,11 @@ def handle_seq_id(
|
|||||||
if not torch.is_tensor(seq_id):
|
if not torch.is_tensor(seq_id):
|
||||||
if isinstance(seq_id[0], str):
|
if isinstance(seq_id[0], str):
|
||||||
seq_id = [hash(s) for s in seq_id]
|
seq_id = [hash(s) for s in seq_id]
|
||||||
|
# pyre-fixme[9]: seq_id has type `Union[List[int], List[str], LongTensor]`;
|
||||||
|
# used as `Tensor`.
|
||||||
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
|
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
|
||||||
|
# pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has
|
||||||
|
# no attribute `to`.
|
||||||
return seq_id.to(device)
|
return seq_id.to(device)
|
||||||
|
|
||||||
|
|
||||||
@@ -287,5 +291,7 @@ def cameras_points_cartesian_product(
|
|||||||
)
|
)
|
||||||
.reshape(batch_pts * n_cameras)
|
.reshape(batch_pts * n_cameras)
|
||||||
)
|
)
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[List[int], int, LongTensor]` but
|
||||||
|
# got `Tensor`.
|
||||||
camera_rep = camera[idx_cams]
|
camera_rep = camera[idx_cams]
|
||||||
return camera_rep, pts_rep
|
return camera_rep, pts_rep
|
||||||
@@ -215,7 +215,6 @@ class BatchLinear(nn.Module):
|
|||||||
def last_hyper_layer_init(m) -> None:
|
def last_hyper_layer_init(m) -> None:
|
||||||
if type(m) == nn.Linear:
|
if type(m) == nn.Linear:
|
||||||
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
|
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
|
||||||
# pyre-fixme[41]: `data` cannot be reassigned. It is a read-only property.
|
|
||||||
m.weight.data *= 1e-1
|
m.weight.data *= 1e-1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,7 @@ def volumetric_camera_overlaps(
|
|||||||
"""
|
"""
|
||||||
device = cameras.device
|
device = cameras.device
|
||||||
ba = cameras.R.shape[0]
|
ba = cameras.R.shape[0]
|
||||||
n_vox = int(resol ** 3)
|
n_vox = int(resol**3)
|
||||||
grid = pt3d.structures.Volumes(
|
grid = pt3d.structures.Volumes(
|
||||||
densities=torch.zeros([1, 1, resol, resol, resol], device=device),
|
densities=torch.zeros([1, 1, resol, resol, resol], device=device),
|
||||||
volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
|
volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
|
||||||
|
|||||||
@@ -102,13 +102,14 @@ def fit_circle_in_2d(
|
|||||||
Circle2D object
|
Circle2D object
|
||||||
"""
|
"""
|
||||||
design = torch.cat([points2d, torch.ones_like(points2d[:, :1])], dim=1)
|
design = torch.cat([points2d, torch.ones_like(points2d[:, :1])], dim=1)
|
||||||
rhs = (points2d ** 2).sum(1)
|
rhs = (points2d**2).sum(1)
|
||||||
n_provided = points2d.shape[0]
|
n_provided = points2d.shape[0]
|
||||||
if n_provided < 3:
|
if n_provided < 3:
|
||||||
raise ValueError(f"{n_provided} points are not enough to determine a circle")
|
raise ValueError(f"{n_provided} points are not enough to determine a circle")
|
||||||
solution = lstsq(design, rhs)
|
solution = lstsq(design, rhs[:, None])
|
||||||
center = solution[:2] / 2
|
center = solution[:2, 0] / 2
|
||||||
radius = torch.sqrt(solution[2] + (center ** 2).sum())
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
|
radius = torch.sqrt(solution[2, 0] + (center**2).sum())
|
||||||
if n_points > 0:
|
if n_points > 0:
|
||||||
if angles is not None:
|
if angles is not None:
|
||||||
warnings.warn("n_points ignored because angles provided")
|
warnings.warn("n_points ignored because angles provided")
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import sys
|
|||||||
import warnings
|
import warnings
|
||||||
from collections import Counter, defaultdict
|
from collections import Counter, defaultdict
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||||
|
|
||||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||||
@@ -175,6 +176,9 @@ _unprocessed_warning: str = (
|
|||||||
TYPE_SUFFIX: str = "_class_type"
|
TYPE_SUFFIX: str = "_class_type"
|
||||||
ARGS_SUFFIX: str = "_args"
|
ARGS_SUFFIX: str = "_args"
|
||||||
ENABLED_SUFFIX: str = "_enabled"
|
ENABLED_SUFFIX: str = "_enabled"
|
||||||
|
CREATE_PREFIX: str = "create_"
|
||||||
|
IMPL_SUFFIX: str = "_impl"
|
||||||
|
TWEAK_SUFFIX: str = "_tweak_args"
|
||||||
|
|
||||||
|
|
||||||
class ReplaceableBase:
|
class ReplaceableBase:
|
||||||
@@ -216,6 +220,7 @@ class Configurable:
|
|||||||
|
|
||||||
|
|
||||||
_X = TypeVar("X", bound=ReplaceableBase)
|
_X = TypeVar("X", bound=ReplaceableBase)
|
||||||
|
_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
|
||||||
|
|
||||||
|
|
||||||
class _Registry:
|
class _Registry:
|
||||||
@@ -259,13 +264,9 @@ class _Registry:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot register {some_class}. Cannot tell what it is."
|
f"Cannot register {some_class}. Cannot tell what it is."
|
||||||
)
|
)
|
||||||
if some_class is base_class:
|
|
||||||
raise ValueError(f"Attempted to register the base class {some_class}")
|
|
||||||
self._mapping[base_class][name] = some_class
|
self._mapping[base_class][name] = some_class
|
||||||
|
|
||||||
def get(
|
def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]:
|
||||||
self, base_class_wanted: Type[ReplaceableBase], name: str
|
|
||||||
) -> Type[ReplaceableBase]:
|
|
||||||
"""
|
"""
|
||||||
Retrieve a class from the registry by name
|
Retrieve a class from the registry by name
|
||||||
|
|
||||||
@@ -293,6 +294,7 @@ class _Registry:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
||||||
)
|
)
|
||||||
|
# pyre-ignore[7]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_all(
|
def get_all(
|
||||||
@@ -306,20 +308,23 @@ class _Registry:
|
|||||||
It determines the namespace.
|
It determines the namespace.
|
||||||
This will typically be a direct subclass of ReplaceableBase.
|
This will typically be a direct subclass of ReplaceableBase.
|
||||||
Returns:
|
Returns:
|
||||||
list of class types
|
list of class types in alphabetical order of registered name.
|
||||||
"""
|
"""
|
||||||
if self._is_base_class(base_class_wanted):
|
if self._is_base_class(base_class_wanted):
|
||||||
return list(self._mapping[base_class_wanted].values())
|
source = self._mapping[base_class_wanted]
|
||||||
|
return [source[key] for key in sorted(source)]
|
||||||
|
|
||||||
base_class = self._base_class_from_class(base_class_wanted)
|
base_class = self._base_class_from_class(base_class_wanted)
|
||||||
if base_class is None:
|
if base_class is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
||||||
)
|
)
|
||||||
|
source = self._mapping[base_class]
|
||||||
return [
|
return [
|
||||||
class_
|
source[key]
|
||||||
for class_ in self._mapping[base_class].values()
|
for key in sorted(source)
|
||||||
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
|
if issubclass(source[key], base_class_wanted)
|
||||||
|
and source[key] is not base_class_wanted
|
||||||
]
|
]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -375,25 +380,68 @@ def _default_create(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Function taking one argument, the object whose member should be
|
Function taking one argument, the object whose member should be
|
||||||
initialized.
|
initialized, i.e. self.
|
||||||
"""
|
"""
|
||||||
|
impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||||
|
|
||||||
def inner(self):
|
def inner(self):
|
||||||
expand_args_fields(type_)
|
expand_args_fields(type_)
|
||||||
|
impl = getattr(self, impl_name)
|
||||||
args = getattr(self, name + ARGS_SUFFIX)
|
args = getattr(self, name + ARGS_SUFFIX)
|
||||||
setattr(self, name, type_(**args))
|
impl(True, args)
|
||||||
|
|
||||||
def inner_optional(self):
|
def inner_optional(self):
|
||||||
expand_args_fields(type_)
|
expand_args_fields(type_)
|
||||||
|
impl = getattr(self, impl_name)
|
||||||
enabled = getattr(self, name + ENABLED_SUFFIX)
|
enabled = getattr(self, name + ENABLED_SUFFIX)
|
||||||
if enabled:
|
|
||||||
args = getattr(self, name + ARGS_SUFFIX)
|
args = getattr(self, name + ARGS_SUFFIX)
|
||||||
|
impl(enabled, args)
|
||||||
|
|
||||||
|
def inner_pluggable(self):
|
||||||
|
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||||
|
impl = getattr(self, impl_name)
|
||||||
|
if type_name is None:
|
||||||
|
args = None
|
||||||
|
else:
|
||||||
|
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None)
|
||||||
|
impl(type_name, args)
|
||||||
|
|
||||||
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
|
return inner_optional
|
||||||
|
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||||
|
|
||||||
|
|
||||||
|
def _default_create_impl(
|
||||||
|
name: str, type_: Type, process_type: _ProcessType
|
||||||
|
) -> Callable[[Any, Any, DictConfig], None]:
|
||||||
|
"""
|
||||||
|
Return the default internal function for initialising a member. This is a function
|
||||||
|
which could be called in the create_ function to initialise the member.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: name of the member
|
||||||
|
type_: type of the member (with any Optional removed)
|
||||||
|
process_type: Shows whether member's declared type inherits ReplaceableBase,
|
||||||
|
in which case the actual type to be created is decided at
|
||||||
|
runtime.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function taking
|
||||||
|
- self, the object whose member should be initialized.
|
||||||
|
- option for what to do. This is
|
||||||
|
- for pluggables, the type to initialise or None to do nothing
|
||||||
|
- for non pluggables, a bool indicating whether to initialise.
|
||||||
|
- the args for initializing the member.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def create_configurable(self, enabled, args):
|
||||||
|
if enabled:
|
||||||
|
expand_args_fields(type_)
|
||||||
setattr(self, name, type_(**args))
|
setattr(self, name, type_(**args))
|
||||||
else:
|
else:
|
||||||
setattr(self, name, None)
|
setattr(self, name, None)
|
||||||
|
|
||||||
def inner_pluggable(self):
|
def create_pluggable(self, type_name, args):
|
||||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
|
||||||
if type_name is None:
|
if type_name is None:
|
||||||
setattr(self, name, None)
|
setattr(self, name, None)
|
||||||
return
|
return
|
||||||
@@ -408,12 +456,11 @@ def _default_create(
|
|||||||
# were made in the redefinition will not be reflected here.
|
# were made in the redefinition will not be reflected here.
|
||||||
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
||||||
expand_args_fields(chosen_class)
|
expand_args_fields(chosen_class)
|
||||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
|
||||||
setattr(self, name, chosen_class(**args))
|
setattr(self, name, chosen_class(**args))
|
||||||
|
|
||||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE):
|
||||||
return inner_optional
|
return create_configurable
|
||||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
return create_pluggable
|
||||||
|
|
||||||
|
|
||||||
def run_auto_creation(self: Any) -> None:
|
def run_auto_creation(self: Any) -> None:
|
||||||
@@ -567,6 +614,9 @@ def _params_iter(C):
|
|||||||
|
|
||||||
|
|
||||||
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
||||||
|
if val is None:
|
||||||
|
return True
|
||||||
|
|
||||||
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
||||||
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
||||||
if isinstance(val, PRIMITIVE_TYPES):
|
if isinstance(val, PRIMITIVE_TYPES):
|
||||||
@@ -601,17 +651,19 @@ def _is_actually_dataclass(some_class) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def expand_args_fields(
|
def expand_args_fields(
|
||||||
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
|
some_class: Type[_Y], *, _do_not_process: Tuple[type, ...] = ()
|
||||||
) -> Type[_X]:
|
) -> Type[_Y]:
|
||||||
"""
|
"""
|
||||||
This expands a class which inherits Configurable or ReplaceableBase classes,
|
This expands a class which inherits Configurable or ReplaceableBase classes,
|
||||||
including dataclass processing. some_class is modified in place by this function.
|
including dataclass processing. some_class is modified in place by this function.
|
||||||
|
If expand_args_fields(some_class) has already been called, subsequent calls do
|
||||||
|
nothing and return some_class unmodified.
|
||||||
For classes of type ReplaceableBase, you can add some_class to the registry before
|
For classes of type ReplaceableBase, you can add some_class to the registry before
|
||||||
or after calling this function. But potential inner classes need to be registered
|
or after calling this function. But potential inner classes need to be registered
|
||||||
before this function is run on the outer class.
|
before this function is run on the outer class.
|
||||||
|
|
||||||
The transformations this function makes, before the concluding
|
The transformations this function makes, before the concluding
|
||||||
dataclasses.dataclass, are as follows. if X is a base class with registered
|
dataclasses.dataclass, are as follows. If X is a base class with registered
|
||||||
subclasses Y and Z, replace a class member
|
subclasses Y and Z, replace a class member
|
||||||
|
|
||||||
x: X
|
x: X
|
||||||
@@ -626,9 +678,12 @@ def expand_args_fields(
|
|||||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
self.x = registry.get(X, self.x_class_type)(
|
args = self.getattr(f"x_{self.x_class_type}_args")
|
||||||
**self.getattr(f"x_{self.x_class_type}_args)
|
self.create_x_impl(self.x_class_type, args)
|
||||||
)
|
def create_x_impl(self, x_type, args):
|
||||||
|
x_type = registry.get(X, x_type)
|
||||||
|
expand_args_fields(x_type)
|
||||||
|
self.x = x_type(**args)
|
||||||
x_class_type: str = "UNDEFAULTED"
|
x_class_type: str = "UNDEFAULTED"
|
||||||
|
|
||||||
without adding the optional attributes if they are already there.
|
without adding the optional attributes if they are already there.
|
||||||
@@ -648,12 +703,19 @@ def expand_args_fields(
|
|||||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
if self.x_class_type is None:
|
if self.x_class_type is None:
|
||||||
|
args = None
|
||||||
|
else:
|
||||||
|
args = self.getattr(f"x_{self.x_class_type}_args", None)
|
||||||
|
self.create_x_impl(self.x_class_type, args)
|
||||||
|
def create_x_impl(self, x_class_type, args):
|
||||||
|
if x_class_type is None:
|
||||||
self.x = None
|
self.x = None
|
||||||
return
|
return
|
||||||
|
|
||||||
self.x = registry.get(X, self.x_class_type)(
|
x_type = registry.get(X, x_class_type)
|
||||||
**self.getattr(f"x_{self.x_class_type}_args)
|
expand_args_fields(x_type)
|
||||||
)
|
assert args is not None
|
||||||
|
self.x = x_type(**args)
|
||||||
x_class_type: Optional[str] = "UNDEFAULTED"
|
x_class_type: Optional[str] = "UNDEFAULTED"
|
||||||
|
|
||||||
without adding the optional attributes if they are already there.
|
without adding the optional attributes if they are already there.
|
||||||
@@ -670,7 +732,14 @@ def expand_args_fields(
|
|||||||
|
|
||||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
self.x = X(self.x_args)
|
self.create_x_impl(True, self.x_args)
|
||||||
|
|
||||||
|
def create_x_impl(self, enabled, args):
|
||||||
|
if enabled:
|
||||||
|
expand_args_fields(X)
|
||||||
|
self.x = X(**args)
|
||||||
|
else:
|
||||||
|
self.x = None
|
||||||
|
|
||||||
Similarly, replace,
|
Similarly, replace,
|
||||||
|
|
||||||
@@ -686,8 +755,12 @@ def expand_args_fields(
|
|||||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||||
x_enabled: bool = False
|
x_enabled: bool = False
|
||||||
def create_x(self):
|
def create_x(self):
|
||||||
if self.x_enabled:
|
self.create_x_impl(self.x_enabled, self.x_args)
|
||||||
self.x = X(self.x_args)
|
|
||||||
|
def create_x_impl(self, enabled, args):
|
||||||
|
if enabled:
|
||||||
|
expand_args_fields(X)
|
||||||
|
self.x = X(**args)
|
||||||
else:
|
else:
|
||||||
self.x = None
|
self.x = None
|
||||||
|
|
||||||
@@ -695,7 +768,7 @@ def expand_args_fields(
|
|||||||
Also adds the following class members, unannotated so that dataclass
|
Also adds the following class members, unannotated so that dataclass
|
||||||
ignores them.
|
ignores them.
|
||||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
- _creation_functions: Tuple[str] of all the create_ functions,
|
||||||
including those from base classes.
|
including those from base classes (not the create_x_impl ones).
|
||||||
- _known_implementations: Dict[str, Type] containing the classes which
|
- _known_implementations: Dict[str, Type] containing the classes which
|
||||||
have been found from the registry.
|
have been found from the registry.
|
||||||
(used only to raise a warning if it one has been overwritten)
|
(used only to raise a warning if it one has been overwritten)
|
||||||
@@ -703,6 +776,14 @@ def expand_args_fields(
|
|||||||
transformed, with values giving the types they were declared to have.
|
transformed, with values giving the types they were declared to have.
|
||||||
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
||||||
|
|
||||||
|
In addition, if the class has a member function
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
|
||||||
|
|
||||||
|
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||||
|
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
some_class: the class to be processed
|
some_class: the class to be processed
|
||||||
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
||||||
@@ -779,19 +860,29 @@ def expand_args_fields(
|
|||||||
return some_class
|
return some_class
|
||||||
|
|
||||||
|
|
||||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
def get_default_args_field(
|
||||||
|
C,
|
||||||
|
*,
|
||||||
|
_do_not_process: Tuple[type, ...] = (),
|
||||||
|
_hook: Optional[Callable[[DictConfig], None]] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Get a dataclass field which defaults to get_default_args(...)
|
Get a dataclass field which defaults to get_default_args(...)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
As for get_default_args.
|
C: As for get_default_args.
|
||||||
|
_do_not_process: As for get_default_args
|
||||||
|
_hook: Function called on the result before returning.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function to return new DictConfig object
|
function to return new DictConfig object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def create():
|
def create():
|
||||||
return get_default_args(C, _do_not_process=_do_not_process)
|
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||||
|
if _hook is not None:
|
||||||
|
_hook(args)
|
||||||
|
return args
|
||||||
|
|
||||||
return dataclasses.field(default_factory=create)
|
return dataclasses.field(default_factory=create)
|
||||||
|
|
||||||
@@ -854,6 +945,7 @@ def _process_member(
|
|||||||
# sure they go at the end of __annotations__ in case
|
# sure they go at the end of __annotations__ in case
|
||||||
# there are non-defaulted standard class members.
|
# there are non-defaulted standard class members.
|
||||||
del some_class.__annotations__[name]
|
del some_class.__annotations__[name]
|
||||||
|
hook = getattr(some_class, name + TWEAK_SUFFIX, None)
|
||||||
|
|
||||||
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
||||||
type_name = name + TYPE_SUFFIX
|
type_name = name + TYPE_SUFFIX
|
||||||
@@ -879,11 +971,17 @@ def _process_member(
|
|||||||
f"Cannot generate {args_name} because it is already present."
|
f"Cannot generate {args_name} because it is already present."
|
||||||
)
|
)
|
||||||
some_class.__annotations__[args_name] = DictConfig
|
some_class.__annotations__[args_name] = DictConfig
|
||||||
|
if hook is not None:
|
||||||
|
hook_closed = partial(hook, derived_type)
|
||||||
|
else:
|
||||||
|
hook_closed = None
|
||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
args_name,
|
args_name,
|
||||||
get_default_args_field(
|
get_default_args_field(
|
||||||
derived_type, _do_not_process=_do_not_process + (some_class,)
|
derived_type,
|
||||||
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
|
_hook=hook_closed,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -896,12 +994,17 @@ def _process_member(
|
|||||||
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
||||||
|
|
||||||
some_class.__annotations__[args_name] = DictConfig
|
some_class.__annotations__[args_name] = DictConfig
|
||||||
|
if hook is not None:
|
||||||
|
hook_closed = partial(hook, type_)
|
||||||
|
else:
|
||||||
|
hook_closed = None
|
||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
args_name,
|
args_name,
|
||||||
get_default_args_field(
|
get_default_args_field(
|
||||||
type_,
|
type_,
|
||||||
_do_not_process=_do_not_process + (some_class,),
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
|
_hook=hook_closed,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||||
@@ -910,7 +1013,7 @@ def _process_member(
|
|||||||
some_class.__annotations__[enabled_name] = bool
|
some_class.__annotations__[enabled_name] = bool
|
||||||
setattr(some_class, enabled_name, False)
|
setattr(some_class, enabled_name, False)
|
||||||
|
|
||||||
creation_function_name = f"create_{name}"
|
creation_function_name = f"{CREATE_PREFIX}{name}"
|
||||||
if not hasattr(some_class, creation_function_name):
|
if not hasattr(some_class, creation_function_name):
|
||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
@@ -919,6 +1022,14 @@ def _process_member(
|
|||||||
)
|
)
|
||||||
creation_functions.append(creation_function_name)
|
creation_functions.append(creation_function_name)
|
||||||
|
|
||||||
|
creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||||
|
if not hasattr(some_class, creation_function_impl_name):
|
||||||
|
setattr(
|
||||||
|
some_class,
|
||||||
|
creation_function_impl_name,
|
||||||
|
_default_create_impl(name, type_, process_type),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_unused_components(dict_: DictConfig) -> None:
|
def remove_unused_components(dict_: DictConfig) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ def cleanup_eval_depth(
|
|||||||
# the threshold is a sigma-multiple of the standard deviation of the depth
|
# the threshold is a sigma-multiple of the standard deviation of the depth
|
||||||
mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
|
mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
|
||||||
std = (
|
std = (
|
||||||
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||||
wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
|
wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
|
||||||
.clamp(1e-4)
|
.clamp(1e-4)
|
||||||
.sqrt()
|
.sqrt()
|
||||||
@@ -58,11 +59,10 @@ def cleanup_eval_depth(
|
|||||||
good_df_thr = std * sigma
|
good_df_thr = std * sigma
|
||||||
good_depth = (df <= good_df_thr).float() * pcl_mask
|
good_depth = (df <= good_df_thr).float() * pcl_mask
|
||||||
|
|
||||||
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
|
# perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
|
||||||
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
|
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
|
||||||
|
|
||||||
good_depth_raster = torch.zeros_like(depth).view(ba, -1)
|
good_depth_raster = torch.zeros_like(depth).view(ba, -1)
|
||||||
# pyre-ignore[16]: scatter_add_
|
|
||||||
good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)
|
good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)
|
||||||
|
|
||||||
good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()
|
good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ def generate_eval_video_cameras(
|
|||||||
trajectory_scale: float = 0.2,
|
trajectory_scale: float = 0.2,
|
||||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||||
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
||||||
focal_length: Optional[torch.FloatTensor] = None,
|
focal_length: Optional[torch.Tensor] = None,
|
||||||
principal_point: Optional[torch.FloatTensor] = None,
|
principal_point: Optional[torch.Tensor] = None,
|
||||||
time: Optional[torch.FloatTensor] = None,
|
time: Optional[torch.Tensor] = None,
|
||||||
infer_up_as_plane_normal: bool = True,
|
infer_up_as_plane_normal: bool = True,
|
||||||
traj_offset: Optional[Tuple[float, float, float]] = None,
|
traj_offset: Optional[Tuple[float, float, float]] = None,
|
||||||
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
||||||
@@ -200,9 +200,6 @@ def _visdom_plot_scene(
|
|||||||
|
|
||||||
viz = Visdom()
|
viz = Visdom()
|
||||||
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
|
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
|
||||||
import pdb
|
|
||||||
|
|
||||||
pdb.set_trace()
|
|
||||||
|
|
||||||
|
|
||||||
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
|
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
|
||||||
|
|||||||
@@ -5,7 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from typing import Union
|
from typing import Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -14,7 +14,7 @@ def mask_background(
|
|||||||
image_rgb: torch.Tensor,
|
image_rgb: torch.Tensor,
|
||||||
mask_fg: torch.Tensor,
|
mask_fg: torch.Tensor,
|
||||||
dim_color: int = 1,
|
dim_color: int = 1,
|
||||||
bg_color: Union[torch.Tensor, str, float] = 0.0,
|
bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Mask the background input image tensor `image_rgb` with `bg_color`.
|
Mask the background input image tensor `image_rgb` with `bg_color`.
|
||||||
@@ -26,9 +26,11 @@ def mask_background(
|
|||||||
# obtain the background color tensor
|
# obtain the background color tensor
|
||||||
if isinstance(bg_color, torch.Tensor):
|
if isinstance(bg_color, torch.Tensor):
|
||||||
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
|
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
|
||||||
elif isinstance(bg_color, float):
|
elif isinstance(bg_color, (float, tuple, list)):
|
||||||
|
if isinstance(bg_color, float):
|
||||||
|
bg_color = [bg_color] * 3
|
||||||
bg_color_t = torch.tensor(
|
bg_color_t = torch.tensor(
|
||||||
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
|
bg_color, device=image_rgb.device, dtype=image_rgb.dtype
|
||||||
).view(*tgt_view)
|
).view(*tgt_view)
|
||||||
elif isinstance(bg_color, str):
|
elif isinstance(bg_color, str):
|
||||||
if bg_color == "white":
|
if bg_color == "white":
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user