mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-28 09:16:00 +08:00
Compare commits
106 Commits
v0.6.2
...
classner-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7c1f026ea | ||
|
|
cb49550486 | ||
|
|
36edf2b302 | ||
|
|
78bb6d17fa | ||
|
|
54c75b4114 | ||
|
|
3783437d2f | ||
|
|
b2dc520210 | ||
|
|
8597d4c5c1 | ||
|
|
38fd8380f7 | ||
|
|
67840f8320 | ||
|
|
9b2e570536 | ||
|
|
0f966217e5 | ||
|
|
379c8b2780 | ||
|
|
8e0c82b89a | ||
|
|
8ba9a694ee | ||
|
|
36ba079bef | ||
|
|
b95ec190af | ||
|
|
55f67b0d18 | ||
|
|
4261e59f51 | ||
|
|
af55ba01f8 | ||
|
|
d3b7f5f421 | ||
|
|
4ecc9ea89d | ||
|
|
8d10ba52b2 | ||
|
|
aa8b03f31d | ||
|
|
57a40b3688 | ||
|
|
522e5f0644 | ||
|
|
e8390d3500 | ||
|
|
4300030d7a | ||
|
|
00acf0b0c7 | ||
|
|
a94f3f4c4b | ||
|
|
efb721320a | ||
|
|
40fb189c29 | ||
|
|
4e87c2b7f1 | ||
|
|
771cf8a328 | ||
|
|
0dce883241 | ||
|
|
ae35824f21 | ||
|
|
f4dd151037 | ||
|
|
7ce8ed55e1 | ||
|
|
7e0146ece4 | ||
|
|
0e4c53c612 | ||
|
|
879495d38f | ||
|
|
5c1ca757bb | ||
|
|
3e4fb0b9d9 | ||
|
|
731ea53c80 | ||
|
|
2e42ef793f | ||
|
|
81d63c6382 | ||
|
|
28c1afaa9d | ||
|
|
cba26506b6 | ||
|
|
65f667fd2e | ||
|
|
7978ffd1e4 | ||
|
|
ea4f3260e4 | ||
|
|
023a2369ae | ||
|
|
c0f88e04a0 | ||
|
|
6275283202 | ||
|
|
1d43251391 | ||
|
|
1fb268dea6 | ||
|
|
8bc0a04e86 | ||
|
|
5cd70067e2 | ||
|
|
5b74a2cc27 | ||
|
|
49ed7b07b1 | ||
|
|
c6519f29f0 | ||
|
|
a42a89a5ba | ||
|
|
c31bf85a23 | ||
|
|
fbd3c679ac | ||
|
|
34f648ede0 | ||
|
|
f625fe1f8b | ||
|
|
7c25d34d22 | ||
|
|
c5a83f46ef | ||
|
|
1702c85bec | ||
|
|
90d00f1b2b | ||
|
|
d27ef14ec7 | ||
|
|
2d1c6d5d93 | ||
|
|
9fe15da3cd | ||
|
|
0f12c51646 | ||
|
|
79c61a2d86 | ||
|
|
69c6d06ed8 | ||
|
|
73dc109dba | ||
|
|
9ec9d057cc | ||
|
|
cd7b885169 | ||
|
|
f632c423ef | ||
|
|
f36b11fe49 | ||
|
|
ea5df60d72 | ||
|
|
4372001981 | ||
|
|
61e2b87019 | ||
|
|
0143d63ba8 | ||
|
|
899a3192b6 | ||
|
|
3b2300641a | ||
|
|
b5f3d3ce12 | ||
|
|
2c1901522a | ||
|
|
90ab219d88 | ||
|
|
9e57b994ca | ||
|
|
e767c4b548 | ||
|
|
e85fa03c5a | ||
|
|
47d06c8924 | ||
|
|
bef959c755 | ||
|
|
c21ba144e7 | ||
|
|
d737a05e55 | ||
|
|
2374d19da5 | ||
|
|
1f3953795c | ||
|
|
a6dada399d | ||
|
|
5c59841863 | ||
|
|
2c64635daa | ||
|
|
ec9580a1d4 | ||
|
|
44cb00e468 | ||
|
|
44ca5f95d9 | ||
|
|
a51a300827 |
@@ -81,7 +81,7 @@ jobs:
|
||||
command: |
|
||||
export LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64
|
||||
python3 setup.py build_ext --inplace
|
||||
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests
|
||||
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests -t .
|
||||
- run: python3 setup.py bdist_wheel
|
||||
|
||||
binary_linux_wheel:
|
||||
@@ -182,23 +182,23 @@ workflows:
|
||||
# context: DOCKERHUB_TOKEN
|
||||
{{workflows()}}
|
||||
- binary_linux_conda_cuda:
|
||||
name: testrun_conda_cuda_py37_cu102_pyt170
|
||||
name: testrun_conda_cuda_py37_cu102_pyt190
|
||||
context: DOCKERHUB_TOKEN
|
||||
python_version: "3.7"
|
||||
pytorch_version: '1.7.0'
|
||||
pytorch_version: '1.9.0'
|
||||
cu_version: "cu102"
|
||||
- binary_macos_wheel:
|
||||
cu_version: cpu
|
||||
name: macos_wheel_py37_cpu
|
||||
python_version: '3.7'
|
||||
pytorch_version: '1.9.0'
|
||||
pytorch_version: '1.12.0'
|
||||
- binary_macos_wheel:
|
||||
cu_version: cpu
|
||||
name: macos_wheel_py38_cpu
|
||||
python_version: '3.8'
|
||||
pytorch_version: '1.9.0'
|
||||
pytorch_version: '1.12.0'
|
||||
- binary_macos_wheel:
|
||||
cu_version: cpu
|
||||
name: macos_wheel_py39_cpu
|
||||
python_version: '3.9'
|
||||
pytorch_version: '1.9.0'
|
||||
pytorch_version: '1.12.0'
|
||||
|
||||
@@ -81,7 +81,7 @@ jobs:
|
||||
command: |
|
||||
export LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64
|
||||
python3 setup.py build_ext --inplace
|
||||
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests
|
||||
- run: LD_LIBRARY_PATH=$LD_LIBARY_PATH:/usr/local/cuda-11.3/lib64 python -m unittest discover -v -s tests -t .
|
||||
- run: python3 setup.py bdist_wheel
|
||||
|
||||
binary_linux_wheel:
|
||||
@@ -180,42 +180,6 @@ workflows:
|
||||
jobs:
|
||||
# - main:
|
||||
# context: DOCKERHUB_TOKEN
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
name: linux_conda_py37_cu101_pyt170
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.7.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py37_cu102_pyt170
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.7.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu110
|
||||
name: linux_conda_py37_cu110_pyt170
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.7.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
name: linux_conda_py37_cu101_pyt171
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.7.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py37_cu102_pyt171
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.7.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu110
|
||||
name: linux_conda_py37_cu110_pyt171
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.7.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
@@ -359,42 +323,26 @@ workflows:
|
||||
name: linux_conda_py37_cu115_pyt1110
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.11.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
name: linux_conda_py38_cu101_pyt170
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.7.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py38_cu102_pyt170
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.7.0
|
||||
name: linux_conda_py37_cu102_pyt1120
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda113
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu110
|
||||
name: linux_conda_py38_cu110_pyt170
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.7.0
|
||||
cu_version: cu113
|
||||
name: linux_conda_py37_cu113_pyt1120
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda116
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
name: linux_conda_py38_cu101_pyt171
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.7.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py38_cu102_pyt171
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.7.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu110
|
||||
name: linux_conda_py38_cu110_pyt171
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.7.1
|
||||
cu_version: cu116
|
||||
name: linux_conda_py37_cu116_pyt1120
|
||||
python_version: '3.7'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
@@ -538,24 +486,26 @@ workflows:
|
||||
name: linux_conda_py38_cu115_pyt1110
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.11.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
name: linux_conda_py39_cu101_pyt171
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.7.1
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py39_cu102_pyt171
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.7.1
|
||||
name: linux_conda_py38_cu102_pyt1120
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda113
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu110
|
||||
name: linux_conda_py39_cu110_pyt171
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.7.1
|
||||
cu_version: cu113
|
||||
name: linux_conda_py38_cu113_pyt1120
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda116
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu116
|
||||
name: linux_conda_py38_cu116_pyt1120
|
||||
python_version: '3.8'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu101
|
||||
@@ -699,6 +649,26 @@ workflows:
|
||||
name: linux_conda_py39_cu115_pyt1110
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.11.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py39_cu102_pyt1120
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda113
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu113
|
||||
name: linux_conda_py39_cu113_pyt1120
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda116
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu116
|
||||
name: linux_conda_py39_cu116_pyt1120
|
||||
python_version: '3.9'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
@@ -725,24 +695,44 @@ workflows:
|
||||
name: linux_conda_py310_cu115_pyt1110
|
||||
python_version: '3.10'
|
||||
pytorch_version: 1.11.0
|
||||
- binary_linux_conda:
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu102
|
||||
name: linux_conda_py310_cu102_pyt1120
|
||||
python_version: '3.10'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda113
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu113
|
||||
name: linux_conda_py310_cu113_pyt1120
|
||||
python_version: '3.10'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda:
|
||||
conda_docker_image: pytorch/conda-builder:cuda116
|
||||
context: DOCKERHUB_TOKEN
|
||||
cu_version: cu116
|
||||
name: linux_conda_py310_cu116_pyt1120
|
||||
python_version: '3.10'
|
||||
pytorch_version: 1.12.0
|
||||
- binary_linux_conda_cuda:
|
||||
name: testrun_conda_cuda_py37_cu102_pyt170
|
||||
name: testrun_conda_cuda_py37_cu102_pyt190
|
||||
context: DOCKERHUB_TOKEN
|
||||
python_version: "3.7"
|
||||
pytorch_version: '1.7.0'
|
||||
pytorch_version: '1.9.0'
|
||||
cu_version: "cu102"
|
||||
- binary_macos_wheel:
|
||||
cu_version: cpu
|
||||
name: macos_wheel_py37_cpu
|
||||
python_version: '3.7'
|
||||
pytorch_version: '1.9.0'
|
||||
pytorch_version: '1.12.0'
|
||||
- binary_macos_wheel:
|
||||
cu_version: cpu
|
||||
name: macos_wheel_py38_cpu
|
||||
python_version: '3.8'
|
||||
pytorch_version: '1.9.0'
|
||||
pytorch_version: '1.12.0'
|
||||
- binary_macos_wheel:
|
||||
cu_version: cpu
|
||||
name: macos_wheel_py39_cpu
|
||||
python_version: '3.9'
|
||||
pytorch_version: '1.9.0'
|
||||
pytorch_version: '1.12.0'
|
||||
|
||||
@@ -20,8 +20,6 @@ from packaging import version
|
||||
# version of pytorch.
|
||||
# Pytorch 1.4 also supports cuda 10.0 but we no longer build for cuda 10.0 at all.
|
||||
CONDA_CUDA_VERSIONS = {
|
||||
"1.7.0": ["cu101", "cu102", "cu110"],
|
||||
"1.7.1": ["cu101", "cu102", "cu110"],
|
||||
"1.8.0": ["cu101", "cu102", "cu111"],
|
||||
"1.8.1": ["cu101", "cu102", "cu111"],
|
||||
"1.9.0": ["cu102", "cu111"],
|
||||
@@ -30,15 +28,20 @@ CONDA_CUDA_VERSIONS = {
|
||||
"1.10.1": ["cu102", "cu111", "cu113"],
|
||||
"1.10.2": ["cu102", "cu111", "cu113"],
|
||||
"1.11.0": ["cu102", "cu111", "cu113", "cu115"],
|
||||
"1.12.0": ["cu102", "cu113", "cu116"],
|
||||
}
|
||||
|
||||
|
||||
def conda_docker_image_for_cuda(cuda_version):
|
||||
if cuda_version in ("cu101", "cu102", "cu111"):
|
||||
return None
|
||||
if cuda_version == "cu113":
|
||||
return "pytorch/conda-builder:cuda113"
|
||||
if cuda_version == "cu115":
|
||||
return "pytorch/conda-builder:cuda115"
|
||||
return None
|
||||
if cuda_version == "cu116":
|
||||
return "pytorch/conda-builder:cuda116"
|
||||
raise ValueError("Unknown cuda version")
|
||||
|
||||
|
||||
def pytorch_versions_for_python(python_version):
|
||||
|
||||
14
INSTALL.md
14
INSTALL.md
@@ -9,7 +9,7 @@ The core library is written in PyTorch. Several components have underlying imple
|
||||
|
||||
- Linux or macOS or Windows
|
||||
- Python 3.6, 3.7, 3.8 or 3.9
|
||||
- PyTorch 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2 or 1.11.0.
|
||||
- PyTorch 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0 or 1.12.0.
|
||||
- torchvision that matches the PyTorch installation. You can install them together as explained at pytorch.org to make sure of this.
|
||||
- gcc & g++ ≥ 4.9
|
||||
- [fvcore](https://github.com/facebookresearch/fvcore)
|
||||
@@ -78,7 +78,7 @@ Or, to install a nightly (non-official, alpha) build:
|
||||
conda install pytorch3d -c pytorch3d-nightly
|
||||
```
|
||||
### 2. Install from PyPI, on Mac only.
|
||||
This works with pytorch 1.9.0 only. The build is CPU only.
|
||||
This works with pytorch 1.12.0 only. The build is CPU only.
|
||||
```
|
||||
pip install pytorch3d
|
||||
```
|
||||
@@ -87,9 +87,9 @@ pip install pytorch3d
|
||||
We have prebuilt wheels with CUDA for Linux for PyTorch 1.11.0, for each of the supported CUDA versions,
|
||||
for Python 3.7, 3.8 and 3.9. This is for ease of use on Google Colab.
|
||||
These are installed in a special way.
|
||||
For example, to install for Python 3.8, PyTorch 1.11.0 and CUDA 10.2
|
||||
For example, to install for Python 3.8, PyTorch 1.11.0 and CUDA 11.3
|
||||
```
|
||||
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu102_pyt1110/download.html
|
||||
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html
|
||||
```
|
||||
|
||||
In general, from inside IPython, or in Google Colab or a jupyter notebook, you can install with
|
||||
@@ -147,10 +147,10 @@ After any necessary patching, you can go to "x64 Native Tools Command Prompt for
|
||||
cd pytorch3d
|
||||
python3 setup.py install
|
||||
```
|
||||
After installing, verify whether all unit tests have passed
|
||||
|
||||
After installing, you can run **unit tests**
|
||||
```
|
||||
cd tests
|
||||
python3 -m unittest discover -p *.py
|
||||
python3 -m unittest discover -v -s tests -t .
|
||||
```
|
||||
|
||||
# FAQ
|
||||
|
||||
@@ -46,3 +46,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
NeRF https://github.com/bmild/nerf/
|
||||
|
||||
Copyright (c) 2020 bmild
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
@@ -7,22 +7,10 @@
|
||||
|
||||
# Run this script at project root by "./dev/linter.sh" before you commit
|
||||
|
||||
{
|
||||
V=$(black --version|cut '-d ' -f3)
|
||||
code='import distutils.version; assert "19.3" < distutils.version.LooseVersion("'$V'")'
|
||||
PYTHON=false
|
||||
command -v python > /dev/null && PYTHON=python
|
||||
command -v python3 > /dev/null && PYTHON=python3
|
||||
${PYTHON} -c "${code}" 2> /dev/null
|
||||
} || {
|
||||
echo "Linter requires black 19.3b0 or higher!"
|
||||
exit 1
|
||||
}
|
||||
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
DIR=$(dirname "${DIR}")
|
||||
|
||||
if [[ -f "${DIR}/tests/TARGETS" ]]
|
||||
if [[ -f "${DIR}/TARGETS" ]]
|
||||
then
|
||||
pyfmt "${DIR}"
|
||||
else
|
||||
@@ -42,7 +30,7 @@ clangformat=$(command -v clang-format-8 || echo clang-format)
|
||||
find "${DIR}" -regex ".*\.\(cpp\|c\|cc\|cu\|cuh\|cxx\|h\|hh\|hpp\|hxx\|tcc\|mm\|m\)" -print0 | xargs -0 "${clangformat}" -i
|
||||
|
||||
# Run arc and pyre internally only.
|
||||
if [[ -f "${DIR}/tests/TARGETS" ]]
|
||||
if [[ -f "${DIR}/TARGETS" ]]
|
||||
then
|
||||
(cd "${DIR}"; command -v arc > /dev/null && arc lint) || true
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ def tests_from_file(path: Path, base: str) -> List[str]:
|
||||
|
||||
def main() -> None:
|
||||
files = get_test_files()
|
||||
test_root = Path(__file__).parent.parent / "tests"
|
||||
test_root = Path(__file__).parent.parent
|
||||
all_tests = []
|
||||
for f in files:
|
||||
file_base = str(f.relative_to(test_root))[:-3].replace("/", ".")
|
||||
|
||||
@@ -3,7 +3,6 @@ API Documentation
|
||||
|
||||
.. toctree::
|
||||
|
||||
common
|
||||
structures
|
||||
io
|
||||
loss
|
||||
@@ -12,3 +11,5 @@ API Documentation
|
||||
transforms
|
||||
utils
|
||||
datasets
|
||||
common
|
||||
vis
|
||||
|
||||
6
docs/modules/vis.rst
Normal file
6
docs/modules/vis.rst
Normal file
@@ -0,0 +1,6 @@
|
||||
pytorch3d.vis
|
||||
===========================
|
||||
|
||||
.. automodule:: pytorch3d.vis
|
||||
:members:
|
||||
:undoc-members:
|
||||
@@ -8,3 +8,4 @@
|
||||
sudo docker run --rm -v "$PWD/../../:/inside" pytorch/conda-cuda bash inside/packaging/linux_wheels/inside.sh
|
||||
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu113 pytorch/conda-builder:cuda113 bash inside/packaging/linux_wheels/inside.sh
|
||||
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu115 pytorch/conda-builder:cuda115 bash inside/packaging/linux_wheels/inside.sh
|
||||
sudo docker run --rm -v "$PWD/../../:/inside" -e SELECTED_CUDA=cu116 pytorch/conda-builder:cuda116 bash inside/packaging/linux_wheels/inside.sh
|
||||
|
||||
@@ -60,7 +60,7 @@ do
|
||||
|
||||
for cu_version in ${CONDA_CUDA_VERSIONS[$pytorch_version]}
|
||||
do
|
||||
if [[ "cu113 cu115" == *$cu_version* ]]
|
||||
if [[ "cu113 cu115 cu116" == *$cu_version* ]]
|
||||
# ^^^ CUDA versions listed here have to be built
|
||||
# in their own containers.
|
||||
then
|
||||
@@ -74,6 +74,11 @@ do
|
||||
fi
|
||||
|
||||
case "$cu_version" in
|
||||
cu116)
|
||||
export CUDA_HOME=/usr/local/cuda-11.6/
|
||||
export CUDA_TAG=11.6
|
||||
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
|
||||
;;
|
||||
cu115)
|
||||
export CUDA_HOME=/usr/local/cuda-11.5/
|
||||
export CUDA_TAG=11.5
|
||||
@@ -124,6 +129,7 @@ do
|
||||
|
||||
conda create -y -n "$tag" "python=$python_version"
|
||||
conda activate "$tag"
|
||||
# shellcheck disable=SC2086
|
||||
conda install -y -c pytorch $extra_channel "pytorch=$pytorch_version" "cudatoolkit=$CUDA_TAG" torchvision
|
||||
pip install fvcore iopath
|
||||
echo "python version" "$python_version" "pytorch version" "$pytorch_version" "cuda version" "$cu_version" "tag" "$tag"
|
||||
|
||||
@@ -55,6 +55,17 @@ setup_cuda() {
|
||||
|
||||
# Now work out the CUDA settings
|
||||
case "$CU_VERSION" in
|
||||
cu116)
|
||||
if [[ "$OSTYPE" == "msys" ]]; then
|
||||
export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.6"
|
||||
else
|
||||
export CUDA_HOME=/usr/local/cuda-11.6/
|
||||
fi
|
||||
export FORCE_CUDA=1
|
||||
# Hard-coding gencode flags is temporary situation until
|
||||
# https://github.com/pytorch/pytorch/pull/23408 lands
|
||||
export NVCC_FLAGS="-gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
|
||||
;;
|
||||
cu115)
|
||||
if [[ "$OSTYPE" == "msys" ]]; then
|
||||
export CUDA_HOME="C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.5"
|
||||
@@ -304,6 +315,9 @@ setup_conda_cudatoolkit_constraint() {
|
||||
export CONDA_CUDATOOLKIT_CONSTRAINT=""
|
||||
else
|
||||
case "$CU_VERSION" in
|
||||
cu116)
|
||||
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.6,<11.7 # [not osx]"
|
||||
;;
|
||||
cu115)
|
||||
export CONDA_CUDATOOLKIT_CONSTRAINT="- cudatoolkit >=11.5,<11.6 # [not osx]"
|
||||
;;
|
||||
|
||||
@@ -45,9 +45,12 @@ test:
|
||||
- docs
|
||||
requires:
|
||||
- imageio
|
||||
- hydra-core
|
||||
- accelerate
|
||||
- lpips
|
||||
commands:
|
||||
#pytest .
|
||||
python -m unittest discover -v -s tests
|
||||
python -m unittest discover -v -s tests -t .
|
||||
|
||||
|
||||
about:
|
||||
|
||||
@@ -5,7 +5,7 @@ Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling t
|
||||
# License
|
||||
|
||||
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
|
||||
It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
|
||||
It includes code from the [NeRF](https://github.com/bmild/nerf), [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
|
||||
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ Only configuration can be changed (see [Configuration system](#configuration-sys
|
||||
For this setup, install the dependencies and PyTorch3D from conda following [the guide](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#1-install-with-cuda-support-from-anaconda-cloud-on-linux-only). Then, install implicitron-specific dependencies:
|
||||
|
||||
```shell
|
||||
pip install "hydra-core>=1.1" visdom lpips matplotlib
|
||||
pip install "hydra-core>=1.1" visdom lpips matplotlib accelerate
|
||||
```
|
||||
|
||||
Runner executable is available as `pytorch3d_implicitron_runner` shell command.
|
||||
@@ -49,7 +49,7 @@ Please follow the instructions to [install PyTorch3D from a local clone](https:/
|
||||
Then, install Implicitron-specific dependencies:
|
||||
|
||||
```shell
|
||||
pip install "hydra-core>=1.1" visdom lpips matplotlib
|
||||
pip install "hydra-core>=1.1" visdom lpips matplotlib accelerate
|
||||
```
|
||||
|
||||
You are still encouraged to implement custom plugins as above where possible as it makes reusing the code easier.
|
||||
@@ -66,7 +66,8 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
|
||||
To run training, pass a yaml config file, followed by a list of overridden arguments.
|
||||
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
|
||||
```shell
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
||||
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
|
||||
```
|
||||
|
||||
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
|
||||
@@ -84,7 +85,8 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
|
||||
|
||||
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
|
||||
```shell
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
||||
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
|
||||
```
|
||||
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
|
||||
|
||||
@@ -202,7 +204,7 @@ to replace the implementation and potentially override the parameters.
|
||||
# Code and config structure
|
||||
|
||||
As per above, the config structure is parsed automatically from the module hierarchy.
|
||||
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node.
|
||||
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node.
|
||||
|
||||
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
|
||||
```
|
||||
@@ -224,7 +226,8 @@ generic_model_args: GenericModel
|
||||
└-- hypernet_args: SRNRaymarchHyperNet
|
||||
└-- pixel_generator_args: SRNPixelGenerator
|
||||
╘== IdrFeatureField
|
||||
└-- image_feature_extractor_args: ResNetFeatureExtractor
|
||||
└-- image_feature_extractor_*_args: FeatureExtractorBase
|
||||
╘== ResNetFeatureExtractor
|
||||
└-- view_sampler_args: ViewSampler
|
||||
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
||||
╘== IdentityFeatureAggregator
|
||||
@@ -232,8 +235,9 @@ generic_model_args: GenericModel
|
||||
╘== AngleWeightedReductionFeatureAggregator
|
||||
╘== ReductionFeatureAggregator
|
||||
solver_args: init_optimizer
|
||||
dataset_args: dataset_zoo
|
||||
dataloader_args: dataloader_zoo
|
||||
data_source_args: ImplicitronDataSource
|
||||
└-- dataset_map_provider_*_args
|
||||
└-- data_loader_map_provider_*_args
|
||||
```
|
||||
|
||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
||||
|
||||
@@ -5,29 +5,22 @@ exp_dir: ./data/exps/base/
|
||||
architecture: generic
|
||||
visualize_interval: 0
|
||||
visdom_port: 8097
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}"
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
data_source_args:
|
||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
dataset_JsonIndexDataset_args:
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
generic_model_args:
|
||||
loss_weights:
|
||||
loss_mask_bce: 1.0
|
||||
@@ -49,10 +42,8 @@ generic_model_args:
|
||||
append_xyz:
|
||||
- 5
|
||||
latent_dim: 0
|
||||
raysampler_args:
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
min_depth: 0.0
|
||||
max_depth: 0.0
|
||||
scene_extent: 8.0
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
@@ -63,9 +54,10 @@ generic_model_args:
|
||||
n_pts_per_ray_fine_evaluation: 64
|
||||
append_coarse_samples_to_fine: true
|
||||
density_noise_std_train: 1.0
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
image_feature_extractor_args:
|
||||
view_pooler_args:
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
stages:
|
||||
- 1
|
||||
- 2
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
generic_model_args:
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
add_images: true
|
||||
add_masks: true
|
||||
first_max_pool: true
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
generic_model_args:
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
add_images: true
|
||||
add_masks: true
|
||||
first_max_pool: false
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
generic_model_args:
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
stages:
|
||||
- 1
|
||||
- 2
|
||||
@@ -11,6 +12,7 @@ generic_model_args:
|
||||
name: resnet34
|
||||
normalize_image: true
|
||||
pretrained: true
|
||||
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||
reduction_functions:
|
||||
- AVG
|
||||
view_pooler_args:
|
||||
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||
reduction_functions:
|
||||
- AVG
|
||||
|
||||
@@ -1,31 +1,35 @@
|
||||
defaults:
|
||||
- repro_base.yaml
|
||||
- _self_
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_args:
|
||||
assert_single_seq: false
|
||||
dataset_name: co3d_multisequence
|
||||
load_point_clouds: false
|
||||
mask_depths: false
|
||||
mask_images: false
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
train_conditioning_type: SAME
|
||||
val_conditioning_type: SAME
|
||||
test_conditioning_type: SAME
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
assert_single_seq: false
|
||||
task_str: multisequence
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: true
|
||||
test_restrict_sequence_id: 0
|
||||
solver_args:
|
||||
max_epochs: 3000
|
||||
milestones:
|
||||
- 1000
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.666667
|
||||
- 0.833334
|
||||
|
||||
@@ -11,11 +11,12 @@ generic_model_args:
|
||||
num_passes: 1
|
||||
output_rasterized_mc: true
|
||||
sampling_mode_training: mask_sample
|
||||
view_pool: false
|
||||
sequence_autodecoder_args:
|
||||
n_instances: 20000
|
||||
init_scale: 1.0
|
||||
encoding_dim: 256
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
n_instances: 20000
|
||||
init_scale: 1.0
|
||||
encoding_dim: 256
|
||||
implicit_function_IdrFeatureField_args:
|
||||
n_harmonic_functions_xyz: 6
|
||||
bias: 0.6
|
||||
@@ -55,7 +56,7 @@ generic_model_args:
|
||||
n_harmonic_functions_dir: 4
|
||||
pooled_feature_dim: 0
|
||||
weight_norm: true
|
||||
raysampler_args:
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
n_pts_per_ray_training: 0
|
||||
n_pts_per_ray_evaluation: 0
|
||||
|
||||
@@ -3,7 +3,9 @@ defaults:
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: false
|
||||
sequence_autodecoder_args:
|
||||
n_instances: 20000
|
||||
encoding_dim: 256
|
||||
view_pooler_enabled: false
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
n_instances: 20000
|
||||
encoding_dim: 256
|
||||
|
||||
@@ -5,6 +5,6 @@ defaults:
|
||||
clip_grad: 1.0
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: true
|
||||
raysampler_args:
|
||||
view_pooler_enabled: true
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 850
|
||||
|
||||
@@ -4,8 +4,7 @@ defaults:
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: true
|
||||
raysampler_args:
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 800
|
||||
n_pts_per_ray_training: 32
|
||||
n_pts_per_ray_evaluation: 32
|
||||
@@ -13,4 +12,6 @@ generic_model_args:
|
||||
n_pts_per_ray_fine_training: 16
|
||||
n_pts_per_ray_fine_evaluation: 16
|
||||
implicit_function_class_type: NeRFormerImplicitFunction
|
||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||
view_pooler_enabled: true
|
||||
view_pooler_args:
|
||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||
|
||||
@@ -1,16 +1,6 @@
|
||||
defaults:
|
||||
- repro_multiseq_base.yaml
|
||||
- repro_feat_extractor_transformer.yaml
|
||||
- repro_multiseq_nerformer.yaml
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: true
|
||||
raysampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 800
|
||||
n_pts_per_ray_training: 32
|
||||
n_pts_per_ray_evaluation: 32
|
||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||
n_pts_per_ray_fine_training: 16
|
||||
n_pts_per_ray_fine_evaluation: 16
|
||||
implicit_function_class_type: NeRFormerImplicitFunction
|
||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||
view_pooler_args:
|
||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||
|
||||
@@ -3,7 +3,7 @@ defaults:
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: false
|
||||
view_pooler_enabled: false
|
||||
n_train_target_views: -1
|
||||
num_passes: 1
|
||||
loss_weights:
|
||||
@@ -13,14 +13,16 @@ generic_model_args:
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
loss_autodecoder_norm: 0.001
|
||||
depth_neg_penalty: 10000.0
|
||||
sequence_autodecoder_args:
|
||||
encoding_dim: 256
|
||||
n_instances: 20000
|
||||
raysampler_args:
|
||||
global_encoder_class_type: SequenceAutodecoder
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 256
|
||||
n_instances: 20000
|
||||
raysampler_class_type: NearFarRaySampler
|
||||
raysampler_NearFarRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 2048
|
||||
min_depth: 0.05
|
||||
max_depth: 0.05
|
||||
scene_extent: 0.0
|
||||
n_pts_per_ray_training: 1
|
||||
n_pts_per_ray_evaluation: 1
|
||||
stratified_point_sampling_training: false
|
||||
|
||||
@@ -4,7 +4,6 @@ defaults:
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 32000
|
||||
view_pool: true
|
||||
num_passes: 1
|
||||
n_train_target_views: -1
|
||||
loss_weights:
|
||||
@@ -14,17 +13,18 @@ generic_model_args:
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
loss_autodecoder_norm: 0.0
|
||||
depth_neg_penalty: 10000.0
|
||||
raysampler_args:
|
||||
raysampler_class_type: NearFarRaySampler
|
||||
raysampler_NearFarRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 2048
|
||||
min_depth: 0.05
|
||||
max_depth: 0.05
|
||||
scene_extent: 0.0
|
||||
n_pts_per_ray_training: 1
|
||||
n_pts_per_ray_evaluation: 1
|
||||
stratified_point_sampling_training: false
|
||||
stratified_point_sampling_evaluation: false
|
||||
renderer_class_type: LSTMRenderer
|
||||
implicit_function_class_type: SRNImplicitFunction
|
||||
view_pooler_enabled: true
|
||||
solver_args:
|
||||
breed: adam
|
||||
lr: 5.0e-05
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
defaults:
|
||||
- repro_base
|
||||
- _self_
|
||||
dataloader_args:
|
||||
batch_size: 1
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
dataset_args:
|
||||
dataset_name: co3d_singlesequence
|
||||
assert_single_seq: true
|
||||
n_frames_per_sequence: -1
|
||||
test_restrict_sequence_id: 0
|
||||
test_on_train: false
|
||||
data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
assert_single_seq: true
|
||||
n_frames_per_sequence: -1
|
||||
test_restrict_sequence_id: 0
|
||||
test_on_train: false
|
||||
generic_model_args:
|
||||
render_image_height: 800
|
||||
render_image_width: 800
|
||||
|
||||
@@ -9,7 +9,7 @@ generic_model_args:
|
||||
loss_eikonal: 0.1
|
||||
chunk_size_grid: 65536
|
||||
num_passes: 1
|
||||
view_pool: false
|
||||
view_pooler_enabled: false
|
||||
implicit_function_IdrFeatureField_args:
|
||||
n_harmonic_functions_xyz: 6
|
||||
bias: 0.6
|
||||
@@ -49,7 +49,7 @@ generic_model_args:
|
||||
n_harmonic_functions_dir: 4
|
||||
pooled_feature_dim: 0
|
||||
weight_norm: true
|
||||
raysampler_args:
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
n_pts_per_ray_training: 0
|
||||
n_pts_per_ray_evaluation: 0
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
defaults:
|
||||
- repro_singleseq_base
|
||||
- _self_
|
||||
exp_dir: ./data/nerf_single_apple/
|
||||
|
||||
@@ -4,6 +4,6 @@ defaults:
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: true
|
||||
raysampler_args:
|
||||
view_pooler_enabled: true
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 850
|
||||
|
||||
@@ -4,13 +4,14 @@ defaults:
|
||||
- _self_
|
||||
generic_model_args:
|
||||
chunk_size_grid: 16000
|
||||
view_pool: true
|
||||
view_pooler_enabled: true
|
||||
implicit_function_class_type: NeRFormerImplicitFunction
|
||||
raysampler_args:
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 800
|
||||
n_pts_per_ray_training: 32
|
||||
n_pts_per_ray_evaluation: 32
|
||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||
n_pts_per_ray_fine_training: 16
|
||||
n_pts_per_ray_fine_evaluation: 16
|
||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||
view_pooler_args:
|
||||
feature_aggregator_class_type: IdentityFeatureAggregator
|
||||
|
||||
@@ -4,7 +4,7 @@ defaults:
|
||||
generic_model_args:
|
||||
num_passes: 1
|
||||
chunk_size_grid: 32000
|
||||
view_pool: false
|
||||
view_pooler_enabled: false
|
||||
loss_weights:
|
||||
loss_rgb_mse: 200.0
|
||||
loss_prev_stage_rgb_mse: 0.0
|
||||
@@ -12,11 +12,11 @@ generic_model_args:
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
loss_autodecoder_norm: 0.0
|
||||
depth_neg_penalty: 10000.0
|
||||
raysampler_args:
|
||||
raysampler_class_type: NearFarRaySampler
|
||||
raysampler_NearFarRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 2048
|
||||
min_depth: 0.05
|
||||
max_depth: 0.05
|
||||
scene_extent: 0.0
|
||||
n_pts_per_ray_training: 1
|
||||
n_pts_per_ray_evaluation: 1
|
||||
stratified_point_sampling_training: false
|
||||
|
||||
@@ -5,7 +5,7 @@ defaults:
|
||||
generic_model_args:
|
||||
num_passes: 1
|
||||
chunk_size_grid: 32000
|
||||
view_pool: true
|
||||
view_pooler_enabled: true
|
||||
loss_weights:
|
||||
loss_rgb_mse: 200.0
|
||||
loss_prev_stage_rgb_mse: 0.0
|
||||
@@ -13,11 +13,11 @@ generic_model_args:
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
loss_autodecoder_norm: 0.0
|
||||
depth_neg_penalty: 10000.0
|
||||
raysampler_args:
|
||||
raysampler_class_type: NearFarRaySampler
|
||||
raysampler_NearFarRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 2048
|
||||
min_depth: 0.05
|
||||
max_depth: 0.05
|
||||
scene_extent: 0.0
|
||||
n_pts_per_ray_training: 1
|
||||
n_pts_per_ray_evaluation: 1
|
||||
stratified_point_sampling_training: false
|
||||
|
||||
@@ -1,18 +1,22 @@
|
||||
defaults:
|
||||
- repro_singleseq_base
|
||||
- _self_
|
||||
dataloader_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
num_workers: 8
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
data_source_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_length_train: 1000
|
||||
dataset_length_val: 1
|
||||
num_workers: 8
|
||||
train_conditioning_type: SAME
|
||||
val_conditioning_type: SAME
|
||||
test_conditioning_type: SAME
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 5
|
||||
- 6
|
||||
- 7
|
||||
- 8
|
||||
- 9
|
||||
- 10
|
||||
|
||||
@@ -45,7 +45,6 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
config file.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
@@ -53,7 +52,6 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import hydra
|
||||
@@ -61,26 +59,29 @@ import lpips
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from accelerate import Accelerator
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
enable_get_default_args,
|
||||
get_default_args_field,
|
||||
expand_args_fields,
|
||||
remove_unused_components,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.stats import Stats
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .impl.experiment_config import ExperimentConfig
|
||||
from .impl.optimization import init_optimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -96,9 +97,13 @@ try:
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
||||
|
||||
|
||||
def init_model(
|
||||
*,
|
||||
cfg: DictConfig,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
force_load: bool = False,
|
||||
clear_stats: bool = False,
|
||||
load_model_only: bool = False,
|
||||
@@ -162,12 +167,20 @@ def init_model(
|
||||
logger.info("found previous model %s" % model_path)
|
||||
if force_load or cfg.resume:
|
||||
logger.info(" -> resuming")
|
||||
|
||||
map_location = None
|
||||
if accelerator is not None and not accelerator.is_local_main_process:
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
if load_model_only:
|
||||
model_state_dict = torch.load(model_io.get_model_path(model_path))
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
)
|
||||
stats_load, optimizer_state = None, None
|
||||
else:
|
||||
model_state_dict, stats_load, optimizer_state = model_io.load_model(
|
||||
model_path
|
||||
model_path, map_location=map_location
|
||||
)
|
||||
|
||||
# Determine if stats should be reset
|
||||
@@ -211,116 +224,21 @@ def init_model(
|
||||
return model, stats, optimizer_state
|
||||
|
||||
|
||||
def init_optimizer(
|
||||
model: GenericModel,
|
||||
optimizer_state: Optional[Dict[str, Any]],
|
||||
last_epoch: int,
|
||||
breed: str = "adam",
|
||||
weight_decay: float = 0.0,
|
||||
lr_policy: str = "multistep",
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float] = (0.9, 0.999),
|
||||
milestones: tuple = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer (optionally from checkpoint state)
|
||||
and the learning rate scheduler.
|
||||
|
||||
Args:
|
||||
model: The model with optionally loaded weights
|
||||
optimizer_state: The state dict for the optimizer. If None
|
||||
it has not been loaded from checkpoint
|
||||
last_epoch: If the model was loaded from checkpoint this will be the
|
||||
number of the last epoch that was saved
|
||||
breed: The type of optimizer to use e.g. adam
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||
is supported.
|
||||
lr: The value for the initial learning rate
|
||||
gamma: Multiplicative factor of learning rate decay
|
||||
momentum: Momentum factor for SGD optimizer
|
||||
betas: Coefficients used for computing running averages of gradient and its square
|
||||
in the Adam optimizer
|
||||
milestones: List of increasing epoch indices at which the learning rate is
|
||||
modified
|
||||
max_epochs: The maximum number of epochs to run the optimizer for
|
||||
|
||||
Returns:
|
||||
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||
scheduler: Learning rate scheduler module
|
||||
|
||||
Raise:
|
||||
ValueError if `breed` or `lr_policy` are not supported.
|
||||
"""
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": lr}]
|
||||
|
||||
# Intialize the optimizer
|
||||
if breed == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||
)
|
||||
elif breed == "adagrad":
|
||||
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||
elif breed == "adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % breed)
|
||||
logger.info(" -> solver type = %s" % breed)
|
||||
|
||||
# Load state from checkpoint
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
if lr_policy == "multistep":
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such lr policy %s" % lr_policy)
|
||||
|
||||
# When loading from checkpoint, this will make sure that the
|
||||
# lr is correctly set even after returning
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
# Add the max epochs here
|
||||
scheduler.max_epochs = max_epochs
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
||||
|
||||
|
||||
def trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
loader,
|
||||
optimizer,
|
||||
validation,
|
||||
validation: bool,
|
||||
*,
|
||||
accelerator: Optional[Accelerator],
|
||||
device: torch.device,
|
||||
bp_var: str = "objective",
|
||||
metric_print_interval: int = 5,
|
||||
visualize_interval: int = 100,
|
||||
visdom_env_root: str = "trainvalidate",
|
||||
clip_grad: float = 0.0,
|
||||
device: str = "cuda:0",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -368,11 +286,11 @@ def trainvalidate(
|
||||
|
||||
# Iterate through the batches
|
||||
n_batches = len(loader)
|
||||
for it, batch in enumerate(loader):
|
||||
for it, net_input in enumerate(loader):
|
||||
last_iter = it == n_batches - 1
|
||||
|
||||
# move to gpu where possible (in place)
|
||||
net_input = batch.to(device)
|
||||
net_input = net_input.to(device)
|
||||
|
||||
# run the forward pass
|
||||
if not validation:
|
||||
@@ -398,7 +316,11 @@ def trainvalidate(
|
||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
||||
|
||||
# visualize results
|
||||
if visualize_interval > 0 and it % visualize_interval == 0:
|
||||
if (
|
||||
(accelerator is None or accelerator.is_local_main_process)
|
||||
and visualize_interval > 0
|
||||
and it % visualize_interval == 0
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
|
||||
model.visualize(
|
||||
@@ -413,7 +335,10 @@ def trainvalidate(
|
||||
loss = preds[bp_var]
|
||||
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
||||
# backprop
|
||||
loss.backward()
|
||||
if accelerator is None:
|
||||
loss.backward()
|
||||
else:
|
||||
accelerator.backward(loss)
|
||||
if clip_grad > 0.0:
|
||||
# Optionally clip the gradient norms.
|
||||
total_norm = torch.nn.utils.clip_grad_norm(
|
||||
@@ -422,18 +347,29 @@ def trainvalidate(
|
||||
if total_norm > clip_grad:
|
||||
logger.info(
|
||||
f"Clipping gradient: {total_norm}"
|
||||
+ f" with coef {clip_grad / total_norm}."
|
||||
+ f" with coef {clip_grad / float(total_norm)}."
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
def run_training(cfg: DictConfig) -> None:
|
||||
"""
|
||||
Entry point to run the training and validation loops
|
||||
based on the specified config file.
|
||||
"""
|
||||
|
||||
# Initialize the accelerator
|
||||
if no_accelerate:
|
||||
accelerator = None
|
||||
device = torch.device("cuda:0")
|
||||
else:
|
||||
accelerator = Accelerator(device_placement=False)
|
||||
logger.info(accelerator.state)
|
||||
device = accelerator.device
|
||||
|
||||
logger.info(f"Running experiment on device: {device}")
|
||||
|
||||
# set the debug mode
|
||||
if cfg.detect_anomaly:
|
||||
logger.info("Anomaly detection!")
|
||||
@@ -452,12 +388,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
warnings.warn("Cant dump config due to insufficient permissions!")
|
||||
|
||||
# setup datasets
|
||||
datasets = dataset_zoo(**cfg.dataset_args)
|
||||
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
|
||||
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
|
||||
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||
task = datasource.get_task()
|
||||
|
||||
# init the model
|
||||
model, stats, optimizer_state = init_model(cfg)
|
||||
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
|
||||
start_epoch = stats.epoch + 1
|
||||
|
||||
# move model to gpu
|
||||
@@ -465,7 +401,16 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
|
||||
# only run evaluation on the test dataloader
|
||||
if cfg.eval_only:
|
||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(
|
||||
cfg,
|
||||
task,
|
||||
datasource.all_train_cameras,
|
||||
datasets,
|
||||
dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
)
|
||||
return
|
||||
|
||||
# init the optimizer
|
||||
@@ -480,6 +425,19 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
assert scheduler.last_epoch == stats.epoch + 1
|
||||
assert scheduler.last_epoch == start_epoch
|
||||
|
||||
# Wrap all modules in the distributed library
|
||||
# Note: we don't pass the scheduler to prepare as it
|
||||
# doesn't need to be stepped at each optimizer step
|
||||
train_loader = dataloaders.train
|
||||
val_loader = dataloaders.val
|
||||
if accelerator is not None:
|
||||
(
|
||||
model,
|
||||
optimizer,
|
||||
train_loader,
|
||||
val_loader,
|
||||
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||
|
||||
past_scheduler_lrs = []
|
||||
# loop through epochs
|
||||
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
|
||||
@@ -499,46 +457,62 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders["train"],
|
||||
train_loader,
|
||||
optimizer,
|
||||
False,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
**cfg,
|
||||
)
|
||||
|
||||
# val loop (optional)
|
||||
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
|
||||
if val_loader is not None and epoch % cfg.validation_interval == 0:
|
||||
trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders["val"],
|
||||
val_loader,
|
||||
optimizer,
|
||||
True,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
**cfg,
|
||||
)
|
||||
|
||||
# eval loop (optional)
|
||||
if (
|
||||
"test" in dataloaders
|
||||
dataloaders.test is not None
|
||||
and cfg.test_interval > 0
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
run_eval(cfg, model, stats, dataloaders["test"], device=device)
|
||||
_run_eval(
|
||||
model,
|
||||
datasource.all_train_cameras,
|
||||
dataloaders.test,
|
||||
task,
|
||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
# delete previous models if required
|
||||
# save model
|
||||
if cfg.store_checkpoints:
|
||||
# save model only on the main process
|
||||
if cfg.store_checkpoints and (
|
||||
accelerator is None or accelerator.is_local_main_process
|
||||
):
|
||||
if cfg.store_checkpoints_purge > 0:
|
||||
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
|
||||
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
|
||||
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
|
||||
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
|
||||
unwrapped_model = (
|
||||
model if accelerator is None else accelerator.unwrap_model(model)
|
||||
)
|
||||
model_io.safe_save_model(
|
||||
unwrapped_model, stats, outfile, optimizer=optimizer
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
@@ -547,26 +521,45 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||
|
||||
if cfg.test_when_finished:
|
||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
||||
_eval_and_dump(
|
||||
cfg,
|
||||
task,
|
||||
datasource.all_train_cameras,
|
||||
datasets,
|
||||
dataloaders,
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
|
||||
def _eval_and_dump(
|
||||
cfg,
|
||||
task: Task,
|
||||
all_train_cameras: Optional[CamerasBase],
|
||||
datasets: DatasetMap,
|
||||
dataloaders: DataLoaderMap,
|
||||
model,
|
||||
stats,
|
||||
device,
|
||||
) -> None:
|
||||
"""
|
||||
Run the evaluation loop with the test data loader and
|
||||
save the predictions to the `exp_dir`.
|
||||
"""
|
||||
|
||||
if "test" not in dataloaders:
|
||||
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
|
||||
dataloader = dataloaders.test
|
||||
|
||||
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
|
||||
all_source_cameras = (
|
||||
_get_all_source_cameras(datasets["train"])
|
||||
if eval_task == "singlesequence"
|
||||
else None
|
||||
)
|
||||
results = run_eval(
|
||||
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
|
||||
if dataloader is None:
|
||||
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||
|
||||
results = _run_eval(
|
||||
model,
|
||||
all_train_cameras,
|
||||
dataloader,
|
||||
task,
|
||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
@@ -594,7 +587,14 @@ def _get_eval_frame_data(frame_data):
|
||||
return frame_data_for_eval
|
||||
|
||||
|
||||
def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
||||
def _run_eval(
|
||||
model,
|
||||
all_train_cameras,
|
||||
loader,
|
||||
task: Task,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||
device,
|
||||
):
|
||||
"""
|
||||
Run the evaluation loop on the test dataloader
|
||||
"""
|
||||
@@ -615,104 +615,91 @@ def run_eval(cfg, model, all_source_cameras, loader, task, device):
|
||||
preds = model(
|
||||
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||
)
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
|
||||
# TODO: Cannot use accelerate gather for two reasons:.
|
||||
# (1) TypeError: Can't apply _gpu_gather_one on object of type
|
||||
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
|
||||
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
|
||||
# (2) Same error above but for frame_data which contains Cameras.
|
||||
|
||||
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||
|
||||
per_batch_eval_results.append(
|
||||
evaluate.eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
implicitron_render,
|
||||
bg_color="black",
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
source_cameras=all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task
|
||||
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||
)
|
||||
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: ImplicitronDataset,
|
||||
num_workers: int = 8,
|
||||
) -> CamerasBase:
|
||||
"""
|
||||
Load and return all the source cameras in the training dataset
|
||||
"""
|
||||
|
||||
all_frame_data = next(
|
||||
iter(
|
||||
torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
batch_size=len(dataset),
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
|
||||
return source_cameras
|
||||
|
||||
|
||||
def _seed_all_random_engines(seed: int):
|
||||
def _seed_all_random_engines(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ExperimentConfig:
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
architecture: str = "generic"
|
||||
detect_anomaly: bool = False
|
||||
eval_only: bool = False
|
||||
exp_dir: str = "./data/default_experiment/"
|
||||
exp_idx: int = 0
|
||||
gpu_idx: int = 0
|
||||
metric_print_interval: int = 5
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = 8097
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
visualize_interval: int = 1000
|
||||
clip_grad: float = 0.0
|
||||
def _setup_envvars_for_cluster() -> bool:
|
||||
"""
|
||||
Prepares to run on cluster if relevant.
|
||||
Returns whether FAIR cluster in use.
|
||||
"""
|
||||
# TODO: How much of this is needed in general?
|
||||
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||
"output_subdir": None, # disable storing the .hydra logs
|
||||
}
|
||||
try:
|
||||
import submitit
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Only needed when launching on cluster with slurm and submitit
|
||||
job_env = submitit.JobEnvironment()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
|
||||
os.environ["RANK"] = str(job_env.global_rank)
|
||||
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "42918"
|
||||
logger.info(
|
||||
"Num tasks %s, global_rank %s"
|
||||
% (str(job_env.num_tasks), str(job_env.global_rank))
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
expand_args_fields(ExperimentConfig)
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config", node=ExperimentConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config")
|
||||
def experiment(cfg: DictConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
# Set the device
|
||||
device = "cpu"
|
||||
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
|
||||
device = f"cuda:{cfg.gpu_idx}"
|
||||
logger.info(f"Running experiment on device: {device}")
|
||||
run_training(cfg, device)
|
||||
# CUDA_VISIBLE_DEVICES must have been set.
|
||||
|
||||
if "CUDA_DEVICE_ORDER" not in os.environ:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
|
||||
if not _setup_envvars_for_cluster():
|
||||
logger.info("Running locally")
|
||||
|
||||
# TODO: The following may be needed for hydra/submitit it to work
|
||||
expand_args_fields(GenericModel)
|
||||
expand_args_fields(AdaptiveRaySampler)
|
||||
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
|
||||
expand_args_fields(ImplicitronDataSource)
|
||||
|
||||
run_training(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
|
||||
|
||||
from .optimization import init_optimizer
|
||||
|
||||
|
||||
class ExperimentConfig(Configurable):
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||
architecture: str = "generic"
|
||||
detect_anomaly: bool = False
|
||||
eval_only: bool = False
|
||||
exp_dir: str = "./data/default_experiment/"
|
||||
exp_idx: int = 0
|
||||
gpu_idx: int = 0
|
||||
metric_print_interval: int = 5
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = 8097
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
visualize_interval: int = 1000
|
||||
clip_grad: float = 0.0
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
|
||||
hydra: Dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||
"output_subdir": None, # disable storing the .hydra logs
|
||||
}
|
||||
)
|
||||
109
projects/implicitron_trainer/impl/optimization.py
Normal file
109
projects/implicitron_trainer/impl/optimization.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_optimizer(
|
||||
model: GenericModel,
|
||||
optimizer_state: Optional[Dict[str, Any]],
|
||||
last_epoch: int,
|
||||
breed: str = "adam",
|
||||
weight_decay: float = 0.0,
|
||||
lr_policy: str = "multistep",
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||
milestones: Tuple[int, ...] = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer (optionally from checkpoint state)
|
||||
and the learning rate scheduler.
|
||||
|
||||
Args:
|
||||
model: The model with optionally loaded weights
|
||||
optimizer_state: The state dict for the optimizer. If None
|
||||
it has not been loaded from checkpoint
|
||||
last_epoch: If the model was loaded from checkpoint this will be the
|
||||
number of the last epoch that was saved
|
||||
breed: The type of optimizer to use e.g. adam
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||
is supported.
|
||||
lr: The value for the initial learning rate
|
||||
gamma: Multiplicative factor of learning rate decay
|
||||
momentum: Momentum factor for SGD optimizer
|
||||
betas: Coefficients used for computing running averages of gradient and its square
|
||||
in the Adam optimizer
|
||||
milestones: List of increasing epoch indices at which the learning rate is
|
||||
modified
|
||||
max_epochs: The maximum number of epochs to run the optimizer for
|
||||
|
||||
Returns:
|
||||
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||
scheduler: Learning rate scheduler module
|
||||
|
||||
Raise:
|
||||
ValueError if `breed` or `lr_policy` are not supported.
|
||||
"""
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": lr}]
|
||||
|
||||
# Intialize the optimizer
|
||||
if breed == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||
)
|
||||
elif breed == "adagrad":
|
||||
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||
elif breed == "adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % breed)
|
||||
logger.info(" -> solver type = %s" % breed)
|
||||
|
||||
# Load state from checkpoint
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
if lr_policy == "multistep":
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such lr policy %s" % lr_policy)
|
||||
|
||||
# When loading from checkpoint, this will make sure that the
|
||||
# lr is correctly set even after returning
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
||||
5
projects/implicitron_trainer/tests/__init__.py
Normal file
5
projects/implicitron_trainer/tests/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
425
projects/implicitron_trainer/tests/experiment.yaml
Normal file
425
projects/implicitron_trainer/tests/experiment.yaml
Normal file
@@ -0,0 +1,425 @@
|
||||
generic_model_args:
|
||||
mask_images: true
|
||||
mask_depths: true
|
||||
render_image_width: 400
|
||||
render_image_height: 400
|
||||
mask_threshold: 0.5
|
||||
output_rasterized_mc: false
|
||||
bg_color:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
num_passes: 1
|
||||
chunk_size_grid: 4096
|
||||
render_features_dimensions: 3
|
||||
tqdm_trigger_threshold: 16
|
||||
n_train_target_views: 1
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
global_encoder_class_type: null
|
||||
raysampler_class_type: AdaptiveRaySampler
|
||||
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
||||
image_feature_extractor_class_type: null
|
||||
view_pooler_enabled: false
|
||||
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
||||
view_metrics_class_type: ViewMetrics
|
||||
regularization_metrics_class_type: RegularizationMetrics
|
||||
loss_weights:
|
||||
loss_rgb_mse: 1.0
|
||||
loss_prev_stage_rgb_mse: 1.0
|
||||
loss_mask_bce: 0.0
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
log_vars:
|
||||
- loss_rgb_psnr_fg
|
||||
- loss_rgb_psnr
|
||||
- loss_rgb_mse
|
||||
- loss_rgb_huber
|
||||
- loss_depth_abs
|
||||
- loss_depth_abs_fg
|
||||
- loss_mask_neg_iou
|
||||
- loss_mask_bce
|
||||
- loss_mask_beta_prior
|
||||
- loss_eikonal
|
||||
- loss_density_tv
|
||||
- loss_depth_neg_penalty
|
||||
- loss_autodecoder_norm
|
||||
- loss_prev_stage_rgb_mse
|
||||
- loss_prev_stage_rgb_psnr_fg
|
||||
- loss_prev_stage_rgb_psnr
|
||||
- loss_prev_stage_mask_bce
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
global_encoder_HarmonicTimeEncoder_args:
|
||||
n_harmonic_functions: 10
|
||||
append_input: true
|
||||
time_divisor: 1.0
|
||||
global_encoder_SequenceAutodecoder_args:
|
||||
autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
stratified_point_sampling_training: true
|
||||
stratified_point_sampling_evaluation: false
|
||||
scene_extent: 8.0
|
||||
scene_center:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
raysampler_NearFarRaySampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
stratified_point_sampling_training: true
|
||||
stratified_point_sampling_evaluation: false
|
||||
min_depth: 0.1
|
||||
max_depth: 8.0
|
||||
renderer_LSTMRenderer_args:
|
||||
num_raymarch_steps: 10
|
||||
init_depth: 17.0
|
||||
init_depth_noise_std: 0.0005
|
||||
hidden_size: 16
|
||||
n_feature_channels: 256
|
||||
bg_color: null
|
||||
verbose: false
|
||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
||||
n_pts_per_ray_fine_training: 64
|
||||
n_pts_per_ray_fine_evaluation: 64
|
||||
stratified_sampling_coarse_training: true
|
||||
stratified_sampling_coarse_evaluation: false
|
||||
append_coarse_samples_to_fine: true
|
||||
density_noise_std_train: 0.0
|
||||
return_weights: false
|
||||
raymarcher_CumsumRaymarcher_args:
|
||||
surface_thickness: 1
|
||||
bg_color:
|
||||
- 0.0
|
||||
background_opacity: 0.0
|
||||
density_relu: true
|
||||
blend_output: false
|
||||
raymarcher_EmissionAbsorptionRaymarcher_args:
|
||||
surface_thickness: 1
|
||||
bg_color:
|
||||
- 0.0
|
||||
background_opacity: 10000000000.0
|
||||
density_relu: true
|
||||
blend_output: false
|
||||
renderer_SignedDistanceFunctionRenderer_args:
|
||||
render_features_dimensions: 3
|
||||
ray_tracer_args:
|
||||
object_bounding_sphere: 1.0
|
||||
sdf_threshold: 5.0e-05
|
||||
line_search_step: 0.5
|
||||
line_step_iters: 1
|
||||
sphere_tracing_iters: 10
|
||||
n_steps: 100
|
||||
n_secant_steps: 8
|
||||
ray_normal_coloring_network_args:
|
||||
feature_vector_size: 3
|
||||
mode: idr
|
||||
d_in: 9
|
||||
d_out: 3
|
||||
dims:
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
weight_norm: true
|
||||
n_harmonic_functions_dir: 0
|
||||
pooled_feature_dim: 0
|
||||
bg_color:
|
||||
- 0.0
|
||||
soft_mask_alpha: 50.0
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
name: resnet34
|
||||
pretrained: true
|
||||
stages:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
normalize_image: true
|
||||
image_rescale: 0.16
|
||||
first_max_pool: true
|
||||
proj_dim: 32
|
||||
l2_norm: true
|
||||
add_masks: true
|
||||
add_images: true
|
||||
global_average_pool: false
|
||||
feature_rescale: 1.0
|
||||
view_pooler_args:
|
||||
feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
sampling_mode: bilinear
|
||||
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
weight_by_ray_angle_gamma: 1.0
|
||||
min_ray_angle_weight: 0.1
|
||||
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
reduction_functions:
|
||||
- AVG
|
||||
- STD
|
||||
weight_by_ray_angle_gamma: 1.0
|
||||
min_ray_angle_weight: 0.1
|
||||
feature_aggregator_IdentityFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
feature_aggregator_ReductionFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
reduction_functions:
|
||||
- AVG
|
||||
- STD
|
||||
implicit_function_IdrFeatureField_args:
|
||||
feature_vector_size: 3
|
||||
d_in: 3
|
||||
d_out: 1
|
||||
dims:
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
geometric_init: true
|
||||
bias: 1.0
|
||||
skip_in: []
|
||||
weight_norm: true
|
||||
n_harmonic_functions_xyz: 0
|
||||
pooled_feature_dim: 0
|
||||
encoding_dim: 0
|
||||
implicit_function_NeRFormerImplicitFunction_args:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_dir: 128
|
||||
latent_dim: 0
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
color_dim: 3
|
||||
transformer_dim_down_factor: 2.0
|
||||
n_hidden_neurons_xyz: 80
|
||||
n_layers_xyz: 2
|
||||
append_xyz:
|
||||
- 1
|
||||
implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_dir: 128
|
||||
latent_dim: 0
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
color_dim: 3
|
||||
transformer_dim_down_factor: 1.0
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_layers_xyz: 8
|
||||
append_xyz:
|
||||
- 5
|
||||
implicit_function_SRNHyperNetImplicitFunction_args:
|
||||
hypernet_args:
|
||||
n_harmonic_functions: 3
|
||||
n_hidden_units: 256
|
||||
n_layers: 2
|
||||
n_hidden_units_hypernet: 256
|
||||
n_layers_hypernet: 1
|
||||
in_features: 3
|
||||
out_features: 256
|
||||
latent_dim_hypernet: 0
|
||||
latent_dim: 0
|
||||
xyz_in_camera_coords: false
|
||||
pixel_generator_args:
|
||||
n_harmonic_functions: 4
|
||||
n_hidden_units: 256
|
||||
n_hidden_units_color: 128
|
||||
n_layers: 2
|
||||
in_features: 256
|
||||
out_features: 3
|
||||
ray_dir_in_camera_coords: false
|
||||
implicit_function_SRNImplicitFunction_args:
|
||||
raymarch_function_args:
|
||||
n_harmonic_functions: 3
|
||||
n_hidden_units: 256
|
||||
n_layers: 2
|
||||
in_features: 3
|
||||
out_features: 256
|
||||
latent_dim: 0
|
||||
xyz_in_camera_coords: false
|
||||
raymarch_function: null
|
||||
pixel_generator_args:
|
||||
n_harmonic_functions: 4
|
||||
n_hidden_units: 256
|
||||
n_hidden_units_color: 128
|
||||
n_layers: 2
|
||||
in_features: 256
|
||||
out_features: 3
|
||||
ray_dir_in_camera_coords: false
|
||||
view_metrics_ViewMetrics_args: {}
|
||||
regularization_metrics_RegularizationMetrics_args: {}
|
||||
solver_args:
|
||||
breed: adam
|
||||
weight_decay: 0.0
|
||||
lr_policy: multistep
|
||||
lr: 0.0005
|
||||
gamma: 0.1
|
||||
momentum: 0.9
|
||||
betas:
|
||||
- 0.9
|
||||
- 0.999
|
||||
milestones: []
|
||||
max_epochs: 1000
|
||||
data_source_args:
|
||||
dataset_map_provider_class_type: ???
|
||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||
dataset_map_provider_BlenderDatasetMapProvider_args:
|
||||
base_dir: ???
|
||||
object_name: ???
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
n_known_frames_for_test: null
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
category: ???
|
||||
task_str: singlesequence
|
||||
dataset_root: ''
|
||||
n_frames_per_sequence: -1
|
||||
test_on_train: false
|
||||
restrict_sequence_name: []
|
||||
test_restrict_sequence_id: -1
|
||||
assert_single_seq: false
|
||||
only_test_set: false
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
limit_to: 0
|
||||
limit_sequences_to: 0
|
||||
exclude_sequence: []
|
||||
limit_category_to: []
|
||||
load_images: true
|
||||
load_depths: true
|
||||
load_depth_masks: true
|
||||
load_masks: true
|
||||
load_point_clouds: false
|
||||
max_points: 0
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
image_height: 800
|
||||
image_width: 800
|
||||
box_crop: true
|
||||
box_crop_mask_thr: 0.4
|
||||
box_crop_context: 0.3
|
||||
remove_empty_masks: true
|
||||
seed: 0
|
||||
sort_frames: false
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
category: ???
|
||||
subset_name: ???
|
||||
dataset_root: ''
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
path_manager: null
|
||||
frame_annotations_file: ''
|
||||
sequence_annotations_file: ''
|
||||
subset_lists_file: ''
|
||||
subsets: null
|
||||
limit_to: 0
|
||||
limit_sequences_to: 0
|
||||
pick_sequence: []
|
||||
exclude_sequence: []
|
||||
limit_category_to: []
|
||||
dataset_root: ''
|
||||
load_images: true
|
||||
load_depths: true
|
||||
load_depth_masks: true
|
||||
load_masks: true
|
||||
load_point_clouds: false
|
||||
max_points: 0
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
image_height: 800
|
||||
image_width: 800
|
||||
box_crop: true
|
||||
box_crop_mask_thr: 0.4
|
||||
box_crop_context: 0.3
|
||||
remove_empty_masks: true
|
||||
n_frames_per_sequence: -1
|
||||
seed: 0
|
||||
sort_frames: false
|
||||
eval_batches: null
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||
base_dir: ???
|
||||
object_name: ???
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
n_known_frames_for_test: null
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_length_train: 0
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
train_conditioning_type: SAME
|
||||
val_conditioning_type: SAME
|
||||
test_conditioning_type: KNOWN
|
||||
images_per_seq_options: []
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
architecture: generic
|
||||
detect_anomaly: false
|
||||
eval_only: false
|
||||
exp_dir: ./data/default_experiment/
|
||||
exp_idx: 0
|
||||
gpu_idx: 0
|
||||
metric_print_interval: 5
|
||||
resume: true
|
||||
resume_epoch: -1
|
||||
seed: 0
|
||||
store_checkpoints: true
|
||||
store_checkpoints_purge: 1
|
||||
test_interval: -1
|
||||
test_when_finished: false
|
||||
validation_interval: 1
|
||||
visdom_env: ''
|
||||
visdom_port: 8097
|
||||
visdom_server: http://127.0.0.1
|
||||
visualize_interval: 1000
|
||||
clip_grad: 0.0
|
||||
camera_difficulty_bin_breaks:
|
||||
- 0.97
|
||||
- 0.98
|
||||
hydra:
|
||||
run:
|
||||
dir: .
|
||||
output_subdir: null
|
||||
91
projects/implicitron_trainer/tests/test_experiment.py
Normal file
91
projects/implicitron_trainer/tests/test_experiment.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from hydra import compose, initialize_config_dir
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .. import experiment
|
||||
|
||||
|
||||
def interactive_testing_requested() -> bool:
|
||||
"""
|
||||
Certain tests are only useful when run interactively, and so are not regularly run.
|
||||
These are activated by this funciton returning True, which the user requests by
|
||||
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
|
||||
"""
|
||||
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
|
||||
|
||||
|
||||
internal = os.environ.get("FB_TEST", False)
|
||||
|
||||
|
||||
DATA_DIR = Path(__file__).resolve().parent
|
||||
IMPLICITRON_CONFIGS_DIR = Path(__file__).resolve().parent.parent / "configs"
|
||||
DEBUG: bool = False
|
||||
|
||||
# TODO:
|
||||
# - add enough files to skateboard_first_5 that this works on RE.
|
||||
# - share common code with PyTorch3D tests?
|
||||
# - deal with the temporary output files this test creates
|
||||
|
||||
|
||||
class TestExperiment(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
|
||||
def test_from_defaults(self):
|
||||
# Test making minimal changes to the dataclass defaults.
|
||||
if not interactive_testing_requested() or not internal:
|
||||
return
|
||||
cfg = OmegaConf.structured(experiment.ExperimentConfig)
|
||||
cfg.data_source_args.dataset_map_provider_class_type = (
|
||||
"JsonIndexDatasetMapProvider"
|
||||
)
|
||||
dataset_args = (
|
||||
cfg.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataloader_args = (
|
||||
cfg.data_source_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||
)
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
||||
dataset_args.dataset_JsonIndexDataset_args.image_height = 80
|
||||
dataset_args.dataset_JsonIndexDataset_args.image_width = 80
|
||||
dataloader_args.dataset_length_train = 1
|
||||
dataloader_args.dataset_length_val = 1
|
||||
cfg.solver_args.max_epochs = 2
|
||||
|
||||
experiment.run_training(cfg)
|
||||
|
||||
def test_yaml_contents(self):
|
||||
cfg = OmegaConf.structured(experiment.ExperimentConfig)
|
||||
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "experiment.yaml").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text())
|
||||
|
||||
def test_load_configs(self):
|
||||
config_files = []
|
||||
|
||||
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
|
||||
config_files.extend(
|
||||
[
|
||||
f
|
||||
for f in IMPLICITRON_CONFIGS_DIR.glob(pattern)
|
||||
if not f.name.endswith("_base.yaml")
|
||||
]
|
||||
)
|
||||
|
||||
for file in config_files:
|
||||
with self.subTest(file.name):
|
||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||
compose(file.name)
|
||||
@@ -21,15 +21,11 @@ from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as Fu
|
||||
from experiment import init_model
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base import EvaluationMode
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||
generate_eval_video_cameras,
|
||||
@@ -41,9 +37,11 @@ from pytorch3d.implicitron.tools.vis_utils import (
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from .experiment import init_model
|
||||
|
||||
|
||||
def render_sequence(
|
||||
dataset: ImplicitronDataset,
|
||||
dataset: DatasetBase,
|
||||
sequence_name: str,
|
||||
model: torch.nn.Module,
|
||||
video_path,
|
||||
@@ -66,6 +64,12 @@ def render_sequence(
|
||||
):
|
||||
if seed is None:
|
||||
seed = hash(sequence_name)
|
||||
|
||||
if visdom_show_preds:
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
else:
|
||||
viz = None
|
||||
|
||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||
@@ -84,7 +88,7 @@ def render_sequence(
|
||||
up=up,
|
||||
focal_length=None,
|
||||
principal_point=torch.zeros(n_eval_cameras, 2),
|
||||
traj_offset_canonical=[0.0, 0.0, traj_offset],
|
||||
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||
)
|
||||
|
||||
# sample the source views reproducibly
|
||||
@@ -120,7 +124,6 @@ def render_sequence(
|
||||
if visdom_show_preds and (
|
||||
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
||||
):
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
show_predictions(
|
||||
preds_total,
|
||||
sequence_name=batch.sequence_name[0],
|
||||
@@ -248,7 +251,7 @@ def show_predictions(
|
||||
def generate_prediction_videos(
|
||||
preds,
|
||||
sequence_name,
|
||||
viz,
|
||||
viz=None,
|
||||
viz_env="visualizer",
|
||||
predicted_keys=(
|
||||
"images_render",
|
||||
@@ -276,19 +279,20 @@ def generate_prediction_videos(
|
||||
for rendered_pred in tqdm(preds):
|
||||
for k in predicted_keys:
|
||||
vws[k].write_frame(
|
||||
rendered_pred[k][0].detach().cpu().numpy(),
|
||||
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||
resize=resize,
|
||||
)
|
||||
|
||||
for k in predicted_keys:
|
||||
vws[k].get_video(quiet=True)
|
||||
print(f"Generated {vws[k].out_path}.")
|
||||
viz.video(
|
||||
videofile=vws[k].out_path,
|
||||
env=viz_env,
|
||||
win=k, # we reuse the same window otherwise visdom dies
|
||||
opts={"title": sequence_name + " " + k},
|
||||
)
|
||||
if viz is not None:
|
||||
viz.video(
|
||||
videofile=vws[k].out_path,
|
||||
env=viz_env,
|
||||
win=k, # we reuse the same window otherwise visdom dies
|
||||
opts={"title": sequence_name + " " + k},
|
||||
)
|
||||
|
||||
|
||||
def export_scenes(
|
||||
@@ -297,7 +301,7 @@ def export_scenes(
|
||||
output_directory: Optional[str] = None,
|
||||
render_size: Tuple[int, int] = (512, 512),
|
||||
video_size: Optional[Tuple[int, int]] = None,
|
||||
split: str = "train", # train | test
|
||||
split: str = "train", # train | val | test
|
||||
n_source_views: int = 9,
|
||||
n_eval_cameras: int = 40,
|
||||
visdom_server="http://127.0.0.1",
|
||||
@@ -325,24 +329,31 @@ def export_scenes(
|
||||
config.gpu_idx = gpu_idx
|
||||
config.exp_dir = exp_dir
|
||||
# important so that the CO3D dataset gets loaded in full
|
||||
config.dataset_args.test_on_train = False
|
||||
dataset_args = (
|
||||
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
)
|
||||
dataset_args.test_on_train = False
|
||||
# Set the rendering image size
|
||||
config.generic_model_args.render_image_width = render_size[0]
|
||||
config.generic_model_args.render_image_height = render_size[1]
|
||||
if restrict_sequence_name is not None:
|
||||
config.dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||
dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||
|
||||
# Set up the CUDA env for the visualization
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
||||
|
||||
# Load the previously trained model
|
||||
model, _, _ = init_model(config, force_load=True, load_model_only=True)
|
||||
model, _, _ = init_model(cfg=config, force_load=True, load_model_only=True)
|
||||
model.cuda()
|
||||
model.eval()
|
||||
|
||||
# Setup the dataset
|
||||
dataset = dataset_zoo(**config.dataset_args)[split]
|
||||
datasource = ImplicitronDataSource(**config.data_source_args)
|
||||
dataset_map = datasource.dataset_map_provider.get_dataset_map()
|
||||
dataset = dataset_map[split]
|
||||
if dataset is None:
|
||||
raise ValueError(f"{split} dataset not provided")
|
||||
|
||||
# iterate over the sequences in the dataset
|
||||
for sequence_name in dataset.sequence_names():
|
||||
|
||||
@@ -97,7 +97,7 @@ def generate_eval_video_cameras(
|
||||
cam_centers_on_plane.t() @ cam_centers_on_plane
|
||||
) / cam_centers_on_plane.shape[0]
|
||||
_, e_vec = torch.symeig(cov, eigenvectors=True)
|
||||
traj_radius = (cam_centers_on_plane ** 2).sum(dim=1).sqrt().mean()
|
||||
traj_radius = (cam_centers_on_plane**2).sum(dim=1).sqrt().mean()
|
||||
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams)
|
||||
traj = traj_radius * torch.stack(
|
||||
(torch.zeros_like(angle), angle.cos(), angle.sin()), dim=-1
|
||||
|
||||
@@ -23,6 +23,7 @@ def solve(A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: # pragma: no cover
|
||||
# PyTorch version >= 1.8.0
|
||||
return torch.linalg.solve(A, B)
|
||||
|
||||
# pyre-fixme[16]: `Tuple` has no attribute `solution`.
|
||||
return torch.solve(B, A).solution
|
||||
|
||||
|
||||
@@ -67,9 +68,14 @@ def meshgrid_ij(
|
||||
Like torch.meshgrid was before PyTorch 1.10.0, i.e. with indexing set to ij
|
||||
"""
|
||||
if (
|
||||
# pyre-fixme[16]: Callable `meshgrid` has no attribute `__kwdefaults__`.
|
||||
torch.meshgrid.__kwdefaults__ is not None
|
||||
and "indexing" in torch.meshgrid.__kwdefaults__
|
||||
):
|
||||
# PyTorch >= 1.10.0
|
||||
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but
|
||||
# got `Union[Sequence[Tensor], Tensor]`.
|
||||
return torch.meshgrid(*A, indexing="ij")
|
||||
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
|
||||
# `Union[Sequence[Tensor], Tensor]`.
|
||||
return torch.meshgrid(*A)
|
||||
|
||||
@@ -26,7 +26,7 @@ def make_device(device: Device) -> torch.device:
|
||||
A matching torch.device object
|
||||
"""
|
||||
device = torch.device(device) if isinstance(device, str) else device
|
||||
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
|
||||
if device.type == "cuda" and device.index is None:
|
||||
# If cuda but with no index, then the current cuda device is indicated.
|
||||
# In that case, we fix to that device
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
@@ -71,6 +71,5 @@ elif sys.version_info >= (3, 7, 0):
|
||||
def get_args(cls): # pragma: no cover
|
||||
return getattr(cls, "__args__", None)
|
||||
|
||||
|
||||
else:
|
||||
raise ImportError("This module requires Python 3.7+")
|
||||
|
||||
@@ -75,12 +75,14 @@ class _SymEig3x3(nn.Module):
|
||||
if inputs.shape[-2:] != (3, 3):
|
||||
raise ValueError("Only inputs of shape (..., 3, 3) are supported.")
|
||||
|
||||
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1) # pyre-ignore[16]
|
||||
inputs_diag = inputs.diagonal(dim1=-2, dim2=-1)
|
||||
inputs_trace = inputs_diag.sum(-1)
|
||||
q = inputs_trace / 3.0
|
||||
|
||||
# Calculate squared sum of elements outside the main diagonal / 2
|
||||
p1 = ((inputs ** 2).sum(dim=(-1, -2)) - (inputs_diag ** 2).sum(-1)) / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
p1 = ((inputs**2).sum(dim=(-1, -2)) - (inputs_diag**2).sum(-1)) / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
p2 = ((inputs_diag - q[..., None]) ** 2).sum(dim=-1) + 2.0 * p1.clamp(self._eps)
|
||||
|
||||
p = torch.sqrt(p2 / 6.0)
|
||||
@@ -195,8 +197,9 @@ class _SymEig3x3(nn.Module):
|
||||
cross_products[..., :1, :]
|
||||
)
|
||||
|
||||
norms_sq = (cross_products ** 2).sum(dim=-1)
|
||||
max_norms_index = norms_sq.argmax(dim=-1) # pyre-ignore[16]
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
norms_sq = (cross_products**2).sum(dim=-1)
|
||||
max_norms_index = norms_sq.argmax(dim=-1)
|
||||
|
||||
# Pick only the cross-product with highest squared norm for each input
|
||||
max_cross_products = self._gather_by_index(
|
||||
@@ -227,9 +230,7 @@ class _SymEig3x3(nn.Module):
|
||||
index_shape = list(source.shape)
|
||||
index_shape[dim] = 1
|
||||
|
||||
return source.gather(dim, index.expand(index_shape)).squeeze( # pyre-ignore[16]
|
||||
dim
|
||||
)
|
||||
return source.gather(dim, index.expand(index_shape)).squeeze(dim)
|
||||
|
||||
def _get_uv(self, w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@@ -243,7 +244,7 @@ class _SymEig3x3(nn.Module):
|
||||
Tuple of U and V unit-length vector tensors of shape (..., 3)
|
||||
"""
|
||||
|
||||
min_idx = w.abs().argmin(dim=-1) # pyre-ignore[16]
|
||||
min_idx = w.abs().argmin(dim=-1)
|
||||
rotation_2d = self._rotations_3d[min_idx].to(w)
|
||||
|
||||
u = F.normalize((rotation_2d @ w[..., None])[..., 0], dim=-1)
|
||||
|
||||
@@ -377,7 +377,7 @@ class R2N2(ShapeNetBase): # pragma: no cover
|
||||
view_idxs: Optional[List[int]] = None,
|
||||
shader_type=HardPhongShader,
|
||||
device: Device = "cpu",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Render models with BlenderCamera by default to achieve the same orientations as the
|
||||
|
||||
@@ -140,7 +140,6 @@ def compute_extrinsic_matrix(
|
||||
# rotates the model 90 degrees about the x axis. To compensate for this quirk we
|
||||
# roll that rotation into the extrinsic matrix here
|
||||
rot = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `mm`.
|
||||
RT = RT.mm(rot.to(RT))
|
||||
|
||||
return RT
|
||||
@@ -180,6 +179,7 @@ def read_binvox_coords(
|
||||
size, translation, scale = _read_binvox_header(f)
|
||||
storage = torch.ByteStorage.from_buffer(f.read())
|
||||
data = torch.tensor([], dtype=torch.uint8)
|
||||
# pyre-fixme[28]: Unexpected keyword argument `source`.
|
||||
data.set_(source=storage)
|
||||
vals, counts = data[::2], data[1::2]
|
||||
idxs = _compute_idxs(vals, counts)
|
||||
@@ -276,7 +276,7 @@ def _read_binvox_header(f): # pragma: no cover
|
||||
try:
|
||||
dims = [int(d) for d in dims[1:]]
|
||||
except ValueError:
|
||||
raise ValueError("Invalid header (line 2)")
|
||||
raise ValueError("Invalid header (line 2)") from None
|
||||
if len(dims) != 3 or dims[0] != dims[1] or dims[0] != dims[2]:
|
||||
raise ValueError("Invalid header (line 2)")
|
||||
size = dims[0]
|
||||
@@ -291,7 +291,7 @@ def _read_binvox_header(f): # pragma: no cover
|
||||
try:
|
||||
translation = tuple(float(t) for t in translation[1:])
|
||||
except ValueError:
|
||||
raise ValueError("Invalid header (line 3)")
|
||||
raise ValueError("Invalid header (line 3)") from None
|
||||
|
||||
# Fourth line of the header should be "scale [float]"
|
||||
line = f.readline().strip()
|
||||
|
||||
@@ -113,7 +113,7 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
||||
idxs: Optional[List[int]] = None,
|
||||
shader_type=HardPhongShader,
|
||||
device: Device = "cpu",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
If a list of model_ids are supplied, render all the objects by the given model_ids.
|
||||
@@ -227,6 +227,8 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
||||
sampled_idxs = self._sample_idxs_from_category(
|
||||
sample_num=sample_num, category=category
|
||||
)
|
||||
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor],
|
||||
# typing.Tuple[Tensor, ...]]` but got `Tuple[Tensor, List[int]]`.
|
||||
idxs_tensor = torch.cat((idxs_tensor, sampled_idxs))
|
||||
idxs = idxs_tensor.tolist()
|
||||
# Check if the indices are valid if idxs are supplied.
|
||||
@@ -283,4 +285,5 @@ class ShapeNetBase(torch.utils.data.Dataset): # pragma: no cover
|
||||
"category " + category if category is not None else "all categories",
|
||||
)
|
||||
warnings.warn(msg)
|
||||
# pyre-fixme[7]: Expected `List[int]` but got `Tensor`.
|
||||
return sampled_idxs
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
|
||||
from .load_blender import load_blender_data
|
||||
from .single_sequence_dataset import (
|
||||
_interpret_blender_cameras,
|
||||
SingleSceneDatasetMapProviderBase,
|
||||
)
|
||||
|
||||
|
||||
@registry.register
|
||||
class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
"""
|
||||
Provides data for one scene from Blender synthetic dataset.
|
||||
Uses the code in load_blender.py
|
||||
|
||||
Members:
|
||||
base_dir: directory holding the data for the scene.
|
||||
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
|
||||
It will typically be equal to the name of the directory self.base_dir.
|
||||
path_manager_factory: Creates path manager which may be used for
|
||||
interpreting paths.
|
||||
n_known_frames_for_test: If set, training frames are included in the val
|
||||
and test datasets, and this many random training frames are added to
|
||||
each test batch. If not set, test batches each contain just a single
|
||||
testing frame.
|
||||
"""
|
||||
|
||||
def _load_data(self) -> None:
|
||||
path_manager = self.path_manager_factory.get()
|
||||
images, poses, _, hwf, i_split = load_blender_data(
|
||||
self.base_dir,
|
||||
testskip=1,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
H, W, focal = hwf
|
||||
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||
|
||||
# pyre-ignore[16]
|
||||
self.poses = _interpret_blender_cameras(poses, focal)
|
||||
# pyre-ignore[16]
|
||||
self.images = images_masks[:, :3]
|
||||
# pyre-ignore[16]
|
||||
self.fg_probabilities = images_masks[:, 3:4]
|
||||
# pyre-ignore[16]
|
||||
self.i_split = i_split
|
||||
438
pytorch3d/implicitron/dataset/data_loader_map_provider.py
Normal file
438
pytorch3d/implicitron/dataset/data_loader_map_provider.py
Normal file
@@ -0,0 +1,438 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from torch.utils.data import (
|
||||
BatchSampler,
|
||||
ChainDataset,
|
||||
DataLoader,
|
||||
RandomSampler,
|
||||
Sampler,
|
||||
)
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
from .utils import is_known_frame_scalar
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoaderMap:
|
||||
"""
|
||||
A collection of data loaders for Implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a data loader for training
|
||||
val: a data loader for validating during training
|
||||
test: a data loader for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[DataLoader[FrameData]]
|
||||
val: Optional[DataLoader[FrameData]]
|
||||
test: Optional[DataLoader[FrameData]]
|
||||
|
||||
def __getitem__(self, split: str) -> Optional[DataLoader[FrameData]]:
|
||||
"""
|
||||
Get one of the data loaders by key (name of data split)
|
||||
"""
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||
return getattr(self, split)
|
||||
|
||||
|
||||
class DataLoaderMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
Provider of a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
|
||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||
"""
|
||||
Batch sampler for making random batches of a single frame
|
||||
from one list and a number of known frames from another list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
first_indices: List[int],
|
||||
rest_indices: List[int],
|
||||
batch_size: int,
|
||||
replacement: bool,
|
||||
num_batches: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
first_indices: indexes of dataset items to use as the first element
|
||||
of each batch.
|
||||
rest_indices: indexes of dataset items to use as the subsequent
|
||||
elements of each batch. Not used if batch_size==1.
|
||||
batch_size: The common size of any batch.
|
||||
replacement: Whether the sampling of first items is with replacement.
|
||||
num_batches: The number of batches in an epoch. If 0 or None,
|
||||
one epoch is the length of `first_indices`.
|
||||
"""
|
||||
self.first_indices = first_indices
|
||||
self.rest_indices = rest_indices
|
||||
self.batch_size = batch_size
|
||||
self.replacement = replacement
|
||||
self.num_batches = None if num_batches == 0 else num_batches
|
||||
|
||||
if batch_size - 1 > len(rest_indices):
|
||||
raise ValueError(
|
||||
f"Cannot make up ({batch_size})-batches from {len(self.rest_indices)}"
|
||||
)
|
||||
|
||||
# copied from RandomSampler
|
||||
seed = int(torch.empty((), dtype=torch.int64).random_().item())
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(seed)
|
||||
|
||||
def __len__(self) -> int:
|
||||
if self.num_batches is not None:
|
||||
return self.num_batches
|
||||
return len(self.first_indices)
|
||||
|
||||
def __iter__(self) -> Iterator[List[int]]:
|
||||
num_batches = self.num_batches
|
||||
if self.replacement:
|
||||
i_first = torch.randint(
|
||||
len(self.first_indices),
|
||||
size=(len(self),),
|
||||
generator=self.generator,
|
||||
)
|
||||
elif num_batches is not None:
|
||||
n_copies = 1 + (num_batches - 1) // len(self.first_indices)
|
||||
raw_indices = [
|
||||
torch.randperm(len(self.first_indices), generator=self.generator)
|
||||
for _ in range(n_copies)
|
||||
]
|
||||
i_first = torch.concat(raw_indices)[:num_batches]
|
||||
else:
|
||||
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
|
||||
first_indices = [self.first_indices[i] for i in i_first]
|
||||
|
||||
if self.batch_size == 1:
|
||||
for first_index in first_indices:
|
||||
yield [first_index]
|
||||
return
|
||||
|
||||
for first_index in first_indices:
|
||||
# Consider using this class in a program which sets the seed. This use
|
||||
# of randperm means that rerunning with a higher batch_size
|
||||
# results in batches whose first elements as the first run.
|
||||
i_rest = torch.randperm(
|
||||
len(self.rest_indices),
|
||||
generator=self.generator,
|
||||
)[: self.batch_size - 1]
|
||||
yield [first_index] + [self.rest_indices[i] for i in i_rest]
|
||||
|
||||
|
||||
class BatchConditioningType(Enum):
|
||||
"""
|
||||
Ways to add conditioning frames for the val and test batches.
|
||||
|
||||
SAME: Use the corresponding dataset for all elements of val batches
|
||||
without regard to frame type.
|
||||
TRAIN: Use the corresponding dataset for the first element of each
|
||||
batch, and the training dataset for the extra conditioning
|
||||
elements. No regard to frame type.
|
||||
KNOWN: Use frames from the corresponding dataset but separate them
|
||||
according to their frame_type. Each batch will contain one UNSEEN
|
||||
frame followed by many KNOWN frames.
|
||||
"""
|
||||
|
||||
SAME = "same"
|
||||
TRAIN = "train"
|
||||
KNOWN = "known"
|
||||
|
||||
|
||||
@registry.register
|
||||
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
"""
|
||||
Default implementation of DataLoaderMapProviderBase.
|
||||
|
||||
If a dataset returns batches from get_eval_batches(), then
|
||||
they will be what the corresponding dataloader returns,
|
||||
independently of any of the fields on this class.
|
||||
|
||||
If conditioning is not required, then the batch size should
|
||||
be set as 1, and most of the fields do not matter.
|
||||
|
||||
If conditioning is required, each batch will contain one main
|
||||
frame first to predict and the, rest of the elements are for
|
||||
conditioning.
|
||||
|
||||
If images_per_seq_options is left empty, the conditioning
|
||||
frames are picked according to the conditioning type given.
|
||||
This does not have regard to the order of frames in a
|
||||
scene, or which frames belong to what scene.
|
||||
|
||||
If images_per_seq_options is given, then the conditioning types
|
||||
must be SAME and the remaining fields are used.
|
||||
|
||||
Members:
|
||||
batch_size: The size of the batch of the data loader.
|
||||
num_workers: Number of data-loading threads in each data loader.
|
||||
dataset_length_train: The number of batches in a training epoch. Or 0 to mean
|
||||
an epoch is the length of the training set.
|
||||
dataset_length_val: The number of batches in a validation epoch. Or 0 to mean
|
||||
an epoch is the length of the validation set.
|
||||
dataset_length_test: The number of batches in a testing epoch. Or 0 to mean
|
||||
an epoch is the length of the test set.
|
||||
train_conditioning_type: Whether the train data loader should use
|
||||
only known frames for conditioning.
|
||||
Only used if batch_size>1 and train dataset is
|
||||
present and does not return eval_batches.
|
||||
val_conditioning_type: Whether the val data loader should use
|
||||
training frames or known frames for conditioning.
|
||||
Only used if batch_size>1 and val dataset is
|
||||
present and does not return eval_batches.
|
||||
test_conditioning_type: Whether the test data loader should use
|
||||
training frames or known frames for conditioning.
|
||||
Only used if batch_size>1 and test dataset is
|
||||
present and does not return eval_batches.
|
||||
images_per_seq_options: Possible numbers of frames sampled per sequence in a batch.
|
||||
If a conditioning_type is KNOWN or TRAIN, then this must be left at its initial
|
||||
value. Empty (the default) means that we are not careful about which frames
|
||||
come from which scene.
|
||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||
in the sequence. It first sorts the frames by timestimps when available,
|
||||
otherwise by frame numbers, finds the connected segments within the sequence
|
||||
of sufficient length, then samples a random pivot element among them and
|
||||
ideally uses it as a middle of the temporal window, shifting the borders
|
||||
where necessary. This strategy mitigates the bias against shorter segments
|
||||
and their boundaries.
|
||||
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||
difference in frame_number of neighbouring frames when forming connected
|
||||
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||
the whole sequence is considered a segment regardless of frame numbers.
|
||||
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||
the whole sequence is considered a segment regardless of frame timestamps.
|
||||
"""
|
||||
|
||||
batch_size: int = 1
|
||||
num_workers: int = 0
|
||||
dataset_length_train: int = 0
|
||||
dataset_length_val: int = 0
|
||||
dataset_length_test: int = 0
|
||||
train_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||
val_conditioning_type: BatchConditioningType = BatchConditioningType.SAME
|
||||
test_conditioning_type: BatchConditioningType = BatchConditioningType.KNOWN
|
||||
images_per_seq_options: Tuple[int, ...] = ()
|
||||
sample_consecutive_frames: bool = False
|
||||
consecutive_frames_max_gap: int = 0
|
||||
consecutive_frames_max_gap_seconds: float = 0.1
|
||||
|
||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
return DataLoaderMap(
|
||||
train=self._make_data_loader(
|
||||
datasets.train,
|
||||
self.dataset_length_train,
|
||||
datasets.train,
|
||||
self.train_conditioning_type,
|
||||
),
|
||||
val=self._make_data_loader(
|
||||
datasets.val,
|
||||
self.dataset_length_val,
|
||||
datasets.train,
|
||||
self.val_conditioning_type,
|
||||
),
|
||||
test=self._make_data_loader(
|
||||
datasets.test,
|
||||
self.dataset_length_test,
|
||||
datasets.train,
|
||||
self.test_conditioning_type,
|
||||
),
|
||||
)
|
||||
|
||||
def _make_data_loader(
|
||||
self,
|
||||
dataset: Optional[DatasetBase],
|
||||
num_batches: int,
|
||||
train_dataset: Optional[DatasetBase],
|
||||
conditioning_type: BatchConditioningType,
|
||||
) -> Optional[DataLoader[FrameData]]:
|
||||
"""
|
||||
Returns the dataloader for a dataset.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
train_dataset: the training dataset, used if conditioning_type==TRAIN
|
||||
conditioning_type: source for padding of batches
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
data_loader_kwargs = {
|
||||
"num_workers": self.num_workers,
|
||||
"collate_fn": dataset.frame_data_type.collate,
|
||||
}
|
||||
|
||||
eval_batches = dataset.get_eval_batches()
|
||||
if eval_batches is not None:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=eval_batches,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
scenes_matter = len(self.images_per_seq_options) > 0
|
||||
if scenes_matter and conditioning_type != BatchConditioningType.SAME:
|
||||
raise ValueError(
|
||||
f"{conditioning_type} cannot be used with images_per_seq "
|
||||
+ str(self.images_per_seq_options)
|
||||
)
|
||||
|
||||
if self.batch_size == 1 or (
|
||||
not scenes_matter and conditioning_type == BatchConditioningType.SAME
|
||||
):
|
||||
return self._simple_loader(dataset, num_batches, data_loader_kwargs)
|
||||
|
||||
if scenes_matter:
|
||||
assert conditioning_type == BatchConditioningType.SAME
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||
images_per_seq_options=self.images_per_seq_options,
|
||||
sample_consecutive_frames=self.sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
if conditioning_type == BatchConditioningType.TRAIN:
|
||||
return self._train_loader(
|
||||
dataset, train_dataset, num_batches, data_loader_kwargs
|
||||
)
|
||||
|
||||
assert conditioning_type == BatchConditioningType.KNOWN
|
||||
return self._known_loader(dataset, num_batches, data_loader_kwargs)
|
||||
|
||||
def _simple_loader(
|
||||
self,
|
||||
dataset: DatasetBase,
|
||||
num_batches: int,
|
||||
data_loader_kwargs: dict,
|
||||
) -> DataLoader[FrameData]:
|
||||
"""
|
||||
Return a simple loader for frames in the dataset.
|
||||
|
||||
This is equivalent to
|
||||
Dataloader(dataset, batch_size=self.batch_size, **data_loader_kwargs)
|
||||
except that num_batches is fixed.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
data_loader_kwargs: common args for dataloader
|
||||
"""
|
||||
if num_batches > 0:
|
||||
num_samples = self.batch_size * num_batches
|
||||
replacement = True
|
||||
else:
|
||||
num_samples = None
|
||||
replacement = False
|
||||
sampler = RandomSampler(
|
||||
dataset, replacement=replacement, num_samples=num_samples
|
||||
)
|
||||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
def _train_loader(
|
||||
self,
|
||||
dataset: DatasetBase,
|
||||
train_dataset: Optional[DatasetBase],
|
||||
num_batches: int,
|
||||
data_loader_kwargs: dict,
|
||||
) -> DataLoader[FrameData]:
|
||||
"""
|
||||
Return the loader for TRAIN conditioning.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
train_dataset: the training dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
data_loader_kwargs: common args for dataloader
|
||||
"""
|
||||
if train_dataset is None:
|
||||
raise ValueError("No training data for conditioning.")
|
||||
length = len(dataset)
|
||||
first_indices = list(range(length))
|
||||
rest_indices = list(range(length, length + len(train_dataset)))
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=first_indices,
|
||||
rest_indices=rest_indices,
|
||||
batch_size=self.batch_size,
|
||||
replacement=True,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return DataLoader(
|
||||
ChainDataset([dataset, train_dataset]),
|
||||
batch_sampler=sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
def _known_loader(
|
||||
self,
|
||||
dataset: DatasetBase,
|
||||
num_batches: int,
|
||||
data_loader_kwargs: dict,
|
||||
) -> DataLoader[FrameData]:
|
||||
"""
|
||||
Return the loader for KNOWN conditioning.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
data_loader_kwargs: common args for dataloader
|
||||
"""
|
||||
first_indices, rest_indices = [], []
|
||||
for idx in range(len(dataset)):
|
||||
frame_type = dataset[idx].frame_type
|
||||
assert isinstance(frame_type, str)
|
||||
if is_known_frame_scalar(frame_type):
|
||||
rest_indices.append(idx)
|
||||
else:
|
||||
first_indices.append(idx)
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=first_indices,
|
||||
rest_indices=rest_indices,
|
||||
batch_size=self.batch_size,
|
||||
replacement=True,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
79
pytorch3d/implicitron/dataset/data_source.py
Normal file
79
pytorch3d/implicitron/dataset/data_source.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||
|
||||
|
||||
class DataSourceBase(ReplaceableBase):
|
||||
"""
|
||||
Base class for a data source in Implicitron. It encapsulates Dataset
|
||||
and DataLoader configuration.
|
||||
"""
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, a list
|
||||
of the known training cameras for that scene, which is
|
||||
used for evaluating the viewpoint difficulty of the
|
||||
unseen cameras.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
"""
|
||||
Represents the data used in Implicitron. This is the only implementation
|
||||
of DataSourceBase provided.
|
||||
|
||||
Members:
|
||||
dataset_map_provider_class_type: identifies type for dataset_map_provider.
|
||||
e.g. JsonIndexDatasetMapProvider for Co3D.
|
||||
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
||||
"""
|
||||
|
||||
dataset_map_provider: DatasetMapProviderBase
|
||||
dataset_map_provider_class_type: str
|
||||
data_loader_map_provider: DataLoaderMapProviderBase
|
||||
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
self._all_train_cameras_cache: Optional[Tuple[Optional[CamerasBase]]] = None
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||
datasets = self.dataset_map_provider.get_dataset_map()
|
||||
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
||||
return datasets, dataloaders
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return self.dataset_map_provider.get_task()
|
||||
|
||||
@property
|
||||
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
||||
all_train_cameras = self.dataset_map_provider.get_all_train_cameras()
|
||||
self._all_train_cameras_cache = (all_train_cameras,)
|
||||
|
||||
return self._all_train_cameras_cache[0]
|
||||
@@ -1,100 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, Sequence
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
def dataloader_zoo(
|
||||
datasets: Dict[str, ImplicitronDatasetBase],
|
||||
dataset_name: str = "co3d_singlesequence",
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
dataset_len: int = 1000,
|
||||
dataset_len_val: int = 1,
|
||||
images_per_seq_options: Sequence[int] = (2,),
|
||||
sample_consecutive_frames: bool = False,
|
||||
consecutive_frames_max_gap: int = 0,
|
||||
consecutive_frames_max_gap_seconds: float = 0.1,
|
||||
) -> Dict[str, torch.utils.data.DataLoader]:
|
||||
"""
|
||||
Returns a set of dataloaders for a given set of datasets.
|
||||
|
||||
Args:
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
dataset_name: The name of the returned dataset.
|
||||
batch_size: The size of the batch of the dataloader.
|
||||
num_workers: Number data-loading threads.
|
||||
dataset_len: The number of batches in a training epoch.
|
||||
dataset_len_val: The number of batches in a validation epoch.
|
||||
images_per_seq_options: Possible numbers of images sampled per sequence.
|
||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||
in the sequence. It first sorts the frames by timestimps when available,
|
||||
otherwise by frame numbers, finds the connected segments within the sequence
|
||||
of sufficient length, then samples a random pivot element among them and
|
||||
ideally uses it as a middle of the temporal window, shifting the borders
|
||||
where necessary. This strategy mitigates the bias against shorter segments
|
||||
and their boundaries.
|
||||
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||
difference in frame_number of neighbouring frames when forming connected
|
||||
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||
the whole sequence is considered a segment regardless of frame numbers.
|
||||
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||
the whole sequence is considered a segment regardless of frame timestamps.
|
||||
|
||||
Returns:
|
||||
dataloaders: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
||||
"""
|
||||
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
dataloaders = {}
|
||||
|
||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
for dataset_set, dataset in datasets.items():
|
||||
num_samples = {
|
||||
"train": dataset_len,
|
||||
"val": dataset_len_val,
|
||||
"test": None,
|
||||
}[dataset_set]
|
||||
|
||||
if dataset_set == "test":
|
||||
batch_sampler = dataset.get_eval_batches()
|
||||
else:
|
||||
assert num_samples is not None
|
||||
num_samples = len(dataset) if num_samples <= 0 else num_samples
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
batch_size,
|
||||
num_batches=num_samples,
|
||||
images_per_seq_options=images_per_seq_options,
|
||||
sample_consecutive_frames=sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=consecutive_frames_max_gap,
|
||||
)
|
||||
|
||||
dataloaders[dataset_set] = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
return dataloaders
|
||||
|
||||
|
||||
enable_get_default_args(dataloader_zoo)
|
||||
306
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
306
pytorch3d/implicitron/dataset/dataset_base.py
Normal file
@@ -0,0 +1,306 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
A type of the elements returned by indexing the dataset object.
|
||||
It can represent both individual frames and batches of thereof;
|
||||
in this documentation, the sizes of tensors refer to single frames;
|
||||
add the first batch dimension for the collation result.
|
||||
|
||||
Args:
|
||||
frame_number: The number of the frame within its sequence.
|
||||
0-based continuous integers.
|
||||
sequence_name: The unique name of the frame's sequence.
|
||||
sequence_category: The object category of the sequence.
|
||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||
image_size_hw: The size of the image in pixels; (height, width) tensor
|
||||
of shape (2,).
|
||||
image_path: The qualified path to the loaded image (with dataset_root).
|
||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||
of the frame; elements are floats in [0, 1].
|
||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||
are a result of zero-padding of the image after cropping around
|
||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||
depth_path: The qualified path to the frame's depth map.
|
||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||
of the frame; values correspond to distances from the camera;
|
||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||
depth map that are valid for evaluation, they have been checked for
|
||||
consistency across views; elements are floats in {0.0, 1.0}.
|
||||
mask_path: A qualified path to the foreground probability mask.
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box tightly enclosing the foreground object in the
|
||||
format (x0, y0, width, height). The convention assumes that
|
||||
`x0+width` and `y0+height` includes the boundary of the box.
|
||||
I.e., to slice out the corresponding crop from an image tensor `I`
|
||||
we execute `crop = I[..., y0:y0+height, x0:x0+width]`
|
||||
crop_bbox_xywh: The bounding box denoting the boundaries of `image_rgb`
|
||||
in the original image coordinates in the format (x0, y0, width, height).
|
||||
The convention is the same as for `bbox_xywh`. `crop_bbox_xywh` differs
|
||||
from `bbox_xywh` due to padding (which can happen e.g. due to
|
||||
setting `JsonIndexDataset.box_crop_context > 0`)
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
frame's camera estimation (the higher the more accurate).
|
||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||
frame's sequence point cloud (the higher the more accurate).
|
||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||
point cloud corresponding to the frame's sequence. When the object
|
||||
represents a batch of frames, point clouds may be deduplicated;
|
||||
see `sequence_point_cloud_idx`.
|
||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||
corresponding point cloud to `image_rgb[i]`, use
|
||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||
frame_type: The type of the loaded frame specified in
|
||||
`subset_lists_file`, if provided.
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
frame_timestamp: Optional[torch.Tensor] = None
|
||||
image_size_hw: Optional[torch.Tensor] = None
|
||||
image_path: Union[str, List[str], None] = None
|
||||
image_rgb: Optional[torch.Tensor] = None
|
||||
# masks out padding added due to cropping the square bit
|
||||
mask_crop: Optional[torch.Tensor] = None
|
||||
depth_path: Union[str, List[str], None] = None
|
||||
depth_map: Optional[torch.Tensor] = None
|
||||
depth_mask: Optional[torch.Tensor] = None
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
crop_bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||
sequence_point_cloud: Optional[Pointclouds] = None
|
||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||
frame_type: Union[str, List[str], None] = None # known | unseen
|
||||
meta: dict = field(default_factory=lambda: {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new_params = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[f.name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[f.name] = value
|
||||
return type(self)(**new_params)
|
||||
|
||||
def cpu(self):
|
||||
return self.to(device=torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(device=torch.device("cuda"))
|
||||
|
||||
# the following functions make sure **frame_data can be passed to functions
|
||||
def __iter__(self):
|
||||
for f in fields(self):
|
||||
yield f.name
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
@classmethod
|
||||
def collate(cls, batch):
|
||||
"""
|
||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||
representation suitable for processing with deep networks.
|
||||
"""
|
||||
|
||||
elem = batch[0]
|
||||
|
||||
if isinstance(elem, cls):
|
||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||
id_to_idx = defaultdict(list)
|
||||
for i, pc_id in enumerate(pointcloud_ids):
|
||||
id_to_idx[pc_id].append(i)
|
||||
|
||||
sequence_point_cloud = []
|
||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||
for i, ind in enumerate(id_to_idx.values()):
|
||||
sequence_point_cloud_idx[ind] = i
|
||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||
assert (sequence_point_cloud_idx >= 0).all()
|
||||
|
||||
override_fields = {
|
||||
"sequence_point_cloud": sequence_point_cloud,
|
||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||
}
|
||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||
|
||||
collated = {}
|
||||
for f in fields(elem):
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
collated[f.name] = (
|
||||
cls.collate(list_values)
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
class _GenericWorkaround:
|
||||
"""
|
||||
OmegaConf.structured has a weirdness when you try to apply
|
||||
it to a dataclass whose first base class is a Generic which is not
|
||||
Dict. The issue is with a function called get_dict_key_value_types
|
||||
in omegaconf/_utils.py.
|
||||
For example this fails:
|
||||
|
||||
@dataclass(eq=False)
|
||||
class D(torch.utils.data.Dataset[int]):
|
||||
a: int = 3
|
||||
|
||||
OmegaConf.structured(D)
|
||||
|
||||
We avoid the problem by adding this class as an extra base class.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
The dataset is made up of frames, and the frames are grouped into sequences.
|
||||
Each sequence has a name (a string).
|
||||
(A sequence could be a video, or a set of images of one scene.)
|
||||
|
||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||
which will describe one frame in one sequence.
|
||||
"""
|
||||
|
||||
# _seq_to_idx is a member which implementations can define.
|
||||
# It maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore it.
|
||||
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int]
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
If the sequences in the dataset are videos rather than
|
||||
unordered views, then the dataset should override this method to
|
||||
return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. In addition,
|
||||
the values in _seq_to_idx should be in ascending order.
|
||||
If timestamps are absent, they should be replaced with a constant.
|
||||
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
frames.
|
||||
|
||||
Args:
|
||||
idx: frame index in self
|
||||
|
||||
Returns:
|
||||
tuple of
|
||||
- frame index in video
|
||||
- timestamp of frame in video
|
||||
"""
|
||||
raise ValueError("This dataset does not contain videos.")
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
# pyre-ignore[16]
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
"""Returns an iterator over the frame indices in a given sequence.
|
||||
We attempt to first sort by timestamp (if they are available),
|
||||
then by frame number.
|
||||
|
||||
Args:
|
||||
seq_name: the name of the sequence.
|
||||
|
||||
Returns:
|
||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||
where `frame_no` is the index within the sequence, and
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
# pyre-ignore[16]
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
[
|
||||
(timestamp, frame_no, idx)
|
||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||
]
|
||||
)
|
||||
|
||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||
yield idx
|
||||
|
||||
# frame_data_type is the actual type of frames returned by the dataset.
|
||||
# Collation uses its classmethod `collate`
|
||||
frame_data_type: ClassVar[Type[FrameData]] = FrameData
|
||||
120
pytorch3d/implicitron/dataset/dataset_map_provider.py
Normal file
120
pytorch3d/implicitron/dataset/dataset_map_provider.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Iterator, Optional
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_base import DatasetBase
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetMap:
|
||||
"""
|
||||
A collection of datasets for implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a dataset for training
|
||||
val: a dataset for validating during training
|
||||
test: a dataset for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[DatasetBase]
|
||||
val: Optional[DatasetBase]
|
||||
test: Optional[DatasetBase]
|
||||
|
||||
def __getitem__(self, split: str) -> Optional[DatasetBase]:
|
||||
"""
|
||||
Get one of the datasets by key (name of data split)
|
||||
"""
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||
return getattr(self, split)
|
||||
|
||||
def iter_datasets(self) -> Iterator[DatasetBase]:
|
||||
"""
|
||||
Iterator over all datasets.
|
||||
"""
|
||||
if self.train is not None:
|
||||
yield self.train
|
||||
if self.val is not None:
|
||||
yield self.val
|
||||
if self.test is not None:
|
||||
yield self.test
|
||||
|
||||
|
||||
class Task(Enum):
|
||||
SINGLE_SEQUENCE = "singlesequence"
|
||||
MULTI_SEQUENCE = "multisequence"
|
||||
|
||||
|
||||
class DatasetMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
Base class for a provider of training / validation and testing
|
||||
dataset objects.
|
||||
"""
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
"""
|
||||
Returns:
|
||||
An object containing the torch.Dataset objects in train/val/test fields.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_task(self) -> Task:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
"""
|
||||
If the data is all for a single scene, returns a list
|
||||
of the known training cameras for that scene, which is
|
||||
used for evaluating the difficulty of the unknown
|
||||
cameras. Otherwise return None.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class PathManagerFactory(ReplaceableBase):
|
||||
"""
|
||||
Base class and default implementation of a tool which dataset_map_provider implementations
|
||||
may use to construct a path manager if needed.
|
||||
|
||||
Args:
|
||||
silence_logs: Whether to reduce log output from iopath library.
|
||||
"""
|
||||
|
||||
silence_logs: bool = True
|
||||
|
||||
def get(self) -> Optional[PathManager]:
|
||||
"""
|
||||
Makes a PathManager if needed.
|
||||
For open source users, this function should always return None.
|
||||
Internally, this allows manifold access.
|
||||
"""
|
||||
if os.environ.get("INSIDE_RE_WORKER", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from iopath.fb.manifold import ManifoldPathHandler
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if self.silence_logs:
|
||||
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
|
||||
|
||||
path_manager = PathManager()
|
||||
path_manager.register_handler(ManifoldPathHandler())
|
||||
|
||||
return path_manager
|
||||
@@ -1,263 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
|
||||
from .utils import (
|
||||
DATASET_TYPE_KNOWN,
|
||||
DATASET_TYPE_TEST,
|
||||
DATASET_TYPE_TRAIN,
|
||||
DATASET_TYPE_UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
# TODO from dataset.dataset_configs import DATASET_CONFIGS
|
||||
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||
"default": {
|
||||
"box_crop": True,
|
||||
"box_crop_context": 0.3,
|
||||
"image_width": 800,
|
||||
"image_height": 800,
|
||||
"remove_empty_masks": True,
|
||||
}
|
||||
}
|
||||
|
||||
# fmt: off
|
||||
CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
"baseballbat", "banana", "bicycle", "microwave", "tv",
|
||||
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
|
||||
"umbrella", "wineglass", "laptop",
|
||||
"hotdog", "stopsign", "frisbee", "baseballglove",
|
||||
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
|
||||
"handbag", "chair", "keyboard", "car", "motorcycle",
|
||||
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
|
||||
"toaster", "mouse", "toytrain", "book", "toytruck",
|
||||
"orange", "broccoli", "plant", "teddybear",
|
||||
"suitcase", "bench", "ball", "cake",
|
||||
"vase", "hydrant", "apple", "donut",
|
||||
]))
|
||||
# fmt: on
|
||||
|
||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||
|
||||
|
||||
def dataset_zoo(
|
||||
dataset_name: str = "co3d_singlesequence",
|
||||
dataset_root: str = _CO3D_DATASET_ROOT,
|
||||
category: str = "DEFAULT",
|
||||
limit_to: int = -1,
|
||||
limit_sequences_to: int = -1,
|
||||
n_frames_per_sequence: int = -1,
|
||||
test_on_train: bool = False,
|
||||
load_point_clouds: bool = False,
|
||||
mask_images: bool = False,
|
||||
mask_depths: bool = False,
|
||||
restrict_sequence_name: Sequence[str] = (),
|
||||
test_restrict_sequence_id: int = -1,
|
||||
assert_single_seq: bool = False,
|
||||
only_test_set: bool = False,
|
||||
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
|
||||
path_manager: Optional[PathManager] = None,
|
||||
) -> Dict[str, ImplicitronDatasetBase]:
|
||||
"""
|
||||
Generates the training / validation and testing dataset objects.
|
||||
|
||||
Args:
|
||||
dataset_name: The name of the returned dataset.
|
||||
dataset_root: The root folder of the dataset.
|
||||
category: The object category of the dataset.
|
||||
limit_to: Limit the dataset to the first #limit_to frames.
|
||||
limit_sequences_to: Limit the dataset to the first
|
||||
#limit_sequences_to sequences.
|
||||
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
|
||||
in each sequence.
|
||||
test_on_train: Construct validation and test datasets from
|
||||
the training subset.
|
||||
load_point_clouds: Enable returning scene point clouds from the dataset.
|
||||
mask_images: Mask the loaded images with segmentation masks.
|
||||
mask_depths: Mask the loaded depths with segmentation masks.
|
||||
restrict_sequence_name: Restrict the dataset sequences to the ones
|
||||
present in the given list of names.
|
||||
test_restrict_sequence_id: The ID of the loaded sequence.
|
||||
Active for dataset_name='co3d_singlesequence'.
|
||||
assert_single_seq: Assert that only frames from a single sequence
|
||||
are present in all generated datasets.
|
||||
only_test_set: Load only the test set.
|
||||
aux_dataset_kwargs: Specifies additional arguments to the
|
||||
ImplicitronDataset constructor call.
|
||||
|
||||
Returns:
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
"""
|
||||
datasets = {}
|
||||
|
||||
# TODO:
|
||||
# - implement loading multiple categories
|
||||
|
||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||
# to the names of the subsets in the CO3D dataset.
|
||||
set_names_mapping = _get_co3d_set_names_mapping(
|
||||
dataset_name,
|
||||
test_on_train,
|
||||
only_test_set,
|
||||
)
|
||||
|
||||
# load the evaluation batches
|
||||
task = dataset_name.split("_")[-1]
|
||||
batch_indices_path = os.path.join(
|
||||
dataset_root,
|
||||
category,
|
||||
f"eval_batches_{task}.json",
|
||||
)
|
||||
if not os.path.isfile(batch_indices_path):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError("Please specify a correct dataset_root folder.")
|
||||
|
||||
with open(batch_indices_path, "r") as f:
|
||||
eval_batch_index = json.load(f)
|
||||
|
||||
if task == "singlesequence":
|
||||
assert (
|
||||
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
|
||||
), (
|
||||
"Please specify an integer id 'test_restrict_sequence_id'"
|
||||
+ " of the sequence considered for 'singlesequence'"
|
||||
+ " training and evaluation."
|
||||
)
|
||||
assert len(restrict_sequence_name) == 0, (
|
||||
"For the 'singlesequence' task, the restrict_sequence_name has"
|
||||
" to be unset while test_restrict_sequence_id has to be set to an"
|
||||
" integer defining the order of the evaluation sequence."
|
||||
)
|
||||
# a sort-stable set() equivalent:
|
||||
eval_batches_sequence_names = list(
|
||||
{b[0][0]: None for b in eval_batch_index}.keys()
|
||||
)
|
||||
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] == eval_sequence_name
|
||||
]
|
||||
# overwrite the restrict_sequence_name
|
||||
restrict_sequence_name = [eval_sequence_name]
|
||||
|
||||
for dataset, subsets in set_names_mapping.items():
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
assert os.path.isfile(frame_file)
|
||||
|
||||
sequence_file = os.path.join(
|
||||
dataset_root, category, "sequence_annotations.jgz"
|
||||
)
|
||||
assert os.path.isfile(sequence_file)
|
||||
|
||||
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
||||
assert os.path.isfile(subset_lists_file)
|
||||
|
||||
# TODO: maybe directly in param list
|
||||
params = {
|
||||
**copy.deepcopy(aux_dataset_kwargs),
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
"dataset_root": dataset_root,
|
||||
"limit_to": limit_to,
|
||||
"limit_sequences_to": limit_sequences_to,
|
||||
"n_frames_per_sequence": n_frames_per_sequence
|
||||
if dataset == "train"
|
||||
else -1,
|
||||
"subsets": subsets,
|
||||
"load_point_clouds": load_point_clouds,
|
||||
"mask_images": mask_images,
|
||||
"mask_depths": mask_depths,
|
||||
"pick_sequence": restrict_sequence_name,
|
||||
"path_manager": path_manager,
|
||||
}
|
||||
|
||||
datasets[dataset] = ImplicitronDataset(**params)
|
||||
if dataset == "test":
|
||||
if len(restrict_sequence_name) > 0:
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||
]
|
||||
|
||||
datasets[dataset].eval_batches = datasets[
|
||||
dataset
|
||||
].seq_frame_index_to_dataset_index(eval_batch_index)
|
||||
|
||||
if assert_single_seq:
|
||||
# check theres only one sequence in all datasets
|
||||
assert (
|
||||
len(
|
||||
{
|
||||
e["frame_annotation"].sequence_name
|
||||
for dset in datasets.values()
|
||||
for e in dset.frame_annots
|
||||
}
|
||||
)
|
||||
<= 1
|
||||
), "Multiple sequences loaded but expected one"
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
if test_on_train:
|
||||
datasets["val"] = datasets["train"]
|
||||
datasets["test"] = datasets["train"]
|
||||
|
||||
return datasets
|
||||
|
||||
|
||||
enable_get_default_args(dataset_zoo)
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
dataset_name: str,
|
||||
test_on_train: bool,
|
||||
only_test: bool,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
|
||||
to the names of the corresponding subsets in the CO3D dataset
|
||||
("test_known"/"test_unseen"/"train_known"/"train_unseen").
|
||||
"""
|
||||
single_seq = dataset_name == "co3d_singlesequence"
|
||||
|
||||
if only_test:
|
||||
set_names_mapping = {}
|
||||
else:
|
||||
set_names_mapping = {
|
||||
"train": [
|
||||
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
|
||||
+ "_"
|
||||
+ DATASET_TYPE_KNOWN
|
||||
]
|
||||
}
|
||||
if not test_on_train:
|
||||
prefixes = [DATASET_TYPE_TEST]
|
||||
if not single_seq:
|
||||
prefixes.append(DATASET_TYPE_TRAIN)
|
||||
set_names_mapping.update(
|
||||
{
|
||||
dset: [
|
||||
p + "_" + t
|
||||
for p in prefixes
|
||||
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
|
||||
]
|
||||
for dset in ["val", "test"]
|
||||
}
|
||||
)
|
||||
|
||||
return set_names_mapping
|
||||
@@ -4,6 +4,7 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import functools
|
||||
import gzip
|
||||
import hashlib
|
||||
@@ -13,17 +14,12 @@ import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
@@ -34,270 +30,31 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from . import types
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .utils import is_known_frame_scalar
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameData(Mapping[str, Any]):
|
||||
"""
|
||||
A type of the elements returned by indexing the dataset object.
|
||||
It can represent both individual frames and batches of thereof;
|
||||
in this documentation, the sizes of tensors refer to single frames;
|
||||
add the first batch dimension for the collation result.
|
||||
|
||||
Args:
|
||||
frame_number: The number of the frame within its sequence.
|
||||
0-based continuous integers.
|
||||
frame_timestamp: The time elapsed since the start of a sequence in sec.
|
||||
sequence_name: The unique name of the frame's sequence.
|
||||
sequence_category: The object category of the sequence.
|
||||
image_size_hw: The size of the image in pixels; (height, width) tuple.
|
||||
image_path: The qualified path to the loaded image (with dataset_root).
|
||||
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
|
||||
of the frame; elements are floats in [0, 1].
|
||||
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
|
||||
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
|
||||
are a result of zero-padding of the image after cropping around
|
||||
the object bounding box; elements are floats in {0.0, 1.0}.
|
||||
depth_path: The qualified path to the frame's depth map.
|
||||
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
|
||||
of the frame; values correspond to distances from the camera;
|
||||
use `depth_mask` and `mask_crop` to filter for valid pixels.
|
||||
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
|
||||
depth map that are valid for evaluation, they have been checked for
|
||||
consistency across views; elements are floats in {0.0, 1.0}.
|
||||
mask_path: A qualified path to the foreground probability mask.
|
||||
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
|
||||
pixels belonging to the captured object; elements are floats
|
||||
in [0, 1].
|
||||
bbox_xywh: The bounding box capturing the object in the
|
||||
format (x0, y0, width, height).
|
||||
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
|
||||
corrected for cropping if it happened.
|
||||
camera_quality_score: The score proportional to the confidence of the
|
||||
frame's camera estimation (the higher the more accurate).
|
||||
point_cloud_quality_score: The score proportional to the accuracy of the
|
||||
frame's sequence point cloud (the higher the more accurate).
|
||||
sequence_point_cloud_path: The path to the sequence's point cloud.
|
||||
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
|
||||
point cloud corresponding to the frame's sequence. When the object
|
||||
represents a batch of frames, point clouds may be deduplicated;
|
||||
see `sequence_point_cloud_idx`.
|
||||
sequence_point_cloud_idx: Integer indices mapping frame indices to the
|
||||
corresponding point clouds in `sequence_point_cloud`; to get the
|
||||
corresponding point cloud to `image_rgb[i]`, use
|
||||
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
|
||||
frame_type: The type of the loaded frame specified in
|
||||
`subset_lists_file`, if provided.
|
||||
meta: A dict for storing additional frame information.
|
||||
"""
|
||||
|
||||
frame_number: Optional[torch.LongTensor]
|
||||
frame_timestamp: Optional[torch.Tensor]
|
||||
sequence_name: Union[str, List[str]]
|
||||
sequence_category: Union[str, List[str]]
|
||||
image_size_hw: Optional[torch.Tensor] = None
|
||||
image_path: Union[str, List[str], None] = None
|
||||
image_rgb: Optional[torch.Tensor] = None
|
||||
# masks out padding added due to cropping the square bit
|
||||
mask_crop: Optional[torch.Tensor] = None
|
||||
depth_path: Union[str, List[str], None] = None
|
||||
depth_map: Optional[torch.Tensor] = None
|
||||
depth_mask: Optional[torch.Tensor] = None
|
||||
mask_path: Union[str, List[str], None] = None
|
||||
fg_probability: Optional[torch.Tensor] = None
|
||||
bbox_xywh: Optional[torch.Tensor] = None
|
||||
camera: Optional[PerspectiveCameras] = None
|
||||
camera_quality_score: Optional[torch.Tensor] = None
|
||||
point_cloud_quality_score: Optional[torch.Tensor] = None
|
||||
sequence_point_cloud_path: Union[str, List[str], None] = None
|
||||
sequence_point_cloud: Optional[Pointclouds] = None
|
||||
sequence_point_cloud_idx: Optional[torch.Tensor] = None
|
||||
frame_type: Union[str, List[str], None] = None # seen | unseen
|
||||
meta: dict = field(default_factory=lambda: {})
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
new_params = {}
|
||||
for f in fields(self):
|
||||
value = getattr(self, f.name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
new_params[f.name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[f.name] = value
|
||||
return type(self)(**new_params)
|
||||
|
||||
def cpu(self):
|
||||
return self.to(device=torch.device("cpu"))
|
||||
|
||||
def cuda(self):
|
||||
return self.to(device=torch.device("cuda"))
|
||||
|
||||
# the following functions make sure **frame_data can be passed to functions
|
||||
def __iter__(self):
|
||||
for f in fields(self):
|
||||
yield f.name
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
@classmethod
|
||||
def collate(cls, batch):
|
||||
"""
|
||||
Given a list objects `batch` of class `cls`, collates them into a batched
|
||||
representation suitable for processing with deep networks.
|
||||
"""
|
||||
|
||||
elem = batch[0]
|
||||
|
||||
if isinstance(elem, cls):
|
||||
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
|
||||
id_to_idx = defaultdict(list)
|
||||
for i, pc_id in enumerate(pointcloud_ids):
|
||||
id_to_idx[pc_id].append(i)
|
||||
|
||||
sequence_point_cloud = []
|
||||
sequence_point_cloud_idx = -np.ones((len(batch),))
|
||||
for i, ind in enumerate(id_to_idx.values()):
|
||||
sequence_point_cloud_idx[ind] = i
|
||||
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
|
||||
assert (sequence_point_cloud_idx >= 0).all()
|
||||
|
||||
override_fields = {
|
||||
"sequence_point_cloud": sequence_point_cloud,
|
||||
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
|
||||
}
|
||||
# note that the pre-collate value of sequence_point_cloud_idx is unused
|
||||
|
||||
collated = {}
|
||||
for f in fields(elem):
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
collated[f.name] = (
|
||||
cls.collate(list_values)
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
The dataset is made up of frames, and the frames are grouped into sequences.
|
||||
Each sequence has a name (a string).
|
||||
(A sequence could be a video, or a set of images of one scene.)
|
||||
|
||||
This means they have a __getitem__ which returns an instance of a FrameData,
|
||||
which will describe one frame in one sequence.
|
||||
"""
|
||||
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_frame_numbers_and_timestamps(
|
||||
self, idxs: Sequence[int]
|
||||
) -> List[Tuple[int, float]]:
|
||||
"""
|
||||
If the sequences in the dataset are videos rather than
|
||||
unordered views, then the dataset should override this method to
|
||||
return the index and timestamp in their videos of the frames whose
|
||||
indices are given in `idxs`. In addition,
|
||||
the values in _seq_to_idx should be in ascending order.
|
||||
If timestamps are absent, they should be replaced with a constant.
|
||||
|
||||
This is used for letting SceneBatchSampler identify consecutive
|
||||
frames.
|
||||
|
||||
Args:
|
||||
idx: frame index in self
|
||||
|
||||
Returns:
|
||||
tuple of
|
||||
- frame index in video
|
||||
- timestamp of frame in video
|
||||
"""
|
||||
raise ValueError("This dataset does not contain videos.")
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return None
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
"""Returns an iterator over the frame indices in a given sequence.
|
||||
We attempt to first sort by timestamp (if they are available),
|
||||
then by frame number.
|
||||
|
||||
Args:
|
||||
seq_name: the name of the sequence.
|
||||
|
||||
Returns:
|
||||
an iterator over triplets `(timestamp, frame_no, dataset_idx)`,
|
||||
where `frame_no` is the index within the sequence, and
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
yield from sorted(
|
||||
[
|
||||
(timestamp, frame_no, idx)
|
||||
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
|
||||
]
|
||||
)
|
||||
|
||||
def sequence_indices_in_order(self, seq_name: str) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name):
|
||||
yield idx
|
||||
|
||||
|
||||
class FrameAnnotsEntry(TypedDict):
|
||||
subset: Optional[str]
|
||||
frame_annotation: types.FrameAnnotation
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
@registry.register
|
||||
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
"""
|
||||
A class for the Common Objects in 3D (CO3D) dataset.
|
||||
A dataset with annotations in json files like the Common Objects in 3D
|
||||
(CO3D) dataset.
|
||||
|
||||
Args:
|
||||
frame_annotations_file: A zipped json file containing metadata of the
|
||||
@@ -361,16 +118,16 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
Type[types.FrameAnnotation]
|
||||
] = types.FrameAnnotation
|
||||
|
||||
path_manager: Optional[PathManager] = None
|
||||
path_manager: Any = None
|
||||
frame_annotations_file: str = ""
|
||||
sequence_annotations_file: str = ""
|
||||
subset_lists_file: str = ""
|
||||
subsets: Optional[List[str]] = None
|
||||
limit_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
pick_sequence: Sequence[str] = ()
|
||||
exclude_sequence: Sequence[str] = ()
|
||||
limit_category_to: Sequence[int] = ()
|
||||
pick_sequence: Tuple[str, ...] = ()
|
||||
exclude_sequence: Tuple[str, ...] = ()
|
||||
limit_category_to: Tuple[int, ...] = ()
|
||||
dataset_root: str = ""
|
||||
load_images: bool = True
|
||||
load_depths: bool = True
|
||||
@@ -380,21 +137,21 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
max_points: int = 0
|
||||
mask_images: bool = False
|
||||
mask_depths: bool = False
|
||||
image_height: Optional[int] = 256
|
||||
image_width: Optional[int] = 256
|
||||
box_crop: bool = False
|
||||
image_height: Optional[int] = 800
|
||||
image_width: Optional[int] = 800
|
||||
box_crop: bool = True
|
||||
box_crop_mask_thr: float = 0.4
|
||||
box_crop_context: float = 1.0
|
||||
remove_empty_masks: bool = False
|
||||
box_crop_context: float = 0.3
|
||||
remove_empty_masks: bool = True
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
sort_frames: bool = False
|
||||
eval_batches: Optional[List[List[int]]] = None
|
||||
frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||
eval_batches: Any = None
|
||||
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `subset_to_image_path`.
|
||||
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
||||
self.subset_to_image_path = None
|
||||
self._load_frames()
|
||||
self._load_sequences()
|
||||
@@ -404,54 +161,174 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._filter_db() # also computes sequence indices
|
||||
logger.info(str(self))
|
||||
|
||||
def is_filtered(self):
|
||||
"""
|
||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||
stored on the disk might be missing in the dataset object.
|
||||
|
||||
Returns:
|
||||
is_filtered: `True` if the dataset has been filtered, else `False`.
|
||||
"""
|
||||
return (
|
||||
self.remove_empty_masks
|
||||
or self.limit_to > 0
|
||||
or self.limit_sequences_to > 0
|
||||
or len(self.pick_sequence) > 0
|
||||
or len(self.exclude_sequence) > 0
|
||||
or len(self.limit_category_to) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def seq_frame_index_to_dataset_index(
|
||||
self,
|
||||
seq_frame_index: Union[
|
||||
List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
],
|
||||
) -> List[List[int]]:
|
||||
seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
allow_missing_indices: bool = False,
|
||||
remove_missing_indices: bool = False,
|
||||
) -> List[List[Union[Optional[int], int]]]:
|
||||
"""
|
||||
Obtain indices into the dataset object given a list of frames specified as
|
||||
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
|
||||
Obtain indices into the dataset object given a list of frame ids.
|
||||
|
||||
Args:
|
||||
seq_frame_index: The list of frame ids specified as
|
||||
`List[List[Tuple[sequence_name:str, frame_number:int]]]`. Optionally,
|
||||
Image paths relative to the dataset_root can be stored specified as well:
|
||||
`List[List[Tuple[sequence_name:str, frame_number:int, image_path:str]]]`
|
||||
allow_missing_indices: If `False`, throws an IndexError upon reaching the first
|
||||
entry from `seq_frame_index` which is missing in the dataset.
|
||||
Otherwise, depending on `remove_missing_indices`, either returns `None`
|
||||
in place of missing entries or removes the indices of missing entries.
|
||||
remove_missing_indices: Active when `allow_missing_indices=True`.
|
||||
If `False`, returns `None` in place of `seq_frame_index` entries that
|
||||
are not present in the dataset.
|
||||
If `True` removes missing indices from the returned indices.
|
||||
|
||||
Returns:
|
||||
dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
|
||||
"""
|
||||
# TODO: check the frame numbers are unique
|
||||
_dataset_seq_frame_n_index = {
|
||||
seq: {
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
||||
for idx in seq_idx
|
||||
}
|
||||
# pyre-ignore[16]
|
||||
for seq, seq_idx in self._seq_to_idx.items()
|
||||
}
|
||||
|
||||
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
|
||||
idx = _dataset_seq_frame_n_index[seq_name][frame_no]
|
||||
def _get_dataset_idx(
|
||||
seq_name: str, frame_no: int, path: Optional[str] = None
|
||||
) -> Optional[int]:
|
||||
idx_seq = _dataset_seq_frame_n_index.get(seq_name, None)
|
||||
idx = idx_seq.get(frame_no, None) if idx_seq is not None else None
|
||||
if idx is None:
|
||||
msg = (
|
||||
f"sequence_name={seq_name} / frame_number={frame_no}"
|
||||
" not in the dataset!"
|
||||
)
|
||||
if not allow_missing_indices:
|
||||
raise IndexError(msg)
|
||||
warnings.warn(msg)
|
||||
return idx
|
||||
if path is not None:
|
||||
# Check that the loaded frame path is consistent
|
||||
# with the one stored in self.frame_annots.
|
||||
assert os.path.normpath(
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots[idx]["frame_annotation"].image.path
|
||||
) == os.path.normpath(
|
||||
path
|
||||
), f"Inconsistent batch {seq_name, frame_no, path}."
|
||||
), f"Inconsistent frame indices {seq_name, frame_no, path}."
|
||||
return idx
|
||||
|
||||
batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index]
|
||||
return batches_idx
|
||||
dataset_idx = [
|
||||
[_get_dataset_idx(*b) for b in batch] # pyre-ignore [6]
|
||||
for batch in seq_frame_index
|
||||
]
|
||||
|
||||
if allow_missing_indices and remove_missing_indices:
|
||||
# remove all None indices, and also batches with only None entries
|
||||
valid_dataset_idx = [
|
||||
[b for b in batch if b is not None] for batch in dataset_idx
|
||||
]
|
||||
return [ # pyre-ignore[7]
|
||||
batch for batch in valid_dataset_idx if len(batch) > 0
|
||||
]
|
||||
|
||||
return dataset_idx
|
||||
|
||||
def subset_from_frame_index(
|
||||
self,
|
||||
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
||||
allow_missing_indices: bool = True,
|
||||
) -> "JsonIndexDataset":
|
||||
# Get the indices into the frame annots.
|
||||
dataset_indices = self.seq_frame_index_to_dataset_index(
|
||||
[frame_index],
|
||||
allow_missing_indices=self.is_filtered() and allow_missing_indices,
|
||||
)[0]
|
||||
valid_dataset_indices = [i for i in dataset_indices if i is not None]
|
||||
|
||||
# Deep copy the whole dataset except frame_annots, which are large so we
|
||||
# deep copy only the requested subset of frame_annots.
|
||||
memo = {id(self.frame_annots): None} # pyre-ignore[16]
|
||||
dataset_new = copy.deepcopy(self, memo)
|
||||
dataset_new.frame_annots = copy.deepcopy(
|
||||
[self.frame_annots[i] for i in valid_dataset_indices]
|
||||
)
|
||||
|
||||
# This will kill all unneeded sequence annotations.
|
||||
dataset_new._invalidate_indexes(filter_seq_annots=True)
|
||||
|
||||
# Finally annotate the frame annotations with the name of the subset
|
||||
# stored in meta.
|
||||
for frame_annot in dataset_new.frame_annots:
|
||||
frame_annotation = frame_annot["frame_annotation"]
|
||||
if frame_annotation.meta is not None:
|
||||
frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
|
||||
|
||||
# A sanity check - this will crash in case some entries from frame_index are missing
|
||||
# in dataset_new.
|
||||
valid_frame_index = [
|
||||
fi for fi, di in zip(frame_index, dataset_indices) if di is not None
|
||||
]
|
||||
dataset_new.seq_frame_index_to_dataset_index(
|
||||
[valid_frame_index], allow_missing_indices=False
|
||||
)
|
||||
|
||||
return dataset_new
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ImplicitronDataset #frames={len(self.frame_annots)}"
|
||||
# pyre-ignore[16]
|
||||
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-ignore[16]
|
||||
return len(self.frame_annots)
|
||||
|
||||
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
||||
return entry["subset"]
|
||||
|
||||
def get_all_train_cameras(self) -> CamerasBase:
|
||||
"""
|
||||
Returns the cameras corresponding to all the known frames.
|
||||
"""
|
||||
cameras = []
|
||||
# pyre-ignore[16]
|
||||
for frame_idx, frame_annot in enumerate(self.frame_annots):
|
||||
frame_type = self._get_frame_type(frame_annot)
|
||||
if frame_type is None:
|
||||
raise ValueError("subsets not loaded")
|
||||
if is_known_frame_scalar(frame_type):
|
||||
cameras.append(self[frame_idx].camera)
|
||||
return join_cameras_as_batch(cameras)
|
||||
|
||||
def __getitem__(self, index) -> FrameData:
|
||||
# pyre-ignore[16]
|
||||
if index >= len(self.frame_annots):
|
||||
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
||||
|
||||
entry = self.frame_annots[index]["frame_annotation"]
|
||||
# pyre-ignore[16]
|
||||
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
||||
frame_data = FrameData(
|
||||
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
||||
@@ -477,6 +354,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
frame_data.mask_path,
|
||||
frame_data.bbox_xywh,
|
||||
clamp_bbox_xyxy,
|
||||
frame_data.crop_bbox_xywh,
|
||||
) = self._load_crop_fg_probability(entry)
|
||||
|
||||
scale = 1.0
|
||||
@@ -524,13 +402,14 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
Optional[str],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
]:
|
||||
fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = (
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
fg_probability = None
|
||||
full_path = None
|
||||
bbox_xywh = None
|
||||
clamp_bbox_xyxy = None
|
||||
crop_box_xywh = None
|
||||
|
||||
if (self.load_masks or self.box_crop) and entry.mask is not None:
|
||||
full_path = os.path.join(self.dataset_root, entry.mask.path)
|
||||
mask = _load_mask(self._local_path(full_path))
|
||||
@@ -543,11 +422,21 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
||||
|
||||
if self.box_crop:
|
||||
clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context)
|
||||
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
||||
_get_clamp_bbox(
|
||||
bbox_xywh,
|
||||
image_path=entry.image.path,
|
||||
box_crop_context=self.box_crop_context,
|
||||
),
|
||||
image_size_hw=tuple(mask.shape[-2:]),
|
||||
)
|
||||
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
||||
|
||||
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
||||
|
||||
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
|
||||
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy
|
||||
|
||||
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy, crop_box_xywh
|
||||
|
||||
def _load_crop_images(
|
||||
self,
|
||||
@@ -686,6 +575,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
)
|
||||
if not frame_annots_list:
|
||||
raise ValueError("Empty dataset!")
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots = [
|
||||
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
||||
]
|
||||
@@ -697,6 +587,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
||||
if not seq_annots:
|
||||
raise ValueError("Empty sequences file!")
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
||||
|
||||
def _load_subset_lists(self) -> None:
|
||||
@@ -712,7 +603,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
for subset, frames in subset_to_seq_frame.items()
|
||||
for _, _, path in frames
|
||||
}
|
||||
|
||||
# pyre-ignore[16]
|
||||
for frame in self.frame_annots:
|
||||
frame["subset"] = frame_path_to_subset.get(
|
||||
frame["frame_annotation"].image.path, None
|
||||
@@ -725,6 +616,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
|
||||
def _sort_frames(self) -> None:
|
||||
# Sort frames to have them grouped by sequence, ordered by timestamp
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots = sorted(
|
||||
self.frame_annots,
|
||||
key=lambda f: (
|
||||
@@ -736,6 +628,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
def _filter_db(self) -> None:
|
||||
if self.remove_empty_masks:
|
||||
logger.info("Removing images with empty masks.")
|
||||
# pyre-ignore[16]
|
||||
old_len = len(self.frame_annots)
|
||||
|
||||
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
||||
@@ -776,6 +669,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
|
||||
if len(self.limit_category_to) > 0:
|
||||
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots = {
|
||||
name: entry
|
||||
for name, entry in self.seq_annots.items()
|
||||
@@ -813,6 +707,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
if self.n_frames_per_sequence > 0:
|
||||
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||
keep_idx = []
|
||||
# pyre-ignore[16]
|
||||
for seq, seq_indices in self._seq_to_idx.items():
|
||||
# infer the seed from the sequence name, this is reproducible
|
||||
# and makes the selection differ for different sequences
|
||||
@@ -842,14 +737,20 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
self._invalidate_seq_to_idx()
|
||||
|
||||
if filter_seq_annots:
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots = {
|
||||
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
|
||||
k: v
|
||||
for k, v in self.seq_annots.items()
|
||||
# pyre-ignore[16]
|
||||
if k in self._seq_to_idx
|
||||
}
|
||||
|
||||
def _invalidate_seq_to_idx(self) -> None:
|
||||
seq_to_idx = defaultdict(list)
|
||||
# pyre-ignore[16]
|
||||
for idx, entry in enumerate(self.frame_annots):
|
||||
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
||||
# pyre-ignore[16]
|
||||
self._seq_to_idx = seq_to_idx
|
||||
|
||||
def _resize_image(
|
||||
@@ -867,16 +768,18 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
)
|
||||
imre = torch.nn.functional.interpolate(
|
||||
torch.from_numpy(image)[None],
|
||||
# pyre-ignore[6]
|
||||
scale_factor=minscale,
|
||||
mode=mode,
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
# pyre-fixme[19]: Expected 1 positional argument.
|
||||
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
# pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`.
|
||||
# pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`.
|
||||
mask = torch.zeros(1, self.image_height, self.image_width)
|
||||
mask[:, 0 : imre.shape[1] - 1, 0 : imre.shape[2] - 1] = 1.0
|
||||
mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0
|
||||
return imre_, minscale, mask
|
||||
|
||||
def _local_path(self, path: str) -> str:
|
||||
@@ -889,6 +792,7 @@ class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
) -> List[Tuple[int, float]]:
|
||||
out: List[Tuple[int, float]] = []
|
||||
for idx in idxs:
|
||||
# pyre-ignore[16]
|
||||
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
||||
out.append(
|
||||
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
||||
@@ -929,7 +833,7 @@ def _load_1bit_png_mask(file: str) -> np.ndarray:
|
||||
return mask
|
||||
|
||||
|
||||
def _load_depth_mask(path) -> np.ndarray:
|
||||
def _load_depth_mask(path: str) -> np.ndarray:
|
||||
if not path.lower().endswith(".png"):
|
||||
raise ValueError('unsupported depth mask file name "%s"' % path)
|
||||
m = _load_1bit_png_mask(path)
|
||||
@@ -954,7 +858,7 @@ def _load_mask(path) -> np.ndarray:
|
||||
|
||||
def _get_1d_bounds(arr) -> Tuple[int, int]:
|
||||
nz = np.flatnonzero(arr)
|
||||
return nz[0], nz[-1]
|
||||
return nz[0], nz[-1] + 1
|
||||
|
||||
|
||||
def _get_bbox_from_mask(
|
||||
@@ -975,11 +879,15 @@ def _get_bbox_from_mask(
|
||||
|
||||
|
||||
def _get_clamp_bbox(
|
||||
bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = ""
|
||||
bbox: torch.Tensor,
|
||||
box_crop_context: float = 0.0,
|
||||
image_path: str = "",
|
||||
) -> torch.Tensor:
|
||||
# box_crop_context: rate of expansion for bbox
|
||||
# returns possibly expanded bbox xyxy as float
|
||||
|
||||
bbox = bbox.clone() # do not edit bbox in place
|
||||
|
||||
# increase box size
|
||||
if box_crop_context > 0.0:
|
||||
c = box_crop_context
|
||||
@@ -991,27 +899,38 @@ def _get_clamp_bbox(
|
||||
|
||||
if (bbox[2:] <= 1.0).any():
|
||||
raise ValueError(
|
||||
f"squashed image {impath}!! The bounding box contains no pixels."
|
||||
f"squashed image {image_path}!! The bounding box contains no pixels."
|
||||
)
|
||||
|
||||
bbox[2:] = torch.clamp(bbox[2:], 2)
|
||||
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax]
|
||||
# +1 because upper bound is not inclusive
|
||||
bbox[2:] = torch.clamp(bbox[2:], 2) # set min height, width to 2 along both axes
|
||||
bbox_xyxy = _bbox_xywh_to_xyxy(bbox, clamp_size=2)
|
||||
|
||||
return bbox
|
||||
return bbox_xyxy
|
||||
|
||||
|
||||
def _crop_around_box(tensor, bbox, impath: str = ""):
|
||||
# bbox is xyxy, where the upper bound is corrected with +1
|
||||
bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1])
|
||||
bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2])
|
||||
bbox = bbox.round().long()
|
||||
bbox = _clamp_box_to_image_bounds_and_round(
|
||||
bbox,
|
||||
image_size_hw=tensor.shape[-2:],
|
||||
)
|
||||
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
||||
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _clamp_box_to_image_bounds_and_round(
|
||||
bbox_xyxy: torch.Tensor,
|
||||
image_size_hw: Tuple[int, int],
|
||||
) -> torch.LongTensor:
|
||||
bbox_xyxy = bbox_xyxy.clone()
|
||||
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
||||
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
||||
if not isinstance(bbox_xyxy, torch.LongTensor):
|
||||
bbox_xyxy = bbox_xyxy.round().long()
|
||||
return bbox_xyxy # pyre-ignore [7]
|
||||
|
||||
|
||||
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
||||
assert bbox is not None
|
||||
assert np.prod(orig_res) > 1e-8
|
||||
@@ -1020,6 +939,22 @@ def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
|
||||
return bbox * rel_size
|
||||
|
||||
|
||||
def _bbox_xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
|
||||
wh = xyxy[2:] - xyxy[:2]
|
||||
xywh = torch.cat([xyxy[:2], wh])
|
||||
return xywh
|
||||
|
||||
|
||||
def _bbox_xywh_to_xyxy(
|
||||
xywh: torch.Tensor, clamp_size: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
xyxy = xywh.clone()
|
||||
if clamp_size is not None:
|
||||
xyxy[2:] = torch.clamp(xyxy[2:], clamp_size)
|
||||
xyxy[2:] += xyxy[:2]
|
||||
return xyxy
|
||||
|
||||
|
||||
def _safe_as_tensor(data, dtype):
|
||||
if data is None:
|
||||
return None
|
||||
326
pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py
Normal file
326
pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py
Normal file
@@ -0,0 +1,326 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from .json_index_dataset import JsonIndexDataset
|
||||
|
||||
from .utils import (
|
||||
DATASET_TYPE_KNOWN,
|
||||
DATASET_TYPE_TEST,
|
||||
DATASET_TYPE_TRAIN,
|
||||
DATASET_TYPE_UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
"baseballbat", "banana", "bicycle", "microwave", "tv",
|
||||
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
|
||||
"umbrella", "wineglass", "laptop",
|
||||
"hotdog", "stopsign", "frisbee", "baseballglove",
|
||||
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
|
||||
"handbag", "chair", "keyboard", "car", "motorcycle",
|
||||
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
|
||||
"toaster", "mouse", "toytrain", "book", "toytruck",
|
||||
"orange", "broccoli", "plant", "teddybear",
|
||||
"suitcase", "bench", "ball", "cake",
|
||||
"vase", "hydrant", "apple", "donut",
|
||||
]))
|
||||
# fmt: on
|
||||
|
||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||
|
||||
# _NEED_CONTROL is a list of those elements of JsonIndexDataset which
|
||||
# are not directly specified for it in the config but come from the
|
||||
# DatasetMapProvider.
|
||||
_NEED_CONTROL: Tuple[str, ...] = (
|
||||
"dataset_root",
|
||||
"eval_batches",
|
||||
"n_frames_per_sequence",
|
||||
"path_manager",
|
||||
"pick_sequence",
|
||||
"subsets",
|
||||
"frame_annotations_file",
|
||||
"sequence_annotations_file",
|
||||
"subset_lists_file",
|
||||
)
|
||||
|
||||
|
||||
@registry.register
|
||||
class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
"""
|
||||
Generates the training / validation and testing dataset objects for
|
||||
a dataset laid out on disk like Co3D, with annotations in json files.
|
||||
|
||||
Args:
|
||||
category: The object category of the dataset.
|
||||
task_str: "multisequence" or "singlesequence".
|
||||
dataset_root: The root folder of the dataset.
|
||||
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
|
||||
in each sequence.
|
||||
test_on_train: Construct validation and test datasets from
|
||||
the training subset.
|
||||
restrict_sequence_name: Restrict the dataset sequences to the ones
|
||||
present in the given list of names.
|
||||
test_restrict_sequence_id: The ID of the loaded sequence.
|
||||
Active for task_str='singlesequence'.
|
||||
assert_single_seq: Assert that only frames from a single sequence
|
||||
are present in all generated datasets.
|
||||
only_test_set: Load only the test set.
|
||||
dataset_class_type: name of class (JsonIndexDataset or a subclass)
|
||||
to use for the dataset.
|
||||
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
|
||||
to all the dataset constructors.
|
||||
path_manager_factory: (Optional) An object that generates an instance of
|
||||
PathManager that can translate provided file paths.
|
||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||
"""
|
||||
|
||||
category: str
|
||||
task_str: str = "singlesequence"
|
||||
dataset_root: str = _CO3D_DATASET_ROOT
|
||||
n_frames_per_sequence: int = -1
|
||||
test_on_train: bool = False
|
||||
restrict_sequence_name: Tuple[str, ...] = ()
|
||||
test_restrict_sequence_id: int = -1
|
||||
assert_single_seq: bool = False
|
||||
only_test_set: bool = False
|
||||
dataset: JsonIndexDataset
|
||||
dataset_class_type: str = "JsonIndexDataset"
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
@classmethod
|
||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
||||
not expose certain fields of each dataset class.
|
||||
"""
|
||||
with open_dict(args):
|
||||
for key in _NEED_CONTROL:
|
||||
del args[key]
|
||||
|
||||
def create_dataset(self):
|
||||
"""
|
||||
Prevent the member named dataset from being created.
|
||||
"""
|
||||
return
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
path_manager = self.path_manager_factory.get()
|
||||
|
||||
# TODO:
|
||||
# - implement loading multiple categories
|
||||
|
||||
frame_file = os.path.join(
|
||||
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||
)
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||
)
|
||||
subset_lists_file = os.path.join(
|
||||
self.dataset_root, self.category, "set_lists.json"
|
||||
)
|
||||
common_kwargs = {
|
||||
"dataset_root": self.dataset_root,
|
||||
"path_manager": path_manager,
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
**getattr(self, f"dataset_{self.dataset_class_type}_args"),
|
||||
}
|
||||
|
||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||
# to the names of the subsets in the CO3D dataset.
|
||||
set_names_mapping = _get_co3d_set_names_mapping(
|
||||
self.get_task(),
|
||||
self.test_on_train,
|
||||
self.only_test_set,
|
||||
)
|
||||
|
||||
# load the evaluation batches
|
||||
batch_indices_path = os.path.join(
|
||||
self.dataset_root,
|
||||
self.category,
|
||||
f"eval_batches_{self.task_str}.json",
|
||||
)
|
||||
if path_manager is not None:
|
||||
batch_indices_path = path_manager.get_local_path(batch_indices_path)
|
||||
if not os.path.isfile(batch_indices_path):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
f"Looking for batch indices in {batch_indices_path}. "
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
|
||||
with open(batch_indices_path, "r") as f:
|
||||
eval_batch_index = json.load(f)
|
||||
restrict_sequence_name = self.restrict_sequence_name
|
||||
|
||||
if self.get_task() == Task.SINGLE_SEQUENCE:
|
||||
if (
|
||||
self.test_restrict_sequence_id is None
|
||||
or self.test_restrict_sequence_id < 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Please specify an integer id 'test_restrict_sequence_id'"
|
||||
+ " of the sequence considered for 'singlesequence'"
|
||||
+ " training and evaluation."
|
||||
)
|
||||
if len(self.restrict_sequence_name) > 0:
|
||||
raise ValueError(
|
||||
"For the 'singlesequence' task, the restrict_sequence_name has"
|
||||
" to be unset while test_restrict_sequence_id has to be set to an"
|
||||
" integer defining the order of the evaluation sequence."
|
||||
)
|
||||
# a sort-stable set() equivalent:
|
||||
eval_batches_sequence_names = list(
|
||||
{b[0][0]: None for b in eval_batch_index}.keys()
|
||||
)
|
||||
eval_sequence_name = eval_batches_sequence_names[
|
||||
self.test_restrict_sequence_id
|
||||
]
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] == eval_sequence_name
|
||||
]
|
||||
# overwrite the restrict_sequence_name
|
||||
restrict_sequence_name = [eval_sequence_name]
|
||||
|
||||
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||
JsonIndexDataset, self.dataset_class_type
|
||||
)
|
||||
expand_args_fields(dataset_type)
|
||||
train_dataset = None
|
||||
if not self.only_test_set:
|
||||
train_dataset = dataset_type(
|
||||
n_frames_per_sequence=self.n_frames_per_sequence,
|
||||
subsets=set_names_mapping["train"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
**common_kwargs,
|
||||
)
|
||||
if self.test_on_train:
|
||||
assert train_dataset is not None
|
||||
val_dataset = test_dataset = train_dataset
|
||||
else:
|
||||
val_dataset = dataset_type(
|
||||
n_frames_per_sequence=-1,
|
||||
subsets=set_names_mapping["val"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
**common_kwargs,
|
||||
)
|
||||
test_dataset = dataset_type(
|
||||
n_frames_per_sequence=-1,
|
||||
subsets=set_names_mapping["test"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
**common_kwargs,
|
||||
)
|
||||
if len(restrict_sequence_name) > 0:
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||
]
|
||||
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index
|
||||
)
|
||||
dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
|
||||
if self.assert_single_seq:
|
||||
# check there's only one sequence in all datasets
|
||||
sequence_names = {
|
||||
sequence_name
|
||||
for dset in dataset_map.iter_datasets()
|
||||
for sequence_name in dset.sequence_names()
|
||||
}
|
||||
if len(sequence_names) > 1:
|
||||
raise ValueError("Multiple sequences loaded but expected one")
|
||||
|
||||
self.dataset_map = dataset_map
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
# pyre-ignore[16]
|
||||
return self.dataset_map
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return Task(self.task_str)
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
||||
return None
|
||||
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
assert isinstance(train_dataset, JsonIndexDataset)
|
||||
return train_dataset.get_all_train_cameras()
|
||||
|
||||
|
||||
def _get_co3d_set_names_mapping(
|
||||
task: Task,
|
||||
test_on_train: bool,
|
||||
only_test: bool,
|
||||
) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
|
||||
to the names of the corresponding subsets in the CO3D dataset
|
||||
("test_known"/"test_unseen"/"train_known"/"train_unseen").
|
||||
|
||||
The keys returned will be
|
||||
- train (if not only_test)
|
||||
- val (if not test_on_train)
|
||||
- test (if not test_on_train)
|
||||
"""
|
||||
single_seq = task == Task.SINGLE_SEQUENCE
|
||||
|
||||
if only_test:
|
||||
set_names_mapping = {}
|
||||
else:
|
||||
set_names_mapping = {
|
||||
"train": [
|
||||
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
|
||||
+ "_"
|
||||
+ DATASET_TYPE_KNOWN
|
||||
]
|
||||
}
|
||||
if not test_on_train:
|
||||
prefixes = [DATASET_TYPE_TEST]
|
||||
if not single_seq:
|
||||
prefixes.append(DATASET_TYPE_TRAIN)
|
||||
set_names_mapping.update(
|
||||
{
|
||||
dset: [
|
||||
p + "_" + t
|
||||
for p in prefixes
|
||||
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
|
||||
]
|
||||
for dset in ["val", "test"]
|
||||
}
|
||||
)
|
||||
|
||||
return set_names_mapping
|
||||
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
|
||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@registry.register
|
||||
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
"""
|
||||
Generates the training, validation, and testing dataset objects for
|
||||
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
||||
|
||||
The dataset is organized in the filesystem as follows:
|
||||
```
|
||||
self.dataset_root
|
||||
├── <category_0>
|
||||
│ ├── <sequence_name_0>
|
||||
│ │ ├── depth_masks
|
||||
│ │ ├── depths
|
||||
│ │ ├── images
|
||||
│ │ ├── masks
|
||||
│ │ └── pointcloud.ply
|
||||
│ ├── <sequence_name_1>
|
||||
│ │ ├── depth_masks
|
||||
│ │ ├── depths
|
||||
│ │ ├── images
|
||||
│ │ ├── masks
|
||||
│ │ └── pointcloud.ply
|
||||
│ ├── ...
|
||||
│ ├── <sequence_name_N>
|
||||
│ ├── set_lists
|
||||
│ ├── set_lists_<subset_name_0>.json
|
||||
│ ├── set_lists_<subset_name_1>.json
|
||||
│ ├── ...
|
||||
│ ├── set_lists_<subset_name_M>.json
|
||||
│ ├── eval_batches
|
||||
│ │ ├── eval_batches_<subset_name_0>.json
|
||||
│ │ ├── eval_batches_<subset_name_1>.json
|
||||
│ │ ├── ...
|
||||
│ │ ├── eval_batches_<subset_name_M>.json
|
||||
│ ├── frame_annotations.jgz
|
||||
│ ├── sequence_annotations.jgz
|
||||
├── <category_1>
|
||||
├── ...
|
||||
├── <category_K>
|
||||
```
|
||||
|
||||
The dataset contains sequences named `<sequence_name_i>` from `K` categories with
|
||||
names `<category_j>`. Each category comprises sequence folders
|
||||
`<category_k>/<sequence_name_i>` containing the list of sequence images, depth maps,
|
||||
foreground masks, and valid-depth masks `images`, `depths`, `masks`, and `depth_masks`
|
||||
respectively. Furthermore, `<category_k>/<sequence_name_i>/set_lists/` stores `M`
|
||||
json files `set_lists_<subset_name_l>.json`, each describing a certain sequence subset.
|
||||
|
||||
Users specify the loaded dataset subset by setting `self.subset_name` to one of the
|
||||
available subset names `<subset_name_l>`.
|
||||
|
||||
`frame_annotations.jgz` and `sequence_annotations.jgz` are gzipped json files containing
|
||||
the list of all frames and sequences of the given category stored as lists of
|
||||
`FrameAnnotation` and `SequenceAnnotation` objects respectivelly.
|
||||
|
||||
Each `set_lists_<subset_name_l>.json` file contains the following dictionary:
|
||||
```
|
||||
{
|
||||
"train": [
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
"val": [
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
"test": [
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
]
|
||||
```
|
||||
defining the list of frames (identified with their `sequence_name` and `frame_number`)
|
||||
in the "train", "val", and "test" subsets of the dataset.
|
||||
Note that `frame_number` can be obtained only from `frame_annotations.jgz` and
|
||||
does not necesarrily correspond to the numeric suffix of the corresponding image
|
||||
file name (e.g. a file `<category_0>/<sequence_name_0>/images/frame00005.jpg` can
|
||||
have its frame number set to `20`, not 5).
|
||||
|
||||
Each `eval_batches_<subset_name_l>.json` file contains a list of evaluation examples
|
||||
in the following form:
|
||||
```
|
||||
[
|
||||
[ # batch 1
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
[ # batch 1
|
||||
(sequence_name: str, frame_number: int, image_path: str),
|
||||
...
|
||||
],
|
||||
]
|
||||
```
|
||||
Note that the evaluation examples always come from the `"test"` subset of the dataset.
|
||||
(test frames can repeat across batches).
|
||||
|
||||
Args:
|
||||
category: The object category of the dataset.
|
||||
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
||||
e.g. "manyview_dev_0", "fewview_test", ...
|
||||
dataset_root: The root folder of the dataset.
|
||||
test_on_train: Construct validation and test datasets from
|
||||
the training subset.
|
||||
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||
load_eval_batches: Load the file containing eval batches pointing to the
|
||||
test dataset.
|
||||
dataset_args: Specifies additional arguments to the
|
||||
JsonIndexDataset constructor call.
|
||||
path_manager_factory: (Optional) An object that generates an instance of
|
||||
PathManager that can translate provided file paths.
|
||||
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||
"""
|
||||
|
||||
category: str
|
||||
subset_name: str
|
||||
dataset_root: str = _CO3DV2_DATASET_ROOT
|
||||
|
||||
test_on_train: bool = False
|
||||
only_test_set: bool = False
|
||||
load_eval_batches: bool = True
|
||||
|
||||
dataset_class_type: str = "JsonIndexDataset"
|
||||
dataset: JsonIndexDataset
|
||||
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
if self.only_test_set and self.test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
frame_file = os.path.join(
|
||||
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||
)
|
||||
sequence_file = os.path.join(
|
||||
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||
)
|
||||
|
||||
path_manager = self.path_manager_factory.get()
|
||||
|
||||
# setup the common dataset arguments
|
||||
common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args")
|
||||
common_dataset_kwargs = {
|
||||
**common_dataset_kwargs,
|
||||
"dataset_root": self.dataset_root,
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subsets": None,
|
||||
"subset_lists_file": "",
|
||||
"path_manager": path_manager,
|
||||
}
|
||||
|
||||
# get the used dataset type
|
||||
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||
JsonIndexDataset, self.dataset_class_type
|
||||
)
|
||||
expand_args_fields(dataset_type)
|
||||
|
||||
dataset = dataset_type(**common_dataset_kwargs)
|
||||
|
||||
available_subset_names = self._get_available_subset_names()
|
||||
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
||||
if self.subset_name not in available_subset_names:
|
||||
raise ValueError(
|
||||
f"Unknown subset name {self.subset_name}."
|
||||
+ f" Choose one of available subsets: {str(available_subset_names)}."
|
||||
)
|
||||
|
||||
# load the list of train/val/test frames
|
||||
subset_mapping = self._load_annotation_json(
|
||||
os.path.join(
|
||||
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
|
||||
)
|
||||
)
|
||||
|
||||
# load the evaluation batches
|
||||
if self.load_eval_batches:
|
||||
eval_batch_index = self._load_annotation_json(
|
||||
os.path.join(
|
||||
self.category,
|
||||
"eval_batches",
|
||||
f"eval_batches_{self.subset_name}.json",
|
||||
)
|
||||
)
|
||||
|
||||
train_dataset = None
|
||||
if not self.only_test_set:
|
||||
# load the training set
|
||||
logger.debug("Extracting train dataset.")
|
||||
train_dataset = dataset.subset_from_frame_index(subset_mapping["train"])
|
||||
logger.info(f"Train dataset: {str(train_dataset)}")
|
||||
|
||||
if self.test_on_train:
|
||||
assert train_dataset is not None
|
||||
val_dataset = test_dataset = train_dataset
|
||||
else:
|
||||
# load the val and test sets
|
||||
logger.debug("Extracting val dataset.")
|
||||
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
logger.debug("Extracting test dataset.")
|
||||
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
||||
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||
if self.load_eval_batches:
|
||||
# load the eval batches
|
||||
logger.debug("Extracting eval batches.")
|
||||
try:
|
||||
test_dataset.eval_batches = (
|
||||
test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index,
|
||||
)
|
||||
)
|
||||
except IndexError:
|
||||
warnings.warn(
|
||||
"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n"
|
||||
+ "Some eval batches are missing from the test dataset.\n"
|
||||
+ "The evaluation results will be incomparable to the\n"
|
||||
+ "evaluation results calculated on the original dataset.\n"
|
||||
+ "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
|
||||
)
|
||||
test_dataset.eval_batches = (
|
||||
test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index,
|
||||
allow_missing_indices=True,
|
||||
remove_missing_indices=True,
|
||||
)
|
||||
)
|
||||
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
||||
|
||||
self.dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
|
||||
def create_dataset(self):
|
||||
# The dataset object is created inside `self.get_dataset_map`
|
||||
pass
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
return self.dataset_map # pyre-ignore [16]
|
||||
|
||||
def get_category_to_subset_name_list(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns a global dataset index containing the available subset names per category
|
||||
as a dictionary.
|
||||
|
||||
Returns:
|
||||
category_to_subset_name_list: A dictionary containing subset names available
|
||||
per category of the following form:
|
||||
```
|
||||
{
|
||||
category_0: [category_0_subset_name_0, category_0_subset_name_1, ...],
|
||||
category_1: [category_1_subset_name_0, category_1_subset_name_1, ...],
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
"""
|
||||
category_to_subset_name_list_json = "category_to_subset_name_list.json"
|
||||
category_to_subset_name_list = self._load_annotation_json(
|
||||
category_to_subset_name_list_json
|
||||
)
|
||||
return category_to_subset_name_list
|
||||
|
||||
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
||||
return {
|
||||
"manyview": Task.SINGLE_SEQUENCE,
|
||||
"fewview": Task.MULTI_SEQUENCE,
|
||||
}[self.subset_name.split("_")[0]]
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
assert isinstance(train_dataset, JsonIndexDataset)
|
||||
return train_dataset.get_all_train_cameras()
|
||||
|
||||
def _load_annotation_json(self, json_filename: str):
|
||||
full_path = os.path.join(
|
||||
self.dataset_root,
|
||||
json_filename,
|
||||
)
|
||||
logger.info(f"Loading frame index json from {full_path}.")
|
||||
path_manager = self.path_manager_factory.get()
|
||||
if path_manager is not None:
|
||||
full_path = path_manager.get_local_path(full_path)
|
||||
if not os.path.isfile(full_path):
|
||||
# The batch indices file does not exist.
|
||||
# Most probably the user has not specified the root folder.
|
||||
raise ValueError(
|
||||
f"Looking for dataset json file in {full_path}. "
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
with open(full_path, "r") as f:
|
||||
data = json.load(f)
|
||||
return data
|
||||
|
||||
def _get_available_subset_names(self):
|
||||
path_manager = self.path_manager_factory.get()
|
||||
if path_manager is not None:
|
||||
dataset_root = path_manager.get_local_path(self.dataset_root)
|
||||
else:
|
||||
dataset_root = self.dataset_root
|
||||
return get_available_subset_names(dataset_root, self.category)
|
||||
|
||||
|
||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||
"""
|
||||
Get the available subset names for a given category folder inside a root dataset
|
||||
folder `dataset_root`.
|
||||
"""
|
||||
category_dir = os.path.join(dataset_root, category)
|
||||
if not os.path.isdir(category_dir):
|
||||
raise ValueError(
|
||||
f"Looking for dataset files in {category_dir}. "
|
||||
+ "Please specify a correct dataset_root folder."
|
||||
)
|
||||
set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists"))
|
||||
return [
|
||||
json_file.replace("set_lists_", "").replace(".json", "")
|
||||
for json_file in set_list_jsons
|
||||
]
|
||||
63
pytorch3d/implicitron/dataset/llff_dataset_map_provider.py
Normal file
63
pytorch3d/implicitron/dataset/llff_dataset_map_provider.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
|
||||
from .load_llff import load_llff_data
|
||||
|
||||
from .single_sequence_dataset import (
|
||||
_interpret_blender_cameras,
|
||||
SingleSceneDatasetMapProviderBase,
|
||||
)
|
||||
|
||||
|
||||
@registry.register
|
||||
class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
"""
|
||||
Provides data for one scene from the LLFF dataset.
|
||||
|
||||
Members:
|
||||
base_dir: directory holding the data for the scene.
|
||||
object_name: The name of the scene (e.g. "fern"). This is just used as a label.
|
||||
It will typically be equal to the name of the directory self.base_dir.
|
||||
path_manager_factory: Creates path manager which may be used for
|
||||
interpreting paths.
|
||||
n_known_frames_for_test: If set, training frames are included in the val
|
||||
and test datasets, and this many random training frames are added to
|
||||
each test batch. If not set, test batches each contain just a single
|
||||
testing frame.
|
||||
"""
|
||||
|
||||
def _load_data(self) -> None:
|
||||
path_manager = self.path_manager_factory.get()
|
||||
images, poses, _ = load_llff_data(
|
||||
self.base_dir, factor=8, path_manager=path_manager
|
||||
)
|
||||
hwf = poses[0, :3, -1]
|
||||
poses = poses[:, :3, :4]
|
||||
|
||||
i_test = np.arange(images.shape[0])[::8]
|
||||
i_test_index = set(i_test.tolist())
|
||||
i_train = np.array(
|
||||
[i for i in np.arange(images.shape[0]) if i not in i_test_index]
|
||||
)
|
||||
i_split = (i_train, i_test, i_test)
|
||||
H, W, focal = hwf
|
||||
focal_ndc = 2 * focal / min(H, W)
|
||||
images = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||
poses = torch.from_numpy(poses)
|
||||
|
||||
# pyre-ignore[16]
|
||||
self.poses = _interpret_blender_cameras(poses, focal_ndc)
|
||||
# pyre-ignore[16]
|
||||
self.images = images
|
||||
# pyre-ignore[16]
|
||||
self.fg_probabilities = None
|
||||
# pyre-ignore[16]
|
||||
self.i_split = i_split
|
||||
141
pytorch3d/implicitron/dataset/load_blender.py
Normal file
141
pytorch3d/implicitron/dataset/load_blender.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# @lint-ignore-every LICENSELINT
|
||||
# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py
|
||||
# Copyright (c) 2020 bmild
|
||||
import json
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def translate_by_t_along_z(t):
|
||||
tform = np.eye(4).astype(np.float32)
|
||||
tform[2][3] = t
|
||||
return tform
|
||||
|
||||
|
||||
def rotate_by_phi_along_x(phi):
|
||||
tform = np.eye(4).astype(np.float32)
|
||||
tform[1, 1] = tform[2, 2] = np.cos(phi)
|
||||
tform[1, 2] = -np.sin(phi)
|
||||
tform[2, 1] = -tform[1, 2]
|
||||
return tform
|
||||
|
||||
|
||||
def rotate_by_theta_along_y(theta):
|
||||
tform = np.eye(4).astype(np.float32)
|
||||
tform[0, 0] = tform[2, 2] = np.cos(theta)
|
||||
tform[0, 2] = -np.sin(theta)
|
||||
tform[2, 0] = -tform[0, 2]
|
||||
return tform
|
||||
|
||||
|
||||
def pose_spherical(theta, phi, radius):
|
||||
c2w = translate_by_t_along_z(radius)
|
||||
c2w = rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w
|
||||
c2w = rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w
|
||||
c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
|
||||
return c2w
|
||||
|
||||
|
||||
def _local_path(path_manager, path):
|
||||
if path_manager is None:
|
||||
return path
|
||||
return path_manager.get_local_path(path)
|
||||
|
||||
|
||||
def load_blender_data(
|
||||
basedir,
|
||||
half_res=False,
|
||||
testskip=1,
|
||||
debug=False,
|
||||
path_manager=None,
|
||||
focal_length_in_screen_space=False,
|
||||
):
|
||||
splits = ["train", "val", "test"]
|
||||
metas = {}
|
||||
for s in splits:
|
||||
path = os.path.join(basedir, f"transforms_{s}.json")
|
||||
with open(_local_path(path_manager, path)) as fp:
|
||||
metas[s] = json.load(fp)
|
||||
|
||||
all_imgs = []
|
||||
all_poses = []
|
||||
counts = [0]
|
||||
for s in splits:
|
||||
meta = metas[s]
|
||||
imgs = []
|
||||
poses = []
|
||||
if s == "train" or testskip == 0:
|
||||
skip = 1
|
||||
else:
|
||||
skip = testskip
|
||||
|
||||
for frame in meta["frames"][::skip]:
|
||||
fname = os.path.join(basedir, frame["file_path"] + ".png")
|
||||
imgs.append(np.array(Image.open(_local_path(path_manager, fname))))
|
||||
poses.append(np.array(frame["transform_matrix"]))
|
||||
imgs = (np.array(imgs) / 255.0).astype(np.float32)
|
||||
poses = np.array(poses).astype(np.float32)
|
||||
counts.append(counts[-1] + imgs.shape[0])
|
||||
all_imgs.append(imgs)
|
||||
all_poses.append(poses)
|
||||
|
||||
i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
|
||||
|
||||
imgs = np.concatenate(all_imgs, 0)
|
||||
poses = np.concatenate(all_poses, 0)
|
||||
|
||||
H, W = imgs[0].shape[:2]
|
||||
camera_angle_x = float(meta["camera_angle_x"])
|
||||
if focal_length_in_screen_space:
|
||||
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
|
||||
else:
|
||||
focal = 1 / np.tan(0.5 * camera_angle_x)
|
||||
|
||||
render_poses = torch.stack(
|
||||
[
|
||||
torch.from_numpy(pose_spherical(angle, -30.0, 4.0))
|
||||
for angle in np.linspace(-180, 180, 40 + 1)[:-1]
|
||||
],
|
||||
0,
|
||||
)
|
||||
|
||||
# In debug mode, return extremely tiny images
|
||||
if debug:
|
||||
import cv2
|
||||
|
||||
H = H // 32
|
||||
W = W // 32
|
||||
if focal_length_in_screen_space:
|
||||
focal = focal / 32.0
|
||||
imgs = [
|
||||
torch.from_numpy(
|
||||
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
|
||||
)
|
||||
for i in range(imgs.shape[0])
|
||||
]
|
||||
imgs = torch.stack(imgs, 0)
|
||||
poses = torch.from_numpy(poses)
|
||||
return imgs, poses, render_poses, [H, W, focal], i_split
|
||||
|
||||
if half_res:
|
||||
import cv2
|
||||
|
||||
# TODO: resize images using INTER_AREA (cv2)
|
||||
H = H // 2
|
||||
W = W // 2
|
||||
if focal_length_in_screen_space:
|
||||
focal = focal / 2.0
|
||||
imgs = [
|
||||
torch.from_numpy(
|
||||
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
|
||||
)
|
||||
for i in range(imgs.shape[0])
|
||||
]
|
||||
imgs = torch.stack(imgs, 0)
|
||||
|
||||
poses = torch.from_numpy(poses)
|
||||
|
||||
return imgs, poses, render_poses, [H, W, focal], i_split
|
||||
343
pytorch3d/implicitron/dataset/load_llff.py
Normal file
343
pytorch3d/implicitron/dataset/load_llff.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# @lint-ignore-every LICENSELINT
|
||||
# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py
|
||||
# Copyright (c) 2020 bmild
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
# Slightly modified version of LLFF data loading code
|
||||
# see https://github.com/Fyusion/LLFF for original
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _minify(basedir, path_manager, factors=(), resolutions=()):
|
||||
needtoload = False
|
||||
for r in factors:
|
||||
imgdir = os.path.join(basedir, "images_{}".format(r))
|
||||
if not _exists(path_manager, imgdir):
|
||||
needtoload = True
|
||||
for r in resolutions:
|
||||
imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0]))
|
||||
if not _exists(path_manager, imgdir):
|
||||
needtoload = True
|
||||
if not needtoload:
|
||||
return
|
||||
assert path_manager is None
|
||||
|
||||
from subprocess import check_output
|
||||
|
||||
imgdir = os.path.join(basedir, "images")
|
||||
imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))]
|
||||
imgs = [
|
||||
f
|
||||
for f in imgs
|
||||
if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]])
|
||||
]
|
||||
imgdir_orig = imgdir
|
||||
|
||||
wd = os.getcwd()
|
||||
|
||||
for r in factors + resolutions:
|
||||
if isinstance(r, int):
|
||||
name = "images_{}".format(r)
|
||||
resizearg = "{}%".format(100.0 / r)
|
||||
else:
|
||||
name = "images_{}x{}".format(r[1], r[0])
|
||||
resizearg = "{}x{}".format(r[1], r[0])
|
||||
imgdir = os.path.join(basedir, name)
|
||||
if os.path.exists(imgdir):
|
||||
continue
|
||||
|
||||
logger.info(f"Minifying {r}, {basedir}")
|
||||
|
||||
os.makedirs(imgdir)
|
||||
check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True)
|
||||
|
||||
ext = imgs[0].split(".")[-1]
|
||||
args = " ".join(
|
||||
["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)]
|
||||
)
|
||||
logger.info(args)
|
||||
os.chdir(imgdir)
|
||||
check_output(args, shell=True)
|
||||
os.chdir(wd)
|
||||
|
||||
if ext != "png":
|
||||
check_output("rm {}/*.{}".format(imgdir, ext), shell=True)
|
||||
logger.info("Removed duplicates")
|
||||
logger.info("Done")
|
||||
|
||||
|
||||
def _load_data(
|
||||
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
|
||||
):
|
||||
|
||||
poses_arr = np.load(
|
||||
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
|
||||
)
|
||||
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
|
||||
bds = poses_arr[:, -2:].transpose([1, 0])
|
||||
|
||||
img0 = [
|
||||
os.path.join(basedir, "images", f)
|
||||
for f in sorted(_ls(path_manager, os.path.join(basedir, "images")))
|
||||
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
|
||||
][0]
|
||||
|
||||
def imread(f):
|
||||
return np.array(Image.open(f))
|
||||
|
||||
sh = imread(_local_path(path_manager, img0)).shape
|
||||
|
||||
sfx = ""
|
||||
|
||||
if factor is not None:
|
||||
sfx = "_{}".format(factor)
|
||||
_minify(basedir, path_manager, factors=[factor])
|
||||
factor = factor
|
||||
elif height is not None:
|
||||
factor = sh[0] / float(height)
|
||||
width = int(sh[1] / factor)
|
||||
_minify(basedir, path_manager, resolutions=[[height, width]])
|
||||
sfx = "_{}x{}".format(width, height)
|
||||
elif width is not None:
|
||||
factor = sh[1] / float(width)
|
||||
height = int(sh[0] / factor)
|
||||
_minify(basedir, path_manager, resolutions=[[height, width]])
|
||||
sfx = "_{}x{}".format(width, height)
|
||||
else:
|
||||
factor = 1
|
||||
|
||||
imgdir = os.path.join(basedir, "images" + sfx)
|
||||
if not _exists(path_manager, imgdir):
|
||||
raise ValueError(f"{imgdir} does not exist, returning")
|
||||
|
||||
imgfiles = [
|
||||
_local_path(path_manager, os.path.join(imgdir, f))
|
||||
for f in sorted(_ls(path_manager, imgdir))
|
||||
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
|
||||
]
|
||||
if poses.shape[-1] != len(imgfiles):
|
||||
raise ValueError(
|
||||
"Mismatch between imgs {} and poses {} !!!!".format(
|
||||
len(imgfiles), poses.shape[-1]
|
||||
)
|
||||
)
|
||||
|
||||
sh = imread(imgfiles[0]).shape
|
||||
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
|
||||
poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor
|
||||
|
||||
if not load_imgs:
|
||||
return poses, bds
|
||||
|
||||
imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles]
|
||||
imgs = np.stack(imgs, -1)
|
||||
|
||||
logger.info(f"Loaded image data, shape {imgs.shape}")
|
||||
return poses, bds, imgs
|
||||
|
||||
|
||||
def normalize(x):
|
||||
denom = np.linalg.norm(x)
|
||||
if denom < 0.001:
|
||||
warnings.warn("unsafe normalize()")
|
||||
return x / denom
|
||||
|
||||
|
||||
def viewmatrix(z, up, pos):
|
||||
vec2 = normalize(z)
|
||||
vec1_avg = up
|
||||
vec0 = normalize(np.cross(vec1_avg, vec2))
|
||||
vec1 = normalize(np.cross(vec2, vec0))
|
||||
m = np.stack([vec0, vec1, vec2, pos], 1)
|
||||
return m
|
||||
|
||||
|
||||
def ptstocam(pts, c2w):
|
||||
tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
|
||||
return tt
|
||||
|
||||
|
||||
def poses_avg(poses):
|
||||
|
||||
hwf = poses[0, :3, -1:]
|
||||
|
||||
center = poses[:, :3, 3].mean(0)
|
||||
vec2 = normalize(poses[:, :3, 2].sum(0))
|
||||
up = poses[:, :3, 1].sum(0)
|
||||
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
|
||||
|
||||
return c2w
|
||||
|
||||
|
||||
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
|
||||
render_poses = []
|
||||
rads = np.array(list(rads) + [1.0])
|
||||
hwf = c2w[:, 4:5]
|
||||
|
||||
for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
|
||||
c = np.dot(
|
||||
c2w[:3, :4],
|
||||
np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
|
||||
* rads,
|
||||
)
|
||||
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
|
||||
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
|
||||
return render_poses
|
||||
|
||||
|
||||
def recenter_poses(poses):
|
||||
|
||||
poses_ = poses + 0
|
||||
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
|
||||
c2w = poses_avg(poses)
|
||||
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
|
||||
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
|
||||
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
|
||||
|
||||
poses = np.linalg.inv(c2w) @ poses
|
||||
poses_[:, :3, :4] = poses[:, :3, :4]
|
||||
poses = poses_
|
||||
return poses
|
||||
|
||||
|
||||
def spherify_poses(poses, bds):
|
||||
def add_row_to_homogenize_transform(p):
|
||||
r"""Add the last row to homogenize 3 x 4 transformation matrices."""
|
||||
return np.concatenate(
|
||||
[p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
|
||||
)
|
||||
|
||||
# p34_to_44 = lambda p: np.concatenate(
|
||||
# [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
|
||||
# )
|
||||
|
||||
p34_to_44 = add_row_to_homogenize_transform
|
||||
|
||||
rays_d = poses[:, :3, 2:3]
|
||||
rays_o = poses[:, :3, 3:4]
|
||||
|
||||
def min_line_dist(rays_o, rays_d):
|
||||
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
|
||||
b_i = -A_i @ rays_o
|
||||
pt_mindist = np.squeeze(
|
||||
-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
|
||||
)
|
||||
return pt_mindist
|
||||
|
||||
pt_mindist = min_line_dist(rays_o, rays_d)
|
||||
|
||||
center = pt_mindist
|
||||
up = (poses[:, :3, 3] - center).mean(0)
|
||||
|
||||
vec0 = normalize(up)
|
||||
vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0))
|
||||
vec2 = normalize(np.cross(vec0, vec1))
|
||||
pos = center
|
||||
c2w = np.stack([vec1, vec2, vec0, pos], 1)
|
||||
|
||||
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
|
||||
|
||||
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
|
||||
|
||||
sc = 1.0 / rad
|
||||
poses_reset[:, :3, 3] *= sc
|
||||
bds *= sc
|
||||
rad *= sc
|
||||
|
||||
centroid = np.mean(poses_reset[:, :3, 3], 0)
|
||||
zh = centroid[2]
|
||||
radcircle = np.sqrt(rad**2 - zh**2)
|
||||
new_poses = []
|
||||
|
||||
for th in np.linspace(0.0, 2.0 * np.pi, 120):
|
||||
|
||||
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
|
||||
up = np.array([0, 0, -1.0])
|
||||
|
||||
vec2 = normalize(camorigin)
|
||||
vec0 = normalize(np.cross(vec2, up))
|
||||
vec1 = normalize(np.cross(vec2, vec0))
|
||||
pos = camorigin
|
||||
p = np.stack([vec0, vec1, vec2, pos], 1)
|
||||
|
||||
new_poses.append(p)
|
||||
|
||||
new_poses = np.stack(new_poses, 0)
|
||||
|
||||
new_poses = np.concatenate(
|
||||
[new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1
|
||||
)
|
||||
poses_reset = np.concatenate(
|
||||
[
|
||||
poses_reset[:, :3, :4],
|
||||
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape),
|
||||
],
|
||||
-1,
|
||||
)
|
||||
|
||||
return poses_reset, new_poses, bds
|
||||
|
||||
|
||||
def _local_path(path_manager, path):
|
||||
if path_manager is None:
|
||||
return path
|
||||
return path_manager.get_local_path(path)
|
||||
|
||||
|
||||
def _ls(path_manager, path):
|
||||
if path_manager is None:
|
||||
return os.path.listdir(path)
|
||||
return path_manager.ls(path)
|
||||
|
||||
|
||||
def _exists(path_manager, path):
|
||||
if path_manager is None:
|
||||
return os.path.exists(path)
|
||||
return path_manager.exists(path)
|
||||
|
||||
|
||||
def load_llff_data(
|
||||
basedir,
|
||||
factor=8,
|
||||
recenter=True,
|
||||
bd_factor=0.75,
|
||||
spherify=False,
|
||||
path_zflat=False,
|
||||
path_manager=None,
|
||||
):
|
||||
|
||||
poses, bds, imgs = _load_data(
|
||||
basedir, factor=factor, path_manager=path_manager
|
||||
) # factor=8 downsamples original imgs by 8x
|
||||
logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}")
|
||||
|
||||
# Correct rotation matrix ordering and move variable dim to axis 0
|
||||
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
|
||||
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
|
||||
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
|
||||
images = imgs
|
||||
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
|
||||
|
||||
# Rescale if bd_factor is provided
|
||||
sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor)
|
||||
poses[:, :3, 3] *= sc
|
||||
bds *= sc
|
||||
|
||||
if recenter:
|
||||
poses = recenter_poses(poses)
|
||||
|
||||
if spherify:
|
||||
poses, render_poses, bds = spherify_poses(poses, bds)
|
||||
|
||||
images = images.astype(np.float32)
|
||||
poses = poses.astype(np.float32)
|
||||
|
||||
return images, poses, bds
|
||||
@@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from .implicitron_dataset import ImplicitronDatasetBase
|
||||
from .dataset_base import DatasetBase
|
||||
|
||||
|
||||
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
||||
@@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
of sequences.
|
||||
"""
|
||||
|
||||
dataset: ImplicitronDatasetBase
|
||||
dataset: DatasetBase
|
||||
batch_size: int
|
||||
num_batches: int
|
||||
# the sampler first samples a random element k from this list and then
|
||||
|
||||
205
pytorch3d/implicitron/dataset/single_sequence_dataset.py
Normal file
205
pytorch3d/implicitron/dataset/single_sequence_dataset.py
Normal file
@@ -0,0 +1,205 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# This file defines a base class for dataset map providers which
|
||||
# provide data for a single scene.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Iterable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
expand_args_fields,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
||||
|
||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
||||
|
||||
|
||||
class SingleSceneDataset(DatasetBase, Configurable):
|
||||
"""
|
||||
A dataset from images from a single scene.
|
||||
"""
|
||||
|
||||
images: List[torch.Tensor] = field()
|
||||
fg_probabilities: Optional[List[torch.Tensor]] = field()
|
||||
poses: List[PerspectiveCameras] = field()
|
||||
object_name: str = field()
|
||||
frame_types: List[str] = field()
|
||||
eval_batches: Optional[List[List[int]]] = field()
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
return [_SINGLE_SEQUENCE_NAME]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.poses)
|
||||
|
||||
def __getitem__(self, index) -> FrameData:
|
||||
if index >= len(self):
|
||||
raise IndexError(f"index {index} out of range {len(self)}")
|
||||
image = self.images[index]
|
||||
pose = self.poses[index]
|
||||
frame_type = self.frame_types[index]
|
||||
fg_probability = (
|
||||
None if self.fg_probabilities is None else self.fg_probabilities[index]
|
||||
)
|
||||
|
||||
frame_data = FrameData(
|
||||
frame_number=index,
|
||||
sequence_name=_SINGLE_SEQUENCE_NAME,
|
||||
sequence_category=self.object_name,
|
||||
camera=pose,
|
||||
image_size_hw=torch.tensor(image.shape[1:]),
|
||||
image_rgb=image,
|
||||
fg_probability=fg_probability,
|
||||
frame_type=frame_type,
|
||||
)
|
||||
return frame_data
|
||||
|
||||
def get_eval_batches(self) -> Optional[List[List[int]]]:
|
||||
return self.eval_batches
|
||||
|
||||
|
||||
# pyre-fixme[13]: Uninitialized attribute
|
||||
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
"""
|
||||
Base for provider of data for one scene from LLFF or blender datasets.
|
||||
|
||||
Members:
|
||||
base_dir: directory holding the data for the scene.
|
||||
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
|
||||
It will typically be equal to the name of the directory self.base_dir.
|
||||
path_manager_factory: Creates path manager which may be used for
|
||||
interpreting paths.
|
||||
n_known_frames_for_test: If set, training frames are included in the val
|
||||
and test datasets, and this many random training frames are added to
|
||||
each test batch. If not set, test batches each contain just a single
|
||||
testing frame.
|
||||
"""
|
||||
|
||||
base_dir: str
|
||||
object_name: str
|
||||
path_manager_factory: PathManagerFactory
|
||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||
n_known_frames_for_test: Optional[int] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
run_auto_creation(self)
|
||||
self._load_data()
|
||||
|
||||
def _load_data(self) -> None:
|
||||
# This must be defined by each subclass,
|
||||
# and should set the following on self.
|
||||
# - poses: a list of length-1 camera objects
|
||||
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _get_dataset(
|
||||
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
||||
) -> SingleSceneDataset:
|
||||
expand_args_fields(SingleSceneDataset)
|
||||
# pyre-ignore[16]
|
||||
split = self.i_split[split_idx]
|
||||
frame_types = [frame_type] * len(split)
|
||||
fg_probabilities = (
|
||||
None
|
||||
# pyre-ignore[16]
|
||||
if self.fg_probabilities is None
|
||||
else self.fg_probabilities[split]
|
||||
)
|
||||
eval_batches = [[i] for i in range(len(split))]
|
||||
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
||||
train_split = self.i_split[0]
|
||||
if set_eval_batches:
|
||||
generator = np.random.default_rng(seed=0)
|
||||
for batch in eval_batches:
|
||||
# using permutation so that changes to n_known_frames_for_test
|
||||
# result in consistent batches.
|
||||
to_add = generator.permutation(len(train_split))[
|
||||
: self.n_known_frames_for_test
|
||||
]
|
||||
batch.extend((to_add + len(split)).tolist())
|
||||
split = np.concatenate([split, train_split])
|
||||
frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
|
||||
|
||||
# pyre-ignore[28]
|
||||
return SingleSceneDataset(
|
||||
object_name=self.object_name,
|
||||
# pyre-ignore[16]
|
||||
images=self.images[split],
|
||||
fg_probabilities=fg_probabilities,
|
||||
# pyre-ignore[16]
|
||||
poses=[self.poses[i] for i in split],
|
||||
frame_types=frame_types,
|
||||
eval_batches=eval_batches if set_eval_batches else None,
|
||||
)
|
||||
|
||||
def get_dataset_map(self) -> DatasetMap:
|
||||
return DatasetMap(
|
||||
train=self._get_dataset(0, DATASET_TYPE_KNOWN),
|
||||
val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
|
||||
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
||||
)
|
||||
|
||||
def get_task(self) -> Task:
|
||||
return Task.SINGLE_SEQUENCE
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||
return join_cameras_as_batch(cameras)
|
||||
|
||||
|
||||
def _interpret_blender_cameras(
|
||||
poses: torch.Tensor, focal: float
|
||||
) -> List[PerspectiveCameras]:
|
||||
"""
|
||||
Convert 4x4 matrices representing cameras in blender format
|
||||
to PyTorch3D format.
|
||||
|
||||
Args:
|
||||
poses: N x 3 x 4 camera matrices
|
||||
focal: ndc space focal length
|
||||
"""
|
||||
pose_target_cameras = []
|
||||
for pose_target in poses:
|
||||
pose_target = pose_target[:3, :4]
|
||||
mtx = torch.eye(4, dtype=pose_target.dtype)
|
||||
mtx[:3, :3] = pose_target[:3, :3].t()
|
||||
mtx[3, :3] = pose_target[:, 3]
|
||||
mtx = mtx.inverse()
|
||||
|
||||
# flip the XZ coordinates.
|
||||
mtx[:, [0, 2]] *= -1.0
|
||||
|
||||
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
|
||||
|
||||
focal_length_pt3 = torch.FloatTensor([[focal, focal]])
|
||||
principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]])
|
||||
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=focal_length_pt3,
|
||||
principal_point=principal_point_pt3,
|
||||
R=Rpt3[None],
|
||||
T=Tpt3,
|
||||
)
|
||||
pose_target_cameras.append(cameras)
|
||||
return pose_target_cameras
|
||||
@@ -8,8 +8,8 @@
|
||||
import dataclasses
|
||||
import gzip
|
||||
import json
|
||||
from dataclasses import MISSING, Field, dataclass
|
||||
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
|
||||
from dataclasses import dataclass, Field, MISSING
|
||||
from typing import Any, cast, Dict, IO, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
from pytorch3d.common.datatypes import get_args, get_origin
|
||||
@@ -80,6 +80,7 @@ class FrameAnnotation:
|
||||
depth: Optional[DepthAnnotation] = None
|
||||
mask: Optional[MaskAnnotation] = None
|
||||
viewpoint: Optional[ViewpointAnnotation] = None
|
||||
meta: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -169,9 +170,11 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
|
||||
if typeannot is Any:
|
||||
return dlist
|
||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||
return dlist
|
||||
elif any(obj is None for obj in dlist):
|
||||
if any(obj is None for obj in dlist):
|
||||
# filter out Nones and recurse on the resulting list
|
||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||
idx, notnone = zip(*idx_notnone)
|
||||
@@ -180,8 +183,13 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
for i, obj in zip(idx, converted):
|
||||
res[i] = obj
|
||||
return res
|
||||
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||
|
||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||
types = cls._field_types.values()
|
||||
dlist_T = zip(*dlist)
|
||||
@@ -218,7 +226,7 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
|
||||
keys = np.split(list(all_keys_res), indices[:-1])
|
||||
vals = np.split(list(all_vals_res), indices[:-1])
|
||||
return [cls(zip(*k, v)) for k, v in zip(keys, vals)]
|
||||
return [cls(zip(k, v)) for k, v in zip(keys, vals)]
|
||||
elif not dataclasses.is_dataclass(typeannot):
|
||||
return dlist
|
||||
|
||||
@@ -240,10 +248,15 @@ def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||
|
||||
|
||||
def _dataclass_from_dict(d, typeannot):
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
if d is None:
|
||||
if d is None or typeannot is Any:
|
||||
return d
|
||||
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
is_optional, contained_type = _resolve_optional(typeannot)
|
||||
if is_optional:
|
||||
# an Optional not set to None, just use the contents of the Optional.
|
||||
return _dataclass_from_dict(d, contained_type)
|
||||
|
||||
cls = get_origin(typeannot) or typeannot
|
||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||
types = cls._field_types.values()
|
||||
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
|
||||
elif issubclass(cls, (list, tuple)):
|
||||
@@ -315,3 +328,15 @@ def load_dataclass_jgzip(outfile, cls):
|
||||
"""
|
||||
with gzip.GzipFile(outfile, "rb") as f:
|
||||
return load_dataclass(cast(IO, f), cls, binary=True)
|
||||
|
||||
|
||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||
if get_origin(type_) is Union:
|
||||
args = get_args(type_)
|
||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||
return True, args[0]
|
||||
if type_ is Any:
|
||||
return True, Any
|
||||
|
||||
return False, type_
|
||||
|
||||
@@ -16,6 +16,14 @@ DATASET_TYPE_KNOWN = "known"
|
||||
DATASET_TYPE_UNKNOWN = "unseen"
|
||||
|
||||
|
||||
def is_known_frame_scalar(frame_type: str) -> bool:
|
||||
"""
|
||||
Given a single frame type corresponding to a single frame, return whether
|
||||
the frame is a known frame.
|
||||
"""
|
||||
return frame_type.endswith(DATASET_TYPE_KNOWN)
|
||||
|
||||
|
||||
def is_known_frame(
|
||||
frame_type: List[str], device: Optional[str] = None
|
||||
) -> torch.BoolTensor:
|
||||
@@ -23,8 +31,9 @@ def is_known_frame(
|
||||
Given a list `frame_type` of frame types in a batch, return a tensor
|
||||
of boolean flags expressing whether the corresponding frame is a known frame.
|
||||
"""
|
||||
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
||||
return torch.tensor(
|
||||
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
|
||||
[is_known_frame_scalar(ft) for ft in frame_type],
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
@@ -37,6 +46,7 @@ def is_train_frame(
|
||||
Given a list `frame_type` of frame types in a batch, return a tensor
|
||||
of boolean flags expressing whether the corresponding frame is a training frame.
|
||||
"""
|
||||
# pyre-fixme[7]: Expected `BoolTensor` but got `Tensor`.
|
||||
return torch.tensor(
|
||||
[ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type],
|
||||
dtype=torch.bool,
|
||||
|
||||
@@ -10,11 +10,12 @@ import torch
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
from .implicitron_dataset import FrameData, ImplicitronDataset
|
||||
from .dataset_base import FrameData
|
||||
from .json_index_dataset import JsonIndexDataset
|
||||
|
||||
|
||||
def get_implicitron_sequence_pointcloud(
|
||||
dataset: ImplicitronDataset,
|
||||
dataset: JsonIndexDataset,
|
||||
sequence_name: Optional[str] = None,
|
||||
mask_points: bool = True,
|
||||
max_frames: int = -1,
|
||||
@@ -43,6 +44,7 @@ def get_implicitron_sequence_pointcloud(
|
||||
sequence_entries = [
|
||||
ei
|
||||
for ei in sequence_entries
|
||||
# pyre-ignore[16]
|
||||
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
|
||||
== sequence_name
|
||||
]
|
||||
@@ -67,7 +69,7 @@ def get_implicitron_sequence_pointcloud(
|
||||
batch_size=len(sequence_dataset),
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
collate_fn=dataset.frame_data_type.collate,
|
||||
)
|
||||
|
||||
frame_data = next(iter(loader)) # there's only one batch
|
||||
|
||||
@@ -5,21 +5,17 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import dataclasses
|
||||
import os
|
||||
from typing import cast, Optional
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
ImplicitronDatasetBase,
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
aggregate_nvs_results,
|
||||
eval_batch,
|
||||
@@ -47,10 +43,12 @@ def main() -> None:
|
||||
"""
|
||||
|
||||
task_results = {}
|
||||
for task in ("singlesequence", "multisequence"):
|
||||
for task in (Task.SINGLE_SEQUENCE, Task.MULTI_SEQUENCE):
|
||||
task_results[task] = []
|
||||
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
|
||||
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
|
||||
for category in CO3D_CATEGORIES[: (20 if task == Task.SINGLE_SEQUENCE else 10)]:
|
||||
for single_sequence_id in (
|
||||
(0, 1) if task == Task.SINGLE_SEQUENCE else (None,)
|
||||
):
|
||||
category_result = evaluate_dbir_for_category(
|
||||
category, task=task, single_sequence_id=single_sequence_id
|
||||
)
|
||||
@@ -74,9 +72,9 @@ def main() -> None:
|
||||
|
||||
|
||||
def evaluate_dbir_for_category(
|
||||
category: str = "apple",
|
||||
bg_color: float = 0.0,
|
||||
task: str = "singlesequence",
|
||||
category: str,
|
||||
task: Task,
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
single_sequence_id: Optional[int] = None,
|
||||
num_workers: int = 16,
|
||||
):
|
||||
@@ -90,6 +88,7 @@ def evaluate_dbir_for_category(
|
||||
task: Evaluation task. Either singlesequence or multisequence.
|
||||
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
|
||||
num_workers: The number of workers for the employed dataloaders.
|
||||
path_manager: (optional) Used for interpreting paths.
|
||||
|
||||
Returns:
|
||||
category_result: A dictionary of quantitative metrics.
|
||||
@@ -99,46 +98,35 @@ def evaluate_dbir_for_category(
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
if task not in ["multisequence", "singlesequence"]:
|
||||
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
|
||||
|
||||
datasets = dataset_zoo(
|
||||
category=category,
|
||||
dataset_root=os.environ["CO3D_DATASET_ROOT"],
|
||||
assert_single_seq=task == "singlesequence",
|
||||
dataset_name=f"co3d_{task}",
|
||||
test_on_train=False,
|
||||
load_point_clouds=True,
|
||||
test_restrict_sequence_id=single_sequence_id,
|
||||
dataset_map_provider_args = {
|
||||
"category": category,
|
||||
"dataset_root": os.environ["CO3D_DATASET_ROOT"],
|
||||
"assert_single_seq": task == Task.SINGLE_SEQUENCE,
|
||||
"task_str": task.value,
|
||||
"test_on_train": False,
|
||||
"test_restrict_sequence_id": single_sequence_id,
|
||||
"dataset_JsonIndexDataset_args": {"load_point_clouds": True},
|
||||
}
|
||||
data_source = ImplicitronDataSource(
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args=dataset_map_provider_args
|
||||
)
|
||||
|
||||
dataloaders = dataloader_zoo(
|
||||
datasets,
|
||||
dataset_name=f"co3d_{task}",
|
||||
)
|
||||
datasets, dataloaders = data_source.get_datasets_and_dataloaders()
|
||||
|
||||
test_dataset = datasets["test"]
|
||||
test_dataloader = dataloaders["test"]
|
||||
test_dataset = datasets.test
|
||||
test_dataloader = dataloaders.test
|
||||
if test_dataset is None or test_dataloader is None:
|
||||
raise ValueError("must have a test dataset.")
|
||||
|
||||
if task == "singlesequence":
|
||||
# all_source_cameras are needed for evaluation of the
|
||||
# target camera difficulty
|
||||
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
|
||||
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
|
||||
all_source_cameras = _get_all_source_cameras(
|
||||
test_dataset, sequence_name, num_workers=num_workers
|
||||
)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
|
||||
image_size = cast(ImplicitronDataset, test_dataset).image_width
|
||||
image_size = cast(JsonIndexDataset, test_dataset).image_width
|
||||
|
||||
if image_size is None:
|
||||
raise ValueError("Image size should be set in the dataset")
|
||||
|
||||
# init the simple DBIR model
|
||||
model = ModelDBIR(
|
||||
image_size=image_size,
|
||||
model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden
|
||||
render_image_width=image_size,
|
||||
render_image_height=image_size,
|
||||
bg_color=bg_color,
|
||||
max_points=int(1e5),
|
||||
)
|
||||
@@ -153,25 +141,31 @@ def evaluate_dbir_for_category(
|
||||
for frame_data in tqdm(test_dataloader):
|
||||
frame_data = dataclass_to_cuda_(frame_data)
|
||||
preds = model(**dataclasses.asdict(frame_data))
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
per_batch_eval_results.append(
|
||||
eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
preds["implicitron_render"],
|
||||
bg_color=bg_color,
|
||||
lpips_model=lpips_model,
|
||||
source_cameras=all_source_cameras,
|
||||
source_cameras=data_source.all_train_cameras,
|
||||
)
|
||||
)
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||
else:
|
||||
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||
|
||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||
per_batch_eval_results, task
|
||||
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||
)
|
||||
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _print_aggregate_results(task, task_results) -> None:
|
||||
def _print_aggregate_results(
|
||||
task: Task, task_results: Dict[Task, List[List[Dict[str, Any]]]]
|
||||
) -> None:
|
||||
"""
|
||||
Prints the aggregate metrics for a given task.
|
||||
"""
|
||||
@@ -182,35 +176,5 @@ def _print_aggregate_results(task, task_results) -> None:
|
||||
print("")
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
|
||||
):
|
||||
"""
|
||||
Loads all training cameras of a given sequence.
|
||||
|
||||
The set of all seen cameras is needed for evaluating the viewpoint difficulty
|
||||
for the singlescene evaluation.
|
||||
|
||||
Args:
|
||||
dataset: Co3D dataset object.
|
||||
sequence_name: The name of the sequence.
|
||||
num_workers: The number of for the utilized dataloader.
|
||||
"""
|
||||
|
||||
# load all source cameras of the sequence
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
|
||||
(all_frame_data,) = torch.utils.data.DataLoader(
|
||||
dataset_for_loader,
|
||||
shuffle=False,
|
||||
batch_size=len(dataset_for_loader),
|
||||
num_workers=num_workers,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
is_known = is_known_frame(all_frame_data.frame_type)
|
||||
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
|
||||
return source_cameras
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -9,12 +9,15 @@ import copy
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||
from pytorch3d.implicitron.tools import vis_utils
|
||||
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
|
||||
from pytorch3d.implicitron.tools.image_utils import mask_background
|
||||
@@ -31,18 +34,6 @@ from visdom import Visdom
|
||||
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
|
||||
|
||||
|
||||
@dataclass
|
||||
class NewViewSynthesisPrediction:
|
||||
"""
|
||||
Holds the tensors that describe a result of synthesizing new views.
|
||||
"""
|
||||
|
||||
depth_render: Optional[torch.Tensor] = None
|
||||
image_render: Optional[torch.Tensor] = None
|
||||
mask_render: Optional[torch.Tensor] = None
|
||||
camera_distance: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Visualizer:
|
||||
image_render: torch.Tensor
|
||||
@@ -145,14 +136,14 @@ class _Visualizer:
|
||||
|
||||
def eval_batch(
|
||||
frame_data: FrameData,
|
||||
nvs_prediction: NewViewSynthesisPrediction,
|
||||
bg_color: Union[torch.Tensor, str, float] = "black",
|
||||
implicitron_render: ImplicitronRender,
|
||||
bg_color: Union[torch.Tensor, Sequence, str, float] = "black",
|
||||
mask_thr: float = 0.5,
|
||||
lpips_model=None,
|
||||
visualize: bool = False,
|
||||
visualize_visdom_env: str = "eval_debug",
|
||||
break_after_visualising: bool = True,
|
||||
source_cameras: Optional[List[CamerasBase]] = None,
|
||||
source_cameras: Optional[CamerasBase] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Produce performance metrics for a single batch of new-view synthesis
|
||||
@@ -162,14 +153,14 @@ def eval_batch(
|
||||
is True), a new-view synthesis method (NVS) is tasked to generate new views
|
||||
of the scene from the viewpoint of the target views (for which
|
||||
frame_data.frame_type.endswith('known') is False). The resulting
|
||||
synthesized new views, stored in `nvs_prediction`, are compared to the
|
||||
synthesized new views, stored in `implicitron_render`, are compared to the
|
||||
target ground truth in `frame_data` in terms of geometry and appearance
|
||||
resulting in a dictionary of metrics returned by the `eval_batch` function.
|
||||
|
||||
Args:
|
||||
frame_data: A FrameData object containing the input to the new view
|
||||
synthesis method.
|
||||
nvs_prediction: The data describing the synthesized new views.
|
||||
implicitron_render: The data describing the synthesized new views.
|
||||
bg_color: The background color of the generated new views and the
|
||||
ground truth.
|
||||
lpips_model: A pre-trained model for evaluating the LPIPS metric.
|
||||
@@ -184,26 +175,39 @@ def eval_batch(
|
||||
ValueError if frame_data does not have frame_type, camera, or image_rgb
|
||||
ValueError if the batch has a mix of training and test samples
|
||||
ValueError if the batch frames are not [unseen, known, known, ...]
|
||||
ValueError if one of the required fields in nvs_prediction is missing
|
||||
ValueError if one of the required fields in implicitron_render is missing
|
||||
"""
|
||||
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
|
||||
frame_type = frame_data.frame_type
|
||||
if frame_type is None:
|
||||
raise ValueError("Frame type has not been set.")
|
||||
|
||||
# we check that all those fields are not None but Pyre can't infer that properly
|
||||
# TODO: assign to local variables
|
||||
# TODO: assign to local variables and simplify the code.
|
||||
if frame_data.image_rgb is None:
|
||||
raise ValueError("Image is not in the evaluation batch.")
|
||||
|
||||
if frame_data.camera is None:
|
||||
raise ValueError("Camera is not in the evaluation batch.")
|
||||
|
||||
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
|
||||
raise ValueError("One of the required predicted fields is missing")
|
||||
# eval all results in the resolution of the frame_data image
|
||||
image_resol = tuple(frame_data.image_rgb.shape[2:])
|
||||
|
||||
# Post-process the render:
|
||||
# 1) check implicitron_render for Nones,
|
||||
# 2) obtain copies to make sure we dont edit the original data,
|
||||
# 3) take only the 1st (target) image
|
||||
# 4) resize to match ground-truth resolution
|
||||
cloned_render: Dict[str, torch.Tensor] = {}
|
||||
for k in ["mask_render", "image_render", "depth_render"]:
|
||||
field = getattr(implicitron_render, k)
|
||||
if field is None:
|
||||
raise ValueError(f"A required predicted field {k} is missing")
|
||||
|
||||
imode = "bilinear" if k == "image_render" else "nearest"
|
||||
cloned_render[k] = (
|
||||
F.interpolate(field[:1], size=image_resol, mode=imode).detach().clone()
|
||||
)
|
||||
|
||||
# obtain copies to make sure we dont edit the original data
|
||||
nvs_prediction = copy.deepcopy(nvs_prediction)
|
||||
frame_data = copy.deepcopy(frame_data)
|
||||
|
||||
# mask the ground truth depth in case frame_data contains the depth mask
|
||||
@@ -226,9 +230,6 @@ def eval_batch(
|
||||
+ " a target view while the rest should be source views."
|
||||
) # TODO: do we need to enforce this?
|
||||
|
||||
# take only the first (target image)
|
||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
||||
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
|
||||
for k in [
|
||||
"depth_map",
|
||||
"image_rgb",
|
||||
@@ -242,10 +243,6 @@ def eval_batch(
|
||||
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
||||
warnings.warn("Empty or missing depth map in evaluation!")
|
||||
|
||||
# eval all results in the resolution of the frame_data image
|
||||
# pyre-fixme[16]: `Optional` has no attribute `shape`.
|
||||
image_resol = list(frame_data.image_rgb.shape[2:])
|
||||
|
||||
# threshold the masks to make ground truth binary masks
|
||||
mask_fg, mask_crop = [
|
||||
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
||||
@@ -258,29 +255,14 @@ def eval_batch(
|
||||
bg_color=bg_color,
|
||||
)
|
||||
|
||||
# resize to the target resolution
|
||||
for k in REQUIRED_NVS_PREDICTION_FIELDS:
|
||||
imode = "bilinear" if k == "image_render" else "nearest"
|
||||
val = getattr(nvs_prediction, k)
|
||||
setattr(
|
||||
nvs_prediction,
|
||||
k,
|
||||
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
|
||||
# `List[typing.Any]`.
|
||||
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
|
||||
)
|
||||
|
||||
# clamp predicted images
|
||||
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
|
||||
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
|
||||
image_render = cloned_render["image_render"].clamp(0.0, 1.0)
|
||||
|
||||
if visualize:
|
||||
visualizer = _Visualizer(
|
||||
image_render=image_render,
|
||||
image_rgb_masked=image_rgb_masked,
|
||||
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
depth_render=nvs_prediction.depth_render,
|
||||
depth_render=cloned_render["depth_render"],
|
||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
depth_map=frame_data.depth_map,
|
||||
@@ -292,9 +274,7 @@ def eval_batch(
|
||||
results: Dict[str, Any] = {}
|
||||
|
||||
results["iou"] = iou(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
nvs_prediction.mask_render,
|
||||
cloned_render["mask_render"],
|
||||
mask_fg,
|
||||
mask=mask_crop,
|
||||
)
|
||||
@@ -321,11 +301,9 @@ def eval_batch(
|
||||
if name_postfix == "_fg":
|
||||
# only record depth metrics for the foreground
|
||||
_, abs_ = eval_depth(
|
||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
nvs_prediction.depth_render,
|
||||
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
|
||||
# `Optional[torch.Tensor]`.
|
||||
cloned_render["depth_render"],
|
||||
# pyre-fixme[6]: For 2nd param expected `Tensor` but got
|
||||
# `Optional[Tensor]`.
|
||||
frame_data.depth_map,
|
||||
get_best_scale=True,
|
||||
mask=loss_mask_now,
|
||||
@@ -336,14 +314,14 @@ def eval_batch(
|
||||
if visualize:
|
||||
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
||||
if break_after_visualising:
|
||||
import pdb
|
||||
import pdb # noqa: B602
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
if lpips_model is not None:
|
||||
im1, im2 = [
|
||||
2.0 * im.clamp(0.0, 1.0) - 1.0
|
||||
for im in (image_rgb_masked, nvs_prediction.image_render)
|
||||
for im in (image_rgb_masked, cloned_render["image_render"])
|
||||
]
|
||||
results["lpips"] = lpips_model.forward(im1, im2).item()
|
||||
|
||||
@@ -426,30 +404,24 @@ def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tenso
|
||||
Returns:
|
||||
single-element Tensor
|
||||
"""
|
||||
# pyre-ignore[16] topk not recognized
|
||||
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
|
||||
|
||||
|
||||
def get_camera_difficulty_bin_edges(task: str):
|
||||
def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float, float]):
|
||||
"""
|
||||
Get the edges of camera difficulty bins.
|
||||
"""
|
||||
_eps = 1e-5
|
||||
if task == "multisequence":
|
||||
# TODO: extract those to constants
|
||||
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
|
||||
diff_bin_edges[0] = 0.0 - _eps
|
||||
elif task == "singlesequence":
|
||||
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
|
||||
else:
|
||||
raise ValueError(f"No such eval task {task}.")
|
||||
lower, upper = camera_difficulty_bin_breaks
|
||||
diff_bin_edges = torch.tensor([0.0 - _eps, lower, upper, 1.0 + _eps]).float()
|
||||
diff_bin_names = ["hard", "medium", "easy"]
|
||||
return diff_bin_edges, diff_bin_names
|
||||
|
||||
|
||||
def summarize_nvs_eval_results(
|
||||
per_batch_eval_results: List[Dict[str, Any]],
|
||||
task: str = "singlesequence",
|
||||
task: Task,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
||||
):
|
||||
"""
|
||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||
@@ -458,7 +430,8 @@ def summarize_nvs_eval_results(
|
||||
Args:
|
||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||
task: The type of the new-view synthesis task.
|
||||
Either 'singlesequence' or 'multisequence'.
|
||||
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||
|
||||
|
||||
Returns:
|
||||
nvs_results_flat: A flattened dict of all aggregate metrics.
|
||||
@@ -466,10 +439,10 @@ def summarize_nvs_eval_results(
|
||||
"""
|
||||
n_batches = len(per_batch_eval_results)
|
||||
eval_sets: List[Optional[str]] = []
|
||||
if task == "singlesequence":
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
eval_sets = [None]
|
||||
# assert n_batches==100
|
||||
elif task == "multisequence":
|
||||
elif task == Task.MULTI_SEQUENCE:
|
||||
eval_sets = ["train", "test"]
|
||||
# assert n_batches==1000
|
||||
else:
|
||||
@@ -485,17 +458,19 @@ def summarize_nvs_eval_results(
|
||||
# init the result database dict
|
||||
results = []
|
||||
|
||||
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task)
|
||||
diff_bin_edges, diff_bin_names = _get_camera_difficulty_bin_edges(
|
||||
camera_difficulty_bin_breaks
|
||||
)
|
||||
n_diff_edges = diff_bin_edges.numel()
|
||||
|
||||
# add per set averages
|
||||
for SET in eval_sets:
|
||||
if SET is None:
|
||||
# task=='singlesequence'
|
||||
assert task == Task.SINGLE_SEQUENCE
|
||||
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
||||
set_name = "test"
|
||||
else:
|
||||
# task=='multisequence'
|
||||
assert task == Task.MULTI_SEQUENCE
|
||||
ok_set = is_train == int(SET == "train")
|
||||
set_name = SET
|
||||
|
||||
@@ -520,7 +495,7 @@ def summarize_nvs_eval_results(
|
||||
}
|
||||
)
|
||||
|
||||
if task == "multisequence":
|
||||
if task == Task.MULTI_SEQUENCE:
|
||||
# split based on n_src_views
|
||||
n_src_views = batch_sizes - 1
|
||||
for n_src in EVAL_N_SRC_VIEWS:
|
||||
|
||||
88
pytorch3d/implicitron/models/base_model.py
Normal file
88
pytorch3d/implicitron/models/base_model.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .renderer.base import EvaluationMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImplicitronRender:
|
||||
"""
|
||||
Holds the tensors that describe a result of rendering.
|
||||
"""
|
||||
|
||||
depth_render: Optional[torch.Tensor] = None
|
||||
image_render: Optional[torch.Tensor] = None
|
||||
mask_render: Optional[torch.Tensor] = None
|
||||
camera_distance: Optional[torch.Tensor] = None
|
||||
|
||||
def clone(self) -> "ImplicitronRender":
|
||||
def safe_clone(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
||||
return t.detach().clone() if t is not None else None
|
||||
|
||||
return ImplicitronRender(
|
||||
depth_render=safe_clone(self.depth_render),
|
||||
image_render=safe_clone(self.image_render),
|
||||
mask_render=safe_clone(self.mask_render),
|
||||
camera_distance=safe_clone(self.camera_distance),
|
||||
)
|
||||
|
||||
|
||||
class ImplicitronModelBase(ReplaceableBase):
|
||||
"""
|
||||
Replaceable abstract base for all image generation / rendering models.
|
||||
`forward()` method produces a render with a depth map.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
image_rgb: Optional[torch.Tensor],
|
||||
camera: CamerasBase,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
mask_crop: Optional[torch.Tensor],
|
||||
depth_map: Optional[torch.Tensor],
|
||||
sequence_name: Optional[List[str]],
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Args:
|
||||
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
|
||||
the first `min(B, n_train_target_views)` images are considered targets and
|
||||
are used to supervise the renders; the rest corresponding to the source
|
||||
viewpoints from which features will be extracted.
|
||||
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
|
||||
to the viewpoints of target images, from which the rays will be sampled,
|
||||
and source images, which will be used for intersecting with target rays.
|
||||
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||
foreground masks.
|
||||
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||
regions in the input images (i.e. regions that do not correspond
|
||||
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||
"mask_sample", rays will be sampled in the non zero regions.
|
||||
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
||||
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||
from which images `image_rgb` were extracted. They are used to match
|
||||
target frames with relevant source frames.
|
||||
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||
EvaluationMode.EVALUATION which determines the settings used for
|
||||
rendering.
|
||||
|
||||
Returns:
|
||||
preds: A dictionary containing all outputs of the forward pass. All models should
|
||||
output an instance of `ImplicitronRender` in `preds["implicitron_render"]`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .feature_extractor import FeatureExtractorBase
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
|
||||
|
||||
class FeatureExtractorBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Base class for an extractor of a set of features from images.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_feat_dims(self) -> int:
|
||||
"""
|
||||
Returns:
|
||||
total number of feature dimensions of the output.
|
||||
(i.e. sum_i(dim_i))
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
imgs: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Dict[Any, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
imgs: A batch of input images of shape `(B, 3, H, W)`.
|
||||
masks: A batch of input masks of shape `(B, 3, H, W)`.
|
||||
|
||||
Returns:
|
||||
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
|
||||
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -4,7 +4,6 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@@ -12,7 +11,9 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import torch.nn.functional as Fu
|
||||
import torchvision
|
||||
from pytorch3d.implicitron.tools.config import Configurable
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
|
||||
from . import FeatureExtractorBase
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,7 +33,8 @@ _RESNET_MEAN = [0.485, 0.456, 0.406]
|
||||
_RESNET_STD = [0.229, 0.224, 0.225]
|
||||
|
||||
|
||||
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
@registry.register
|
||||
class ResNetFeatureExtractor(FeatureExtractorBase):
|
||||
"""
|
||||
Implements an image feature extractor. Depending on the settings allows
|
||||
to extract:
|
||||
@@ -141,14 +143,15 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
|
||||
return (img - self._resnet_mean) / self._resnet_std
|
||||
|
||||
def get_feat_dims(self, size_dict: bool = False):
|
||||
if size_dict:
|
||||
return copy.deepcopy(self._feat_dim)
|
||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.values)[[Na...
|
||||
def get_feat_dims(self) -> int:
|
||||
# pyre-fixme[29]
|
||||
return sum(self._feat_dim.values())
|
||||
|
||||
def forward(
|
||||
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
|
||||
self,
|
||||
imgs: Optional[torch.Tensor],
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> Dict[Any, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@@ -163,23 +166,22 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
out_feats = {}
|
||||
|
||||
imgs_input = imgs
|
||||
if self.image_rescale != 1.0:
|
||||
if self.image_rescale != 1.0 and imgs_input is not None:
|
||||
imgs_resized = Fu.interpolate(
|
||||
imgs_input,
|
||||
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
|
||||
# got `float`.
|
||||
scale_factor=self.image_rescale,
|
||||
mode="bilinear",
|
||||
)
|
||||
else:
|
||||
imgs_resized = imgs_input
|
||||
|
||||
if self.normalize_image:
|
||||
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
|
||||
if len(self.stages) > 0:
|
||||
assert imgs_resized is not None
|
||||
|
||||
if self.normalize_image:
|
||||
imgs_normed = self._resnet_normalize_image(imgs_resized)
|
||||
else:
|
||||
imgs_normed = imgs_resized
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
|
||||
# is not a function.
|
||||
feats = self.stem(imgs_normed)
|
||||
@@ -206,7 +208,7 @@ class ResNetFeatureExtractor(Configurable, torch.nn.Module):
|
||||
out_feats[MASK_FEATURE_NAME] = masks
|
||||
|
||||
if self.add_images:
|
||||
assert imgs_input is not None
|
||||
assert imgs_resized is not None
|
||||
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
|
||||
|
||||
if self.feature_rescale != 1.0:
|
||||
@@ -5,26 +5,39 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
# Note: The #noqa comments below are for unused imports of pluggable implementations
|
||||
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||
|
||||
import logging
|
||||
import math
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import tqdm
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
NewViewSynthesisPrediction,
|
||||
from pytorch3d.implicitron.models.metrics import ( # noqa
|
||||
RegularizationMetrics,
|
||||
RegularizationMetricsBase,
|
||||
ViewMetrics,
|
||||
ViewMetricsBase,
|
||||
)
|
||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
expand_args_fields,
|
||||
registry,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
|
||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from visdom import Visdom
|
||||
|
||||
from .autodecoder import Autodecoder
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .feature_extractor import FeatureExtractorBase
|
||||
from .feature_extractor.resnet_feature_extractor import ResNetFeatureExtractor # noqa
|
||||
from .global_encoder.global_encoder import GlobalEncoderBase
|
||||
from .implicit_function.base import ImplicitFunctionBase
|
||||
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
|
||||
from .implicit_function.neural_radiance_field import ( # noqa
|
||||
@@ -35,7 +48,7 @@ from .implicit_function.scene_representation_networks import ( # noqa
|
||||
SRNHyperNetImplicitFunction,
|
||||
SRNImplicitFunction,
|
||||
)
|
||||
from .metrics import ViewMetrics
|
||||
|
||||
from .renderer.base import (
|
||||
BaseRenderer,
|
||||
EvaluationMode,
|
||||
@@ -45,19 +58,16 @@ from .renderer.base import (
|
||||
)
|
||||
from .renderer.lstm_renderer import LSTMRenderer # noqa
|
||||
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
|
||||
from .renderer.ray_sampler import RaySampler
|
||||
from .renderer.ray_sampler import RaySamplerBase
|
||||
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
|
||||
from .resnet_feature_extractor import ResNetFeatureExtractor
|
||||
from .view_pooling.feature_aggregation import FeatureAggregatorBase
|
||||
from .view_pooling.view_sampling import ViewSampler
|
||||
from .view_pooler.view_pooler import ViewPooler
|
||||
|
||||
|
||||
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# pyre-ignore: 13
|
||||
class GenericModel(Configurable, torch.nn.Module):
|
||||
@registry.register
|
||||
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||
"""
|
||||
GenericModel is a wrapper for the neural implicit
|
||||
rendering and reconstruction pipeline which consists
|
||||
@@ -98,6 +108,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
------------------
|
||||
Evaluate the implicit function(s) at the sampled ray points
|
||||
(optionally pass in the aggregated image features from (4)).
|
||||
(also optionally pass in a global encoding from global_encoder).
|
||||
│
|
||||
▼
|
||||
(6) Rendering
|
||||
@@ -116,7 +127,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
this sequence of steps. Currently, steps 1, 3, 4, 5, 6
|
||||
can be customized by intializing a subclass of the appropriate
|
||||
baseclass and adding the newly created module to the registry.
|
||||
Please see https://github.com/fairinternal/pytorch3d/blob/co3d/projects/implicitron_trainer/README.md#custom-plugins
|
||||
Please see https://github.com/facebookresearch/pytorch3d/blob/main/projects/implicitron_trainer/README.md#custom-plugins
|
||||
for more details on how to create and register a custom component.
|
||||
|
||||
In the config .yaml files for experiments, the parameters below are
|
||||
@@ -138,9 +149,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
||||
splatting onto an image grid. Default: False.
|
||||
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
||||
view_pool: If True, features are sampled from the source image(s)
|
||||
at the projected 2d locations of the sampled 3d ray points from the target
|
||||
view(s), i.e. this activates step (3) above.
|
||||
num_passes: The specified implicit_function is initialized num_passes
|
||||
times and run sequentially.
|
||||
chunk_size_grid: The total number of points which can be rendered
|
||||
@@ -155,37 +163,49 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
sampling_mode_training: The sampling method to use during training. Must be
|
||||
a value from the RenderSamplingMode Enum.
|
||||
sampling_mode_evaluation: Same as above but for evaluation.
|
||||
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
|
||||
global_encoder_class_type: The name of the class to use for global_encoder,
|
||||
which must be available in the registry. Or `None` to disable global encoder.
|
||||
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
|
||||
of the image (referred to as the global_code) that can be used to model aspects of
|
||||
the scene such as multiple objects or morphing objects. It is up to the implicit
|
||||
function definition how to use it, but the most typical way is to broadcast and
|
||||
concatenate to the other inputs for the implicit function.
|
||||
raysampler_class_type: The name of the raysampler class which is available
|
||||
in the global registry.
|
||||
raysampler: An instance of RaySampler which is used to emit
|
||||
rays from the target view(s).
|
||||
renderer_class_type: The name of the renderer class which is available in the global
|
||||
registry.
|
||||
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
||||
generate the images from the target view(s).
|
||||
image_feature_extractor_class_type: If a str, constructs and enables
|
||||
the `image_feature_extractor` object of this type. Or None if not needed.
|
||||
image_feature_extractor: A module for extrating features from an input image.
|
||||
view_sampler: An instance of ViewSampler which is used for sampling of
|
||||
view_pooler_enabled: If `True`, constructs and enables the `view_pooler` object.
|
||||
This means features are sampled from the source image(s)
|
||||
at the projected 2d locations of the sampled 3d ray points from the target
|
||||
view(s), i.e. this activates step (3) above.
|
||||
view_pooler: An instance of ViewPooler which is used for sampling of
|
||||
image-based features at the 2D projections of a set
|
||||
of 3D points.
|
||||
feature_aggregator_class_type: The name of the feature aggregator class which
|
||||
is available in the global registry.
|
||||
feature_aggregator: A feature aggregator class which inherits from
|
||||
FeatureAggregatorBase. Typically, the aggregated features and their
|
||||
masks are output by a `ViewSampler` which samples feature tensors extracted
|
||||
from a set of source images. FeatureAggregator executes step (4) above.
|
||||
of 3D points and aggregating the sampled features.
|
||||
implicit_function_class_type: The type of implicit function to use which
|
||||
is available in the global registry.
|
||||
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
|
||||
are initialised to be in self._implicit_functions.
|
||||
view_metrics: An instance of ViewMetricsBase used to compute loss terms which
|
||||
are independent of the model's parameters.
|
||||
view_metrics_class_type: The type of view metrics to use, must be available in
|
||||
the global registry.
|
||||
regularization_metrics: An instance of RegularizationMetricsBase used to compute
|
||||
regularization terms which can depend on the model's parameters.
|
||||
regularization_metrics_class_type: The type of regularization metrics to use,
|
||||
must be available in the global registry.
|
||||
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
||||
for `ViewMetrics` class for available loss functions.
|
||||
log_vars: A list of variable names which should be logged.
|
||||
The names should correspond to a subset of the keys of the
|
||||
dict `preds` output by the `forward` function.
|
||||
"""
|
||||
""" # noqa: B950
|
||||
|
||||
mask_images: bool = True
|
||||
mask_depths: bool = True
|
||||
@@ -194,7 +214,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
mask_threshold: float = 0.5
|
||||
output_rasterized_mc: bool = False
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
view_pool: bool = False
|
||||
num_passes: int = 1
|
||||
chunk_size_grid: int = 4096
|
||||
render_features_dimensions: int = 3
|
||||
@@ -204,23 +223,25 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
sampling_mode_training: str = "mask_sample"
|
||||
sampling_mode_evaluation: str = "full_grid"
|
||||
|
||||
# ---- autodecoder settings
|
||||
sequence_autodecoder: Autodecoder
|
||||
# ---- global encoder settings
|
||||
global_encoder_class_type: Optional[str] = None
|
||||
global_encoder: Optional[GlobalEncoderBase]
|
||||
|
||||
# ---- raysampler
|
||||
raysampler: RaySampler
|
||||
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||
raysampler: RaySamplerBase
|
||||
|
||||
# ---- renderer configs
|
||||
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||
renderer: BaseRenderer
|
||||
|
||||
# ---- view sampling settings - used if view_pool=True
|
||||
# (This is only created if view_pool is False)
|
||||
image_feature_extractor: ResNetFeatureExtractor
|
||||
view_sampler: ViewSampler
|
||||
# ---- ---- view sampling feature aggregator settings
|
||||
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
||||
feature_aggregator: FeatureAggregatorBase
|
||||
# ---- image feature extractor settings
|
||||
# (This is only created if view_pooler is enabled)
|
||||
image_feature_extractor: Optional[FeatureExtractorBase]
|
||||
image_feature_extractor_class_type: Optional[str] = None
|
||||
# ---- view pooler settings
|
||||
view_pooler_enabled: bool = False
|
||||
view_pooler: Optional[ViewPooler]
|
||||
|
||||
# ---- implicit function settings
|
||||
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||
@@ -228,6 +249,13 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
# The actual implicit functions live in self._implicit_functions
|
||||
implicit_function: ImplicitFunctionBase
|
||||
|
||||
# ----- metrics
|
||||
view_metrics: ViewMetricsBase
|
||||
view_metrics_class_type: str = "ViewMetrics"
|
||||
|
||||
regularization_metrics: RegularizationMetricsBase
|
||||
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||
|
||||
# ---- loss weights
|
||||
loss_weights: Dict[str, float] = field(
|
||||
default_factory=lambda: {
|
||||
@@ -259,19 +287,21 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
"loss_prev_stage_rgb_psnr_fg",
|
||||
"loss_prev_stage_rgb_psnr",
|
||||
"loss_prev_stage_mask_bce",
|
||||
*STD_LOG_VARS,
|
||||
# basic metrics
|
||||
"objective",
|
||||
"epoch",
|
||||
"sec/it",
|
||||
]
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self.view_metrics = ViewMetrics()
|
||||
|
||||
self._check_and_preprocess_renderer_configs()
|
||||
self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
|
||||
self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
|
||||
self.raysampler_args["image_width"] = self.render_image_width
|
||||
self.raysampler_args["image_height"] = self.render_image_height
|
||||
if self.view_pooler_enabled:
|
||||
if self.image_feature_extractor_class_type is None:
|
||||
raise ValueError(
|
||||
"image_feature_extractor must be present for view pooling."
|
||||
)
|
||||
run_auto_creation(self)
|
||||
|
||||
self._implicit_functions = self._construct_implicit_functions()
|
||||
@@ -283,10 +313,11 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
*, # force keyword-only arguments
|
||||
image_rgb: Optional[torch.Tensor],
|
||||
camera: CamerasBase,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
mask_crop: Optional[torch.Tensor],
|
||||
depth_map: Optional[torch.Tensor],
|
||||
sequence_name: Optional[List[str]],
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
sequence_name: Optional[List[str]] = None,
|
||||
frame_timestamp: Optional[torch.Tensor] = None,
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -309,6 +340,8 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||
from which images `image_rgb` were extracted. They are used to match
|
||||
target frames with relevant source frames.
|
||||
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
|
||||
of frame timestamps.
|
||||
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||
EvaluationMode.EVALUATION which determines the settings used for
|
||||
rendering.
|
||||
@@ -333,6 +366,13 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
else min(self.n_train_target_views, batch_size)
|
||||
)
|
||||
|
||||
# A helper function for selecting n_target first elements from the input
|
||||
# where the latter can be None.
|
||||
def safe_slice_targets(
|
||||
tensor: Optional[Union[torch.Tensor, List[str]]],
|
||||
) -> Optional[Union[torch.Tensor, List[str]]]:
|
||||
return None if tensor is None else tensor[:n_targets]
|
||||
|
||||
# Select the target cameras.
|
||||
target_cameras = camera[list(range(n_targets))]
|
||||
|
||||
@@ -344,7 +384,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
)
|
||||
|
||||
# (1) Sample rendering rays with the ray sampler.
|
||||
ray_bundle: RayBundle = self.raysampler(
|
||||
ray_bundle: RayBundle = self.raysampler( # pyre-fixme[29]
|
||||
target_cameras,
|
||||
evaluation_mode,
|
||||
mask=mask_crop[:n_targets]
|
||||
@@ -355,38 +395,37 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
# custom_args hold additional arguments to the implicit function.
|
||||
custom_args = {}
|
||||
|
||||
if self.view_pool:
|
||||
if sequence_name is None:
|
||||
raise ValueError("sequence_name must be provided for view pooling")
|
||||
if self.image_feature_extractor is not None:
|
||||
# (2) Extract features for the image
|
||||
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
|
||||
else:
|
||||
img_feats = None
|
||||
|
||||
# (3) Sample features and masks at the ray points
|
||||
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
|
||||
pts=pts,
|
||||
seq_id_pts=sequence_name[:n_targets],
|
||||
camera=camera,
|
||||
seq_id_camera=sequence_name,
|
||||
feats=img_feats,
|
||||
masks=mask_crop,
|
||||
) # returns feats_sampled, masks_sampled
|
||||
if self.view_pooler_enabled:
|
||||
if sequence_name is None:
|
||||
raise ValueError("sequence_name must be provided for view pooling")
|
||||
assert img_feats is not None
|
||||
|
||||
# (4) Aggregate features from multiple views
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
|
||||
*curried_view_sampler(pts=pts),
|
||||
pts=pts,
|
||||
camera=camera,
|
||||
) # TODO: do we need to pass a callback rather than compute here?
|
||||
# precomputing will be faster for 2 passes
|
||||
# -> but this is important for non-nerf
|
||||
custom_args["fun_viewpool"] = curried_view_pool
|
||||
# (3-4) Sample features and masks at the ray points.
|
||||
# Aggregate features from multiple views.
|
||||
def curried_viewpooler(pts):
|
||||
return self.view_pooler(
|
||||
pts=pts,
|
||||
seq_id_pts=sequence_name[:n_targets],
|
||||
camera=camera,
|
||||
seq_id_camera=sequence_name,
|
||||
feats=img_feats,
|
||||
masks=mask_crop,
|
||||
)
|
||||
|
||||
custom_args["fun_viewpool"] = curried_viewpooler
|
||||
|
||||
global_code = None
|
||||
if self.sequence_autodecoder.n_instances > 0:
|
||||
if sequence_name is None:
|
||||
raise ValueError("sequence_name must be provided for autodecoder.")
|
||||
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
|
||||
if self.global_encoder is not None:
|
||||
global_code = self.global_encoder( # pyre-fixme[29]
|
||||
sequence_name=safe_slice_targets(sequence_name),
|
||||
frame_timestamp=safe_slice_targets(frame_timestamp),
|
||||
)
|
||||
custom_args["global_code"] = global_code
|
||||
|
||||
# pyre-fixme[29]:
|
||||
@@ -422,15 +461,26 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
for func in self._implicit_functions:
|
||||
func.unbind_args()
|
||||
|
||||
preds = self._get_view_metrics(
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
|
||||
depth_map=None if depth_map is None else depth_map[:n_targets],
|
||||
fg_probability=None
|
||||
if fg_probability is None
|
||||
else fg_probability[:n_targets],
|
||||
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
|
||||
# A dict to store losses as well as rendering results.
|
||||
preds: Dict[str, Any] = {}
|
||||
|
||||
preds.update(
|
||||
self.view_metrics(
|
||||
results=preds,
|
||||
raymarched=rendered,
|
||||
xys=ray_bundle.xys,
|
||||
image_rgb=safe_slice_targets(image_rgb),
|
||||
depth_map=safe_slice_targets(depth_map),
|
||||
fg_probability=safe_slice_targets(fg_probability),
|
||||
mask_crop=safe_slice_targets(mask_crop),
|
||||
)
|
||||
)
|
||||
|
||||
preds.update(
|
||||
self.regularization_metrics(
|
||||
results=preds,
|
||||
model=self,
|
||||
)
|
||||
)
|
||||
|
||||
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
||||
@@ -452,7 +502,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
||||
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
||||
|
||||
preds["nvs_prediction"] = NewViewSynthesisPrediction(
|
||||
preds["implicitron_render"] = ImplicitronRender(
|
||||
image_render=preds["images_render"],
|
||||
depth_render=preds["depths_render"],
|
||||
mask_render=preds["masks_render"],
|
||||
@@ -460,11 +510,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
else:
|
||||
raise AssertionError("Unreachable state")
|
||||
|
||||
# calc the AD penalty, returns None if autodecoder is not active
|
||||
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
|
||||
if ad_penalty is not None:
|
||||
preds["loss_autodecoder_norm"] = ad_penalty
|
||||
|
||||
# (7) Compute losses
|
||||
# finally get the optimization objective using self.loss_weights
|
||||
objective = self._get_objective(preds)
|
||||
@@ -559,38 +604,65 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _get_viewpooled_feature_dim(self):
|
||||
return (
|
||||
self.feature_aggregator.get_aggregated_feature_dim(
|
||||
self.image_feature_extractor.get_feat_dims()
|
||||
)
|
||||
if self.view_pool
|
||||
else 0
|
||||
def _get_global_encoder_encoding_dim(self) -> int:
|
||||
if self.global_encoder is None:
|
||||
return 0
|
||||
return self.global_encoder.get_encoding_dim()
|
||||
|
||||
def _get_viewpooled_feature_dim(self) -> int:
|
||||
if self.view_pooler is None:
|
||||
return 0
|
||||
assert self.image_feature_extractor is not None
|
||||
return self.view_pooler.get_aggregated_feature_dim(
|
||||
self.image_feature_extractor.get_feat_dims()
|
||||
)
|
||||
|
||||
def _check_and_preprocess_renderer_configs(self):
|
||||
def create_raysampler(self):
|
||||
raysampler_args = getattr(
|
||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||
)
|
||||
setattr_if_hasattr(
|
||||
raysampler_args, "sampling_mode_training", self.sampling_mode_training
|
||||
)
|
||||
setattr_if_hasattr(
|
||||
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
|
||||
)
|
||||
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
|
||||
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
|
||||
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||
**raysampler_args
|
||||
)
|
||||
|
||||
def create_renderer(self):
|
||||
raysampler_args = getattr(
|
||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||
)
|
||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||
"stratified_sampling_coarse_training"
|
||||
] = self.raysampler_args["stratified_point_sampling_training"]
|
||||
] = raysampler_args["stratified_point_sampling_training"]
|
||||
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||
"stratified_sampling_coarse_evaluation"
|
||||
] = self.raysampler_args["stratified_point_sampling_evaluation"]
|
||||
] = raysampler_args["stratified_point_sampling_evaluation"]
|
||||
self.renderer_SignedDistanceFunctionRenderer_args[
|
||||
"render_features_dimensions"
|
||||
] = self.render_features_dimensions
|
||||
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
||||
"object_bounding_sphere"
|
||||
] = self.raysampler_args["scene_extent"]
|
||||
|
||||
def create_image_feature_extractor(self):
|
||||
"""
|
||||
Custom creation function called by run_auto_creation so that the
|
||||
image_feature_extractor is not created if it is not be needed.
|
||||
"""
|
||||
if self.view_pool:
|
||||
self.image_feature_extractor = ResNetFeatureExtractor(
|
||||
**self.image_feature_extractor_args
|
||||
)
|
||||
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||
if "scene_extent" not in raysampler_args:
|
||||
raise ValueError(
|
||||
"SignedDistanceFunctionRenderer requires"
|
||||
+ " a raysampler that defines the 'scene_extent' field"
|
||||
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||
)
|
||||
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
||||
"object_bounding_sphere"
|
||||
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
|
||||
|
||||
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||
**renderer_args
|
||||
)
|
||||
|
||||
def create_implicit_function(self) -> None:
|
||||
"""
|
||||
@@ -613,8 +685,7 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
||||
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
||||
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
||||
self._get_viewpooled_feature_dim()
|
||||
+ self.sequence_autodecoder.get_encoding_dim()
|
||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||
)
|
||||
nerf_args["color_dim"] = nerformer_args[
|
||||
"color_dim"
|
||||
@@ -623,27 +694,25 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
# idr preprocessing
|
||||
idr = self.implicit_function_IdrFeatureField_args
|
||||
idr["feature_vector_size"] = self.render_features_dimensions
|
||||
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
|
||||
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||
|
||||
# srn preprocessing
|
||||
srn = self.implicit_function_SRNImplicitFunction_args
|
||||
srn.raymarch_function_args.latent_dim = (
|
||||
self._get_viewpooled_feature_dim()
|
||||
+ self.sequence_autodecoder.get_encoding_dim()
|
||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||
)
|
||||
|
||||
# srn_hypernet preprocessing
|
||||
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
||||
srn_hypernet_args = srn_hypernet.hypernet_args
|
||||
srn_hypernet_args.latent_dim_hypernet = (
|
||||
self.sequence_autodecoder.get_encoding_dim()
|
||||
)
|
||||
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
||||
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
||||
|
||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||
implicit_function_type = registry.get(
|
||||
ImplicitFunctionBase, self.implicit_function_class_type
|
||||
)
|
||||
expand_args_fields(implicit_function_type)
|
||||
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
|
||||
raise ValueError(
|
||||
self.implicit_function_class_type
|
||||
@@ -651,10 +720,9 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
)
|
||||
|
||||
if implicit_function_type.requires_pooling_without_aggregation():
|
||||
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
|
||||
if not self.view_pool or has_aggregation:
|
||||
if self.view_pooler_enabled and self.view_pooler.has_aggregation():
|
||||
raise ValueError(
|
||||
"Chosen implicit function requires view pooling without aggregation."
|
||||
"The chosen implicit function requires view pooling without aggregation."
|
||||
)
|
||||
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
|
||||
config = getattr(self, config_name, None)
|
||||
@@ -697,6 +765,17 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
Returns:
|
||||
Modified image_rgb, fg_mask, depth_map
|
||||
"""
|
||||
if image_rgb is not None and image_rgb.ndim == 3:
|
||||
# The FrameData object is used for both frames and batches of frames,
|
||||
# and a user might get this error if those were confused.
|
||||
# Perhaps a user has a FrameData `fd` representing a single frame and
|
||||
# wrote something like `model(**fd)` instead of
|
||||
# `model(**fd.collate([fd]))`.
|
||||
raise ValueError(
|
||||
"Model received unbatched inputs. "
|
||||
+ "Perhaps they came from a FrameData which had not been collated."
|
||||
)
|
||||
|
||||
fg_mask = fg_probability
|
||||
if fg_mask is not None and self.mask_threshold > 0.0:
|
||||
# threshold masks
|
||||
@@ -720,45 +799,6 @@ class GenericModel(Configurable, torch.nn.Module):
|
||||
|
||||
return image_rgb, fg_mask, depth_map
|
||||
|
||||
def _get_view_metrics(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
):
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
metrics = self.view_metrics(
|
||||
image_sampling_grid=xys,
|
||||
images_pred=raymarched.features,
|
||||
images=image_rgb,
|
||||
depths_pred=raymarched.depths,
|
||||
depths=depth_map,
|
||||
masks_pred=raymarched.masks,
|
||||
masks=fg_probability,
|
||||
masks_crop=mask_crop,
|
||||
keys_prefix=keys_prefix,
|
||||
**raymarched.aux,
|
||||
)
|
||||
|
||||
if raymarched.prev_stage:
|
||||
metrics.update(
|
||||
self._get_view_metrics(
|
||||
raymarched.prev_stage,
|
||||
xys,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
mask_crop,
|
||||
keys_prefix=(keys_prefix + "prev_stage_"),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
@torch.no_grad()
|
||||
def _rasterize_mc_samples(
|
||||
self,
|
||||
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
5
pytorch3d/implicitron/models/global_encoder/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -12,10 +12,9 @@ import torch
|
||||
from pytorch3d.implicitron.tools.config import Configurable
|
||||
|
||||
|
||||
# TODO: probabilistic embeddings?
|
||||
class Autodecoder(Configurable, torch.nn.Module):
|
||||
"""
|
||||
Autodecoder module
|
||||
Autodecoder which maps a list of integer or string keys to optimizable embeddings.
|
||||
|
||||
Settings:
|
||||
encoding_dim: Embedding dimension for the decoder.
|
||||
@@ -43,37 +42,37 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
# weight has been initialised from Normal(0, 1)
|
||||
self._autodecoder_codes.weight *= self.init_scale
|
||||
|
||||
self._sequence_map = self._build_sequence_map()
|
||||
self._key_map = self._build_key_map()
|
||||
# Make sure to register hooks for correct handling of saving/loading
|
||||
# the module's _sequence_map.
|
||||
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
|
||||
self._register_state_dict_hook(_save_sequence_map_hook)
|
||||
# the module's _key_map.
|
||||
self._register_load_state_dict_pre_hook(self._load_key_map_hook)
|
||||
self._register_state_dict_hook(_save_key_map_hook)
|
||||
|
||||
def _build_sequence_map(
|
||||
self, sequence_map_dict: Optional[Dict[str, int]] = None
|
||||
def _build_key_map(
|
||||
self, key_map_dict: Optional[Dict[str, int]] = None
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Args:
|
||||
sequence_map_dict: A dictionary used to initialize the sequence_map.
|
||||
key_map_dict: A dictionary used to initialize the key_map.
|
||||
|
||||
Returns:
|
||||
sequence_map: a dictionary of key: id pairs.
|
||||
key_map: a dictionary of key: id pairs.
|
||||
"""
|
||||
# increments the counter when asked for a new value
|
||||
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
|
||||
if sequence_map_dict is not None:
|
||||
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
|
||||
key_map = defaultdict(iter(range(self.n_instances)).__next__)
|
||||
if key_map_dict is not None:
|
||||
# Assign all keys from the loaded key_map_dict to self._key_map.
|
||||
# Since this is done in the original order, it should generate
|
||||
# the same set of key:id pairs. We check this with an assert to be sure.
|
||||
for x, x_id in sequence_map_dict.items():
|
||||
x_id_ = sequence_map[x]
|
||||
for x, x_id in key_map_dict.items():
|
||||
x_id_ = key_map[x]
|
||||
assert x_id == x_id_
|
||||
return sequence_map
|
||||
return key_map
|
||||
|
||||
def calc_squared_encoding_norm(self):
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
if self.n_instances <= 0:
|
||||
return None
|
||||
return (self._autodecoder_codes.weight ** 2).mean()
|
||||
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
||||
|
||||
def get_encoding_dim(self) -> int:
|
||||
if self.n_instances <= 0:
|
||||
@@ -83,13 +82,13 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: A batch of `N` sequence identifiers. Either a long tensor of size
|
||||
x: A batch of `N` identifiers. Either a long tensor of size
|
||||
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
|
||||
are hashed to codes (without collisions).
|
||||
|
||||
Returns:
|
||||
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
||||
sequence-specific autodecoder codes.
|
||||
key-specific autodecoder codes.
|
||||
"""
|
||||
if self.n_instances == 0:
|
||||
return None
|
||||
@@ -99,19 +98,21 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
|
||||
if isinstance(x[0], str):
|
||||
try:
|
||||
# pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
|
||||
# `Tensor`.
|
||||
x = torch.tensor(
|
||||
# pyre-ignore[29]
|
||||
[self._sequence_map[elem] for elem in x],
|
||||
[self._key_map[elem] for elem in x],
|
||||
dtype=torch.long,
|
||||
device=next(self.parameters()).device,
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError("Not enough n_instances in the autodecoder")
|
||||
raise ValueError("Not enough n_instances in the autodecoder") from None
|
||||
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
return self._autodecoder_codes(x)
|
||||
|
||||
def _load_sequence_map_hook(
|
||||
def _load_key_map_hook(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
@@ -140,20 +141,18 @@ class Autodecoder(Configurable, torch.nn.Module):
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
|
||||
Returns:
|
||||
Constructed sequence_map if it exists in the state_dict
|
||||
Constructed key_map if it exists in the state_dict
|
||||
else raises a warning only.
|
||||
"""
|
||||
sequence_map_key = prefix + "_sequence_map"
|
||||
if sequence_map_key in state_dict:
|
||||
sequence_map_dict = state_dict.pop(sequence_map_key)
|
||||
self._sequence_map = self._build_sequence_map(
|
||||
sequence_map_dict=sequence_map_dict
|
||||
)
|
||||
key_map_key = prefix + "_key_map"
|
||||
if key_map_key in state_dict:
|
||||
key_map_dict = state_dict.pop(key_map_key)
|
||||
self._key_map = self._build_key_map(key_map_dict=key_map_dict)
|
||||
else:
|
||||
warnings.warn("No sequence map in Autodecoder state dict!")
|
||||
warnings.warn("No key map in Autodecoder state dict!")
|
||||
|
||||
|
||||
def _save_sequence_map_hook(
|
||||
def _save_key_map_hook(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
@@ -167,6 +166,6 @@ def _save_sequence_map_hook(
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
"""
|
||||
sequence_map_key = prefix + "_sequence_map"
|
||||
sequence_map_dict = dict(self._sequence_map.items())
|
||||
state_dict[sequence_map_key] = sequence_map_dict
|
||||
key_map_key = prefix + "_key_map"
|
||||
key_map_dict = dict(self._key_map.items())
|
||||
state_dict[key_map_key] = key_map_dict
|
||||
111
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
111
pytorch3d/implicitron/models/global_encoder/global_encoder.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
|
||||
from .autodecoder import Autodecoder
|
||||
|
||||
|
||||
class GlobalEncoderBase(ReplaceableBase):
|
||||
"""
|
||||
A base class for implementing encoders of global frame-specific quantities.
|
||||
|
||||
The latter includes e.g. the harmonic encoding of a frame timestamp
|
||||
(`HarmonicTimeEncoder`), or an autodecoder encoding of the frame's sequence
|
||||
(`SequenceAutodecoder`).
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_encoding_dim(self):
|
||||
"""
|
||||
Returns the dimensionality of the returned encoding.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Calculates the squared norm of the encoding to report as the
|
||||
`autodecoder_norm` loss of the model, as a zero dimensional tensor.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Given a set of inputs to encode, generates a tensor containing the encoding.
|
||||
|
||||
Returns:
|
||||
encoding: The tensor containing the global encoding.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
# TODO: probabilistic embeddings?
|
||||
@registry.register
|
||||
class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 13
|
||||
"""
|
||||
A global encoder implementation which provides an autodecoder encoding
|
||||
of the frame's sequence identifier.
|
||||
"""
|
||||
|
||||
autodecoder: Autodecoder
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_encoding_dim(self):
|
||||
return self.autodecoder.get_encoding_dim()
|
||||
|
||||
def forward(
|
||||
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
|
||||
) -> torch.Tensor:
|
||||
|
||||
# run dtype checks and pass sequence_name to self.autodecoder
|
||||
return self.autodecoder(sequence_name)
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
return self.autodecoder.calculate_squared_encoding_norm()
|
||||
|
||||
|
||||
@registry.register
|
||||
class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
||||
"""
|
||||
A global encoder implementation which provides harmonic embeddings
|
||||
of each frame's timestamp.
|
||||
"""
|
||||
|
||||
n_harmonic_functions: int = 10
|
||||
append_input: bool = True
|
||||
time_divisor: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self._harmonic_embedding = HarmonicEmbedding(
|
||||
n_harmonic_functions=self.n_harmonic_functions,
|
||||
append_input=self.append_input,
|
||||
)
|
||||
|
||||
def get_encoding_dim(self):
|
||||
return self._harmonic_embedding.get_output_dim(1)
|
||||
|
||||
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
if frame_timestamp.shape[-1] != 1:
|
||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||
time = frame_timestamp / self.time_divisor
|
||||
return self._harmonic_embedding(time) # pyre-ignore: 29
|
||||
|
||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||
return None
|
||||
@@ -3,7 +3,7 @@
|
||||
# implicit_differentiable_renderer.py
|
||||
# Copyright (c) 2020 Lior Yariv
|
||||
import math
|
||||
from typing import Sequence
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
@@ -15,13 +15,48 @@ from .base import ImplicitFunctionBase
|
||||
|
||||
@registry.register
|
||||
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
"""
|
||||
Implicit function as used in http://github.com/lioryariv/idr.
|
||||
|
||||
Members:
|
||||
d_in: dimension of the input point.
|
||||
n_harmonic_functions_xyz: If -1, do not embed the point.
|
||||
If >=0, use a harmonic embedding with this number of
|
||||
harmonic functions. (The harmonic embedding includes the input
|
||||
itself, so a value of 0 means the point is used but without
|
||||
any harmonic functions.)
|
||||
d_out and feature_vector_size: Sum of these is the output
|
||||
dimension. This implicit function thus returns a concatenation
|
||||
of `d_out` signed distance function values and `feature_vector_size`
|
||||
features (such as colors). When used in `GenericModel`,
|
||||
`feature_vector_size` corresponds is automatically set to
|
||||
`render_features_dimensions`.
|
||||
dims: list of hidden layer sizes.
|
||||
geometric_init: whether to use custom weight initialization
|
||||
in linear layers. If False, pytorch default (uniform sampling)
|
||||
is used.
|
||||
bias: if geometric_init=True, initial value for bias subtracted
|
||||
in the last layer.
|
||||
skip_in: List of indices of layers that receive as input the initial
|
||||
value concatenated with the output of the previous layers.
|
||||
weight_norm: whether to apply weight normalization to each layer.
|
||||
pooled_feature_dim: If view pooling is in use (provided as
|
||||
fun_viewpool to forward()) this must be its number of features.
|
||||
Otherwise this must be set to 0. (If used from GenericModel,
|
||||
this config value will be overridden automatically.)
|
||||
encoding_dim: If global coding is in use (provided as global_code
|
||||
to forward()) this must be its number of featuress.
|
||||
Otherwise this must be set to 0. (If used from GenericModel,
|
||||
this config value will be overridden automatically.)
|
||||
"""
|
||||
|
||||
feature_vector_size: int = 3
|
||||
d_in: int = 3
|
||||
d_out: int = 1
|
||||
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
|
||||
dims: Tuple[int, ...] = (512, 512, 512, 512, 512, 512, 512, 512)
|
||||
geometric_init: bool = True
|
||||
bias: float = 1.0
|
||||
skip_in: Sequence[int] = ()
|
||||
skip_in: Tuple[int, ...] = ()
|
||||
weight_norm: bool = True
|
||||
n_harmonic_functions_xyz: int = 0
|
||||
pooled_feature_dim: int = 0
|
||||
@@ -33,7 +68,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
||||
|
||||
self.embed_fn = None
|
||||
if self.n_harmonic_functions_xyz > 0:
|
||||
if self.n_harmonic_functions_xyz >= 0:
|
||||
self.embed_fn = HarmonicEmbedding(
|
||||
self.n_harmonic_functions_xyz, append_input=True
|
||||
)
|
||||
@@ -59,23 +94,23 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
if layer_idx == self.num_layers - 2:
|
||||
torch.nn.init.normal_(
|
||||
lin.weight,
|
||||
mean=math.pi ** 0.5 / dims[layer_idx] ** 0.5,
|
||||
mean=math.pi**0.5 / dims[layer_idx] ** 0.5,
|
||||
std=0.0001,
|
||||
)
|
||||
torch.nn.init.constant_(lin.bias, -self.bias)
|
||||
elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0:
|
||||
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
||||
torch.nn.init.normal_(
|
||||
lin.weight[:, :3], 0.0, 2 ** 0.5 / out_dim ** 0.5
|
||||
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
|
||||
)
|
||||
elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in:
|
||||
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
|
||||
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
|
||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
||||
else:
|
||||
torch.nn.init.constant_(lin.bias, 0.0)
|
||||
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
|
||||
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
|
||||
|
||||
if self.weight_norm:
|
||||
lin = nn.utils.weight_norm(lin)
|
||||
@@ -103,38 +138,43 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
self.embed_fn is None and fun_viewpool is None and global_code is None
|
||||
):
|
||||
return torch.tensor(
|
||||
[], device=rays_points_world.device, dtype=rays_points_world.dtype
|
||||
[],
|
||||
device=rays_points_world.device,
|
||||
dtype=rays_points_world.dtype
|
||||
# pyre-fixme[6]: For 2nd param expected `int` but got `Union[Module,
|
||||
# Tensor]`.
|
||||
).view(0, self.out_dim)
|
||||
|
||||
embedding = None
|
||||
embeddings = []
|
||||
if self.embed_fn is not None:
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
embedding = self.embed_fn(rays_points_world)
|
||||
embeddings.append(self.embed_fn(rays_points_world))
|
||||
|
||||
if fun_viewpool is not None:
|
||||
assert rays_points_world.ndim == 2
|
||||
pooled_feature = fun_viewpool(rays_points_world[None])
|
||||
# TODO: pooled features are 4D!
|
||||
embedding = torch.cat((embedding, pooled_feature), dim=-1)
|
||||
embeddings.append(pooled_feature)
|
||||
|
||||
if global_code is not None:
|
||||
assert embedding.ndim == 2
|
||||
assert global_code.shape[0] == 1 # TODO: generalize to batches!
|
||||
# This will require changing raytracer code
|
||||
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
|
||||
embedding = torch.cat(
|
||||
(embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)),
|
||||
dim=-1,
|
||||
embeddings.append(
|
||||
global_code[0, None, :].expand(rays_points_world.shape[0], -1)
|
||||
)
|
||||
|
||||
embedding = torch.cat(embeddings, dim=-1)
|
||||
x = embedding
|
||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s...
|
||||
for layer_idx in range(self.num_layers - 1):
|
||||
if layer_idx in self.skip_in:
|
||||
x = torch.cat([x, embedding], dim=-1) / 2 ** 0.5
|
||||
x = torch.cat([x, embedding], dim=-1) / 2**0.5
|
||||
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
x = self.linear_layers[layer_idx](x)
|
||||
|
||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
|
||||
if layer_idx < self.num_layers - 2:
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
x = self.softplus(x)
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from dataclasses import field
|
||||
from typing import List, Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||
@@ -206,7 +205,7 @@ class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
|
||||
transformer_dim_down_factor: float = 1.0
|
||||
n_hidden_neurons_xyz: int = 256
|
||||
n_layers_xyz: int = 8
|
||||
append_xyz: List[int] = field(default_factory=lambda: [5])
|
||||
append_xyz: Tuple[int, ...] = (5,)
|
||||
|
||||
def _construct_xyz_encoder(self, input_dim: int):
|
||||
return MLPWithInputSkips(
|
||||
@@ -224,7 +223,7 @@ class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
|
||||
transformer_dim_down_factor: float = 2.0
|
||||
n_hidden_neurons_xyz: int = 80
|
||||
n_layers_xyz: int = 2
|
||||
append_xyz: List[int] = field(default_factory=lambda: [1])
|
||||
append_xyz: Tuple[int, ...] = (1,)
|
||||
|
||||
def _construct_xyz_encoder(self, input_dim: int):
|
||||
return TransformerWithInputSkips(
|
||||
@@ -286,7 +285,7 @@ class MLPWithInputSkips(torch.nn.Module):
|
||||
output_dim: int = 256,
|
||||
skip_dim: int = 39,
|
||||
hidden_dim: int = 256,
|
||||
input_skips: List[int] = [5],
|
||||
input_skips: Tuple[int, ...] = (5,),
|
||||
skip_affine_trans: bool = False,
|
||||
no_last_relu=False,
|
||||
):
|
||||
@@ -362,7 +361,7 @@ class TransformerWithInputSkips(torch.nn.Module):
|
||||
output_dim: int = 256,
|
||||
skip_dim: int = 39,
|
||||
hidden_dim: int = 64,
|
||||
input_skips: List[int] = [5],
|
||||
input_skips: Tuple[int, ...] = (5,),
|
||||
dim_down_factor: float = 1,
|
||||
):
|
||||
"""
|
||||
@@ -386,7 +385,7 @@ class TransformerWithInputSkips(torch.nn.Module):
|
||||
layers_pool, layers_ray = [], []
|
||||
dimout = 0
|
||||
for layeri in range(n_layers):
|
||||
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
|
||||
dimin = int(round(hidden_dim / (dim_down_factor**layeri)))
|
||||
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
|
||||
logger.info(f"Tr: {dimin} -> {dimout}")
|
||||
for _i, l in enumerate((layers_pool, layers_ray)):
|
||||
@@ -406,6 +405,8 @@ class TransformerWithInputSkips(torch.nn.Module):
|
||||
self.last = torch.nn.Linear(dimout, output_dim)
|
||||
_xavier_init(self.last)
|
||||
|
||||
# pyre-fixme[8]: Attribute has type `Tuple[ModuleList, ModuleList]`; used as
|
||||
# `ModuleList`.
|
||||
self.layers_pool, self.layers_ray = (
|
||||
torch.nn.ModuleList(layers_pool),
|
||||
torch.nn.ModuleList(layers_ray),
|
||||
|
||||
@@ -6,63 +6,180 @@
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools import metric_utils as utils
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer import utils as rend_utils
|
||||
|
||||
from .renderer.base import RendererOutput
|
||||
|
||||
|
||||
class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Replaceable abstract base for regularization metrics.
|
||||
`forward()` method produces regularization metrics and (unlike ViewMetrics) can
|
||||
depend on the model's parameters.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates various regularization terms useful for supervising differentiable
|
||||
rendering pipelines.
|
||||
|
||||
Args:
|
||||
model: A model instance. Useful, for example, to implement
|
||||
weights-based regularization.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all regularization metrics.
|
||||
|
||||
Returns:
|
||||
A dictionary with the resulting regularization metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ViewMetricsBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Replaceable abstract base for model metrics.
|
||||
`forward()` method produces losses and other metrics.
|
||||
"""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
class ViewMetrics(torch.nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
image_sampling_grid: torch.Tensor,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
images_pred: Optional[torch.Tensor] = None,
|
||||
depths: Optional[torch.Tensor] = None,
|
||||
depths_pred: Optional[torch.Tensor] = None,
|
||||
masks: Optional[torch.Tensor] = None,
|
||||
masks_pred: Optional[torch.Tensor] = None,
|
||||
masks_crop: Optional[torch.Tensor] = None,
|
||||
grad_theta: Optional[torch.Tensor] = None,
|
||||
density_grid: Optional[torch.Tensor] = None,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
mask_renders_by_pred: bool = False,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates various metrics and loss functions useful for supervising
|
||||
differentiable rendering pipelines. Any additional parameters can be passed
|
||||
in the `raymarched.aux` dictionary.
|
||||
|
||||
Args:
|
||||
results: A dictionary with the resulting view metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
raymarched: Output of the renderer.
|
||||
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||
the predictions are defined. All ground truth inputs are sampled at
|
||||
these locations in order to extract values that correspond to the
|
||||
predictions.
|
||||
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||
values.
|
||||
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||
values.
|
||||
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||
foreground masks.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all view metrics.
|
||||
|
||||
Returns:
|
||||
A dictionary with the resulting view metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class RegularizationMetrics(RegularizationMetricsBase):
|
||||
def forward(
|
||||
self, model: Any, keys_prefix: str = "loss_", **kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates the AD penalty, or returns an empty dict if the model's autoencoder
|
||||
is inactive.
|
||||
|
||||
Args:
|
||||
model: A model instance.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all regularization metrics.
|
||||
|
||||
Returns:
|
||||
A dictionary with the resulting regularization metrics. The items
|
||||
will have form `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
|
||||
The calculated metric is:
|
||||
autoencoder_norm: Autoencoder weight norm regularization term.
|
||||
"""
|
||||
metrics = {}
|
||||
if getattr(model, "sequence_autodecoder", None) is not None:
|
||||
ad_penalty = model.sequence_autodecoder.calculate_squared_encoding_norm()
|
||||
if ad_penalty is not None:
|
||||
metrics["autodecoder_norm"] = ad_penalty
|
||||
|
||||
if keys_prefix is not None:
|
||||
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
@registry.register
|
||||
class ViewMetrics(ViewMetricsBase):
|
||||
def forward(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculates various differentiable metrics useful for supervising
|
||||
differentiable rendering pipelines.
|
||||
|
||||
Args:
|
||||
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
|
||||
image locations at which the predictions are defined.
|
||||
All ground truth inputs are sampled at these
|
||||
locations in order to extract values that correspond
|
||||
to the predictions.
|
||||
images: A tensor of shape `(B, H, W, 3)` containing ground truth
|
||||
rgb values.
|
||||
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
|
||||
rgb values.
|
||||
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
|
||||
depth values.
|
||||
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
||||
depth values.
|
||||
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||
foreground masks.
|
||||
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
|
||||
foreground masks.
|
||||
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
|
||||
of a gradient of a signed distance function w.r.t.
|
||||
results: A dict to store the results in.
|
||||
raymarched.features: Predicted rgb or feature values.
|
||||
raymarched.depths: A tensor of shape `(B, ..., 1)` containing
|
||||
predicted depth values.
|
||||
raymarched.masks: A tensor of shape `(B, ..., 1)` containing
|
||||
predicted foreground masks.
|
||||
raymarched.aux["grad_theta"]: A tensor of shape `(B, ..., 3)` containing an
|
||||
evaluation of a gradient of a signed distance function w.r.t.
|
||||
input 3D coordinates used to compute the eikonal loss.
|
||||
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
|
||||
`Hg x Wg x Dg` voxel grid of density values.
|
||||
raymarched.aux["density_grid"]: A tensor of shape `(B, Hg, Wg, Dg, 1)`
|
||||
containing a `Hg x Wg x Dg` voxel grid of density values.
|
||||
xys: A tensor of shape `(B, ..., 2)` containing 2D image locations at which
|
||||
the predictions are defined. All ground truth inputs are sampled at
|
||||
these locations in order to extract values that correspond to the
|
||||
predictions.
|
||||
image_rgb: A tensor of shape `(B, H, W, 3)` containing ground truth rgb
|
||||
values.
|
||||
depth_map: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth depth
|
||||
values.
|
||||
fg_probability: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
|
||||
foreground masks.
|
||||
keys_prefix: A common prefix for all keys in the output dictionary
|
||||
containing all metrics.
|
||||
mask_renders_by_pred: If `True`, masks rendered images by the predicted
|
||||
`masks_pred` prior to computing all rgb metrics.
|
||||
containing all view metrics.
|
||||
|
||||
Returns:
|
||||
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
||||
A dictionary `{metric_name_i: metric_value_i}` keyed by the
|
||||
names of the output metrics `metric_name_i` with their corresponding
|
||||
values `metric_value_i` represented as 0-dimensional float tensors.
|
||||
|
||||
@@ -90,109 +207,142 @@ class ViewMetrics(torch.nn.Module):
|
||||
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
|
||||
predicted depth values.
|
||||
"""
|
||||
metrics = self._calculate_stage(
|
||||
raymarched,
|
||||
xys,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
mask_crop,
|
||||
keys_prefix,
|
||||
)
|
||||
|
||||
if raymarched.prev_stage:
|
||||
metrics.update(
|
||||
self(
|
||||
raymarched.prev_stage,
|
||||
xys,
|
||||
image_rgb,
|
||||
depth_map,
|
||||
fg_probability,
|
||||
mask_crop,
|
||||
keys_prefix=(keys_prefix + "prev_stage_"),
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _calculate_stage(
|
||||
self,
|
||||
raymarched: RendererOutput,
|
||||
xys: torch.Tensor,
|
||||
image_rgb: Optional[torch.Tensor] = None,
|
||||
depth_map: Optional[torch.Tensor] = None,
|
||||
fg_probability: Optional[torch.Tensor] = None,
|
||||
mask_crop: Optional[torch.Tensor] = None,
|
||||
keys_prefix: str = "loss_",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Calculate metrics for the current stage.
|
||||
"""
|
||||
# TODO: extract functions
|
||||
|
||||
# reshape from B x ... x DIM to B x DIM x -1 x 1
|
||||
images_pred, masks_pred, depths_pred = [
|
||||
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
|
||||
image_rgb_pred, fg_probability_pred, depth_map_pred = [
|
||||
_reshape_nongrid_var(x)
|
||||
for x in [raymarched.features, raymarched.masks, raymarched.depths]
|
||||
]
|
||||
# reshape the sampling grid as well
|
||||
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
|
||||
# now that we use rend_utils.ndc_grid_sample
|
||||
image_sampling_grid = image_sampling_grid.reshape(
|
||||
image_sampling_grid.shape[0], -1, 1, 2
|
||||
)
|
||||
xys = xys.reshape(xys.shape[0], -1, 1, 2)
|
||||
|
||||
# closure with the given image_sampling_grid
|
||||
# closure with the given xys
|
||||
def sample(tensor, mode):
|
||||
if tensor is None:
|
||||
return tensor
|
||||
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
|
||||
return rend_utils.ndc_grid_sample(tensor, xys, mode=mode)
|
||||
|
||||
# eval all results in this size
|
||||
images = sample(images, mode="bilinear")
|
||||
depths = sample(depths, mode="nearest")
|
||||
masks = sample(masks, mode="nearest")
|
||||
masks_crop = sample(masks_crop, mode="nearest")
|
||||
if masks_crop is None and images_pred is not None:
|
||||
masks_crop = torch.ones_like(images_pred[:, :1])
|
||||
if masks_crop is None and depths_pred is not None:
|
||||
masks_crop = torch.ones_like(depths_pred[:, :1])
|
||||
image_rgb = sample(image_rgb, mode="bilinear")
|
||||
depth_map = sample(depth_map, mode="nearest")
|
||||
fg_probability = sample(fg_probability, mode="nearest")
|
||||
mask_crop = sample(mask_crop, mode="nearest")
|
||||
if mask_crop is None and image_rgb_pred is not None:
|
||||
mask_crop = torch.ones_like(image_rgb_pred[:, :1])
|
||||
if mask_crop is None and depth_map_pred is not None:
|
||||
mask_crop = torch.ones_like(depth_map_pred[:, :1])
|
||||
|
||||
preds = {}
|
||||
if images is not None and images_pred is not None:
|
||||
# TODO: mask_renders_by_pred is always false; simplify
|
||||
preds.update(
|
||||
metrics = {}
|
||||
if image_rgb is not None and image_rgb_pred is not None:
|
||||
metrics.update(
|
||||
_rgb_metrics(
|
||||
images,
|
||||
images_pred,
|
||||
masks,
|
||||
masks_pred,
|
||||
masks_crop,
|
||||
mask_renders_by_pred,
|
||||
image_rgb,
|
||||
image_rgb_pred,
|
||||
fg_probability,
|
||||
fg_probability_pred,
|
||||
mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
if masks_pred is not None:
|
||||
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
|
||||
if masks is not None and masks_pred is not None:
|
||||
preds["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
masks_pred, masks, mask=masks_crop
|
||||
if fg_probability_pred is not None:
|
||||
metrics["mask_beta_prior"] = utils.beta_prior(fg_probability_pred)
|
||||
if fg_probability is not None and fg_probability_pred is not None:
|
||||
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
|
||||
|
||||
if depths is not None and depths_pred is not None:
|
||||
assert masks_crop is not None
|
||||
if depth_map is not None and depth_map_pred is not None:
|
||||
assert mask_crop is not None
|
||||
_, abs_ = utils.eval_depth(
|
||||
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask_crop, crop=0
|
||||
)
|
||||
preds["depth_abs"] = abs_.mean()
|
||||
metrics["depth_abs"] = abs_.mean()
|
||||
|
||||
if masks is not None:
|
||||
mask = masks * masks_crop
|
||||
if fg_probability is not None:
|
||||
mask = fg_probability * mask_crop
|
||||
_, abs_ = utils.eval_depth(
|
||||
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||
)
|
||||
preds["depth_abs_fg"] = abs_.mean()
|
||||
metrics["depth_abs_fg"] = abs_.mean()
|
||||
|
||||
# regularizers
|
||||
grad_theta = raymarched.aux.get("grad_theta")
|
||||
if grad_theta is not None:
|
||||
preds["eikonal"] = _get_eikonal_loss(grad_theta)
|
||||
metrics["eikonal"] = _get_eikonal_loss(grad_theta)
|
||||
|
||||
density_grid = raymarched.aux.get("density_grid")
|
||||
if density_grid is not None:
|
||||
preds["density_tv"] = _get_grid_tv_loss(density_grid)
|
||||
metrics["density_tv"] = _get_grid_tv_loss(density_grid)
|
||||
|
||||
if depths_pred is not None:
|
||||
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
|
||||
if depth_map_pred is not None:
|
||||
metrics["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depth_map_pred)
|
||||
|
||||
if keys_prefix is not None:
|
||||
preds = {(keys_prefix + k): v for k, v in preds.items()}
|
||||
metrics = {(keys_prefix + k): v for k, v in metrics.items()}
|
||||
|
||||
return preds
|
||||
return metrics
|
||||
|
||||
|
||||
def _rgb_metrics(
|
||||
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
|
||||
):
|
||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||
assert masks_crop is not None
|
||||
if mask_renders_by_pred:
|
||||
images = images[..., masks_pred.reshape(-1), :]
|
||||
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
|
||||
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
|
||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||
crop_mass = masks_crop.sum().clamp(1.0)
|
||||
preds = {
|
||||
results = {
|
||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||
}
|
||||
if masks is not None:
|
||||
masks = masks_crop * masks
|
||||
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
||||
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
||||
return preds
|
||||
results["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
|
||||
results["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
|
||||
return results
|
||||
|
||||
|
||||
def _get_eikonal_loss(grad_theta):
|
||||
|
||||
@@ -5,13 +5,11 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
NewViewSynthesisPrediction,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import (
|
||||
get_rgbd_point_cloud,
|
||||
render_point_cloud_pytorch3d,
|
||||
@@ -19,41 +17,43 @@ from pytorch3d.implicitron.tools.point_cloud_utils import (
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.structures import Pointclouds
|
||||
|
||||
from .base_model import ImplicitronModelBase, ImplicitronRender
|
||||
from .renderer.base import EvaluationMode
|
||||
|
||||
class ModelDBIR(torch.nn.Module):
|
||||
|
||||
@registry.register
|
||||
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
||||
"""
|
||||
A simple depth-based image rendering model.
|
||||
|
||||
Args:
|
||||
render_image_width: The width of the rendered rectangular images.
|
||||
render_image_height: The height of the rendered rectangular images.
|
||||
bg_color: The color of the background.
|
||||
max_points: Maximum number of points in the point cloud
|
||||
formed by unprojecting all source view depths.
|
||||
If more points are present, they are randomly subsampled
|
||||
to this number of points without replacement.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: int = 256,
|
||||
bg_color: float = 0.0,
|
||||
max_points: int = -1,
|
||||
):
|
||||
"""
|
||||
Initializes a simple DBIR model.
|
||||
|
||||
Args:
|
||||
image_size: The size of the rendered rectangular images.
|
||||
bg_color: The color of the background.
|
||||
max_points: Maximum number of points in the point cloud
|
||||
formed by unprojecting all source view depths.
|
||||
If more points are present, they are randomly subsampled
|
||||
to #max_size points without replacement.
|
||||
"""
|
||||
render_image_width: int = 256
|
||||
render_image_height: int = 256
|
||||
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
max_points: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self.image_size = image_size
|
||||
self.bg_color = bg_color
|
||||
self.max_points = max_points
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force keyword-only arguments
|
||||
image_rgb: Optional[torch.Tensor],
|
||||
camera: CamerasBase,
|
||||
image_rgb: torch.Tensor,
|
||||
depth_map: torch.Tensor,
|
||||
fg_probability: torch.Tensor,
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
mask_crop: Optional[torch.Tensor],
|
||||
depth_map: Optional[torch.Tensor],
|
||||
sequence_name: Optional[List[str]],
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
frame_type: List[str],
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
|
||||
@@ -72,26 +72,39 @@ class ModelDBIR(torch.nn.Module):
|
||||
|
||||
Returns:
|
||||
preds: A dict with the following fields:
|
||||
nvs_prediction: The rendered colors, depth and mask
|
||||
implicitron_render: The rendered colors, depth and mask
|
||||
of the target views.
|
||||
point_cloud: The point cloud of the scene. It's renders are
|
||||
stored in `nvs_prediction`.
|
||||
stored in `implicitron_render`.
|
||||
"""
|
||||
|
||||
if image_rgb is None:
|
||||
raise ValueError("ModelDBIR needs image input")
|
||||
|
||||
if fg_probability is None:
|
||||
raise ValueError("ModelDBIR needs foreground mask input")
|
||||
|
||||
if depth_map is None:
|
||||
raise ValueError("ModelDBIR needs depth map input")
|
||||
|
||||
is_known = is_known_frame(frame_type)
|
||||
is_known_idx = torch.where(is_known)[0]
|
||||
|
||||
mask_fg = (fg_probability > 0.5).type_as(image_rgb)
|
||||
|
||||
point_cloud = get_rgbd_point_cloud(
|
||||
# pyre-fixme[6]: For 1st param expected `Union[List[int], int,
|
||||
# LongTensor]` but got `Tensor`.
|
||||
camera[is_known_idx],
|
||||
image_rgb[is_known_idx],
|
||||
depth_map[is_known_idx],
|
||||
mask_fg[is_known_idx],
|
||||
)
|
||||
|
||||
pcl_size = int(point_cloud.num_points_per_cloud())
|
||||
pcl_size = point_cloud.num_points_per_cloud().item()
|
||||
if (self.max_points > 0) and (pcl_size > self.max_points):
|
||||
# pyre-fixme[6]: For 1st param expected `int` but got `Union[bool,
|
||||
# float, int]`.
|
||||
prm = torch.randperm(pcl_size)[: self.max_points]
|
||||
point_cloud = Pointclouds(
|
||||
point_cloud.points_padded()[:, prm, :],
|
||||
@@ -108,7 +121,7 @@ class ModelDBIR(torch.nn.Module):
|
||||
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
|
||||
camera[int(tgt_idx)],
|
||||
point_cloud,
|
||||
render_size=(self.image_size, self.image_size),
|
||||
render_size=(self.render_image_height, self.render_image_width),
|
||||
point_radius=1e-2,
|
||||
topk=10,
|
||||
bg_color=self.bg_color,
|
||||
@@ -121,7 +134,7 @@ class ModelDBIR(torch.nn.Module):
|
||||
image_render.append(_image_render)
|
||||
mask_render.append(_mask_render)
|
||||
|
||||
nvs_prediction = NewViewSynthesisPrediction(
|
||||
implicitron_render = ImplicitronRender(
|
||||
**{
|
||||
k: torch.cat(v, dim=0)
|
||||
for k, v in zip(
|
||||
@@ -132,7 +145,7 @@ class ModelDBIR(torch.nn.Module):
|
||||
)
|
||||
|
||||
preds = {
|
||||
"nvs_prediction": nvs_prediction,
|
||||
"implicitron_render": implicitron_render,
|
||||
"point_cloud": point_cloud,
|
||||
}
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ class RendererOutput:
|
||||
prev_stage: Optional[RendererOutput] = None
|
||||
normals: Optional[torch.Tensor] = None
|
||||
points: Optional[torch.Tensor] = None # TODO: redundant with depths
|
||||
weights: Optional[torch.Tensor] = None
|
||||
aux: Dict[str, Any] = field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@@ -87,7 +88,7 @@ class BaseRenderer(ABC, ReplaceableBase):
|
||||
ray_bundle,
|
||||
implicit_functions: List[ImplicitFunctionWrapper],
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> RendererOutput:
|
||||
"""
|
||||
Each Renderer should implement its own forward function
|
||||
|
||||
@@ -21,6 +21,8 @@ logger = logging.getLogger(__name__)
|
||||
class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
"""
|
||||
Implements the learnable LSTM raymarching function from SRN [1].
|
||||
This requires there to be one implicit function, and it is expected to be
|
||||
like SRNImplicitFunction or SRNHyperNetImplicitFunction.
|
||||
|
||||
Settings:
|
||||
num_raymarch_steps: The number of LSTM raymarching steps.
|
||||
@@ -32,6 +34,11 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
hidden_size: The dimensionality of the LSTM's hidden state.
|
||||
n_feature_channels: The number of feature channels returned by the
|
||||
implicit_function evaluated at each raymarching step.
|
||||
bg_color: If supplied, used as the background color. Otherwise the pixel
|
||||
generator is used everywhere. This has to have length either 1
|
||||
(for a constant value for all output channels) or equal to the number
|
||||
of output channels (which is `out_features` on the pixel generator,
|
||||
typically 3.)
|
||||
verbose: If `True`, logs raymarching debug info.
|
||||
|
||||
References:
|
||||
@@ -45,6 +52,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
init_depth_noise_std: float = 5e-4
|
||||
hidden_size: int = 16
|
||||
n_feature_channels: int = 256
|
||||
bg_color: Optional[List[float]] = None
|
||||
verbose: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -117,13 +125,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
msg = (
|
||||
f"{t}: mu={float(signed_distance.mean()):1.2e};"
|
||||
+ f" std={float(signed_distance.std()):1.2e};"
|
||||
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
||||
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
|
||||
# param but got `Tensor`.
|
||||
+ f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
|
||||
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
|
||||
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
|
||||
# param but got `Tensor`.
|
||||
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
|
||||
)
|
||||
logger.info(msg)
|
||||
@@ -153,6 +155,10 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
if self.bg_color is not None:
|
||||
background = features.new_tensor(self.bg_color)
|
||||
features = torch.lerp(background, features, mask)
|
||||
|
||||
return RendererOutput(
|
||||
features=features[..., 0, :],
|
||||
depths=depth,
|
||||
|
||||
@@ -4,18 +4,21 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Tuple
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
from pytorch3d.implicitron.tools.config import registry, run_auto_creation
|
||||
from pytorch3d.renderer import RayBundle
|
||||
|
||||
from .base import BaseRenderer, EvaluationMode, RendererOutput
|
||||
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
||||
from .ray_point_refiner import RayPointRefiner
|
||||
from .raymarcher import GenericRaymarcher
|
||||
from .raymarcher import RaymarcherBase
|
||||
|
||||
|
||||
@registry.register
|
||||
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
||||
class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
||||
BaseRenderer, torch.nn.Module
|
||||
):
|
||||
"""
|
||||
Implements the multi-pass rendering function, in particular,
|
||||
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
|
||||
@@ -33,7 +36,17 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
||||
```
|
||||
and the final rendered quantities are computed by a dot-product of ray values
|
||||
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
|
||||
See below for possible values of `cap_fn` and `weight_fn`.
|
||||
|
||||
By default, for the EA raymarcher from [1] (
|
||||
activated with `self.raymarcher_class_type="EmissionAbsorptionRaymarcher"`
|
||||
):
|
||||
```
|
||||
cap_fn(x) = 1 - exp(-x),
|
||||
weight_fn(x) = w * x.
|
||||
```
|
||||
Note that the latter can altered by changing `self.raymarcher_class_type`,
|
||||
e.g. to "CumsumRaymarcher" which implements the cumulative-sum raymarcher
|
||||
from NeuralVolumes [2].
|
||||
|
||||
Settings:
|
||||
n_pts_per_ray_fine_training: The number of points sampled per ray for the
|
||||
@@ -46,42 +59,33 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
||||
evaluation.
|
||||
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
||||
after sampling.
|
||||
bg_color: The background color. A tuple of either 1 element or of D elements,
|
||||
where D matches the feature dimensionality; it is broadcasted when necessary.
|
||||
density_noise_std_train: Standard deviation of the noise added to the
|
||||
opacity field.
|
||||
capping_function: The capping function of the raymarcher.
|
||||
Options:
|
||||
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
|
||||
- "cap1" (`cap_fn(x) = min(x, 1)`)
|
||||
Set to "exponential" for the standard Emission Absorption raymarching.
|
||||
weight_function: The weighting function of the raymarcher.
|
||||
Options:
|
||||
- "product" (`weight_fn(w, x) = w * x`)
|
||||
- "minimum" (`weight_fn(w, x) = min(w, x)`)
|
||||
Set to "product" for the standard Emission Absorption raymarching.
|
||||
background_opacity: The raw opacity value (i.e. before exponentiation)
|
||||
of the background.
|
||||
blend_output: If `True`, alpha-blends the output renders with the
|
||||
background color using the rendered opacity mask.
|
||||
return_weights: Enables returning the rendering weights of the EA raymarcher.
|
||||
Setting to `True` can lead to a prohibitivelly large memory consumption.
|
||||
raymarcher_class_type: The type of self.raymarcher corresponding to
|
||||
a child of `RaymarcherBase` in the registry.
|
||||
raymarcher: The raymarcher object used to convert per-point features
|
||||
and opacities to a feature render.
|
||||
|
||||
References:
|
||||
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
|
||||
fields for view synthesis." ECCV 2020.
|
||||
[1] Mildenhall, Ben, et al. "Nerf: Representing Scenes as Neural Radiance
|
||||
Fields for View Synthesis." ECCV 2020.
|
||||
[2] Lombardi, Stephen, et al. "Neural Volumes: Learning Dynamic Renderable
|
||||
Volumes from Images." SIGGRAPH 2019.
|
||||
|
||||
"""
|
||||
|
||||
raymarcher_class_type: str = "EmissionAbsorptionRaymarcher"
|
||||
raymarcher: RaymarcherBase
|
||||
|
||||
n_pts_per_ray_fine_training: int = 64
|
||||
n_pts_per_ray_fine_evaluation: int = 64
|
||||
stratified_sampling_coarse_training: bool = True
|
||||
stratified_sampling_coarse_evaluation: bool = False
|
||||
append_coarse_samples_to_fine: bool = True
|
||||
bg_color: Tuple[float, ...] = (0.0,)
|
||||
density_noise_std_train: float = 0.0
|
||||
capping_function: str = "exponential" # exponential | cap1
|
||||
weight_function: str = "product" # product | minimum
|
||||
background_opacity: float = 1e10
|
||||
blend_output: bool = False
|
||||
return_weights: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
@@ -97,22 +101,14 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
||||
add_input_samples=self.append_coarse_samples_to_fine,
|
||||
),
|
||||
}
|
||||
|
||||
self._raymarcher = GenericRaymarcher(
|
||||
1,
|
||||
self.bg_color,
|
||||
capping_function=self.capping_function,
|
||||
weight_function=self.weight_function,
|
||||
background_opacity=self.background_opacity,
|
||||
blend_output=self.blend_output,
|
||||
)
|
||||
run_auto_creation(self)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle,
|
||||
implicit_functions=[],
|
||||
ray_bundle: RayBundle,
|
||||
implicit_functions: List[ImplicitFunctionWrapper],
|
||||
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
) -> RendererOutput:
|
||||
"""
|
||||
Args:
|
||||
@@ -149,14 +145,16 @@ class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
|
||||
else 0.0
|
||||
)
|
||||
|
||||
features, depth, mask, weights, aux = self._raymarcher(
|
||||
output = self.raymarcher(
|
||||
*implicit_functions[0](ray_bundle),
|
||||
ray_lengths=ray_bundle.lengths,
|
||||
density_noise_std=density_noise_std,
|
||||
)
|
||||
output = RendererOutput(
|
||||
features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
|
||||
)
|
||||
output.prev_stage = prev_stage
|
||||
|
||||
weights = output.weights
|
||||
if not self.return_weights:
|
||||
output.weights = None
|
||||
|
||||
# we may need to make a recursive call
|
||||
if len(implicit_functions) > 1:
|
||||
|
||||
@@ -4,21 +4,52 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools import camera_utils
|
||||
from pytorch3d.implicitron.tools.config import Configurable
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .base import EvaluationMode, RenderSamplingMode
|
||||
|
||||
|
||||
class RaySampler(Configurable, torch.nn.Module):
|
||||
class RaySamplerBase(ReplaceableBase):
|
||||
"""
|
||||
Base class for ray samplers.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cameras: CamerasBase,
|
||||
evaluation_mode: EvaluationMode,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
) -> RayBundle:
|
||||
"""
|
||||
Args:
|
||||
cameras: A batch of `batch_size` cameras from which the rays are emitted.
|
||||
evaluation_mode: one of `EvaluationMode.TRAINING` or
|
||||
`EvaluationMode.EVALUATION` which determines the sampling mode
|
||||
that is used.
|
||||
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
|
||||
Defines a non-negative mask of shape
|
||||
`(batch_size, image_height, image_width)` where each per-pixel
|
||||
value is proportional to the probability of sampling the
|
||||
corresponding pixel's ray.
|
||||
|
||||
Returns:
|
||||
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
||||
sampled rendering rays.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
|
||||
"""
|
||||
Samples a fixed number of points along rays which are in turn sampled for
|
||||
each camera in a batch.
|
||||
|
||||
@@ -29,46 +60,19 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
for training and evaluation by setting `self.sampling_mode_training`
|
||||
and `self.sampling_mode_training` accordingly.
|
||||
|
||||
The class allows two modes of sampling points along the rays:
|
||||
1) Sampling between fixed near and far z-planes:
|
||||
Active when `self.scene_extent <= 0`, samples points along each ray
|
||||
with approximately uniform spacing of z-coordinates between
|
||||
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
|
||||
This sampling is useful for rendering scenes where the camera is
|
||||
in a constant distance from the focal point of the scene.
|
||||
2) Adaptive near/far plane estimation around the world scene center:
|
||||
Active when `self.scene_extent > 0`. Samples points on each
|
||||
ray between near and far planes whose depths are determined based on
|
||||
the distance from the camera center to a predefined scene center.
|
||||
More specifically,
|
||||
`min_depth = max(
|
||||
(self.scene_center-camera_center).norm() - self.scene_extent, eps
|
||||
)` and
|
||||
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
|
||||
This sampling is ideal for object-centric scenes whose contents are
|
||||
centered around a known `self.scene_center` and fit into a bounding sphere
|
||||
with a radius of `self.scene_extent`.
|
||||
|
||||
Similar to the sampling mode, the sampling parameters can be set separately
|
||||
for training and evaluation.
|
||||
The class allows to adjust the sampling points along rays by overwriting the
|
||||
`AbstractMaskRaySampler._get_min_max_depth_bounds` function which returns
|
||||
the near/far planes (`min_depth`/`max_depth`) `NDCMultinomialRaysampler`.
|
||||
|
||||
Settings:
|
||||
image_width: The horizontal size of the image grid.
|
||||
image_height: The vertical size of the image grid.
|
||||
scene_center: The xyz coordinates of the center of the scene used
|
||||
along with `scene_extent` to compute the min and max depth planes
|
||||
for sampling ray-points.
|
||||
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
|
||||
If `scene_extent <= 0`, the raysampler samples points between
|
||||
`self.min_depth` and `self.max_depth` depths instead.
|
||||
sampling_mode_training: The ray sampling mode for training. This should be a str
|
||||
option from the RenderSamplingMode Enum
|
||||
sampling_mode_evaluation: Same as above but for evaluation.
|
||||
n_pts_per_ray_training: The number of points sampled along each ray during training.
|
||||
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
|
||||
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
|
||||
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
|
||||
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
|
||||
stratified_point_sampling_training: if set, performs stratified random sampling
|
||||
along the ray; otherwise takes ray points at deterministic offsets.
|
||||
stratified_point_sampling_evaluation: Same as above but for evaluation.
|
||||
@@ -77,24 +81,17 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
|
||||
image_width: int = 400
|
||||
image_height: int = 400
|
||||
scene_center: Tuple[float, float, float] = field(
|
||||
default_factory=lambda: (0.0, 0.0, 0.0)
|
||||
)
|
||||
scene_extent: float = 0.0
|
||||
sampling_mode_training: str = "mask_sample"
|
||||
sampling_mode_evaluation: str = "full_grid"
|
||||
n_pts_per_ray_training: int = 64
|
||||
n_pts_per_ray_evaluation: int = 64
|
||||
n_rays_per_image_sampled_from_mask: int = 1024
|
||||
min_depth: float = 0.1
|
||||
max_depth: float = 8.0
|
||||
# stratified sampling vs taking points at deterministic offsets
|
||||
stratified_point_sampling_training: bool = True
|
||||
stratified_point_sampling_evaluation: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
self.scene_center = torch.FloatTensor(self.scene_center)
|
||||
|
||||
self._sampling_mode = {
|
||||
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
|
||||
@@ -108,8 +105,8 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_training,
|
||||
min_depth=self.min_depth,
|
||||
max_depth=self.max_depth,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
if self._sampling_mode[EvaluationMode.TRAINING]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
@@ -121,8 +118,8 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
image_width=self.image_width,
|
||||
image_height=self.image_height,
|
||||
n_pts_per_ray=self.n_pts_per_ray_evaluation,
|
||||
min_depth=self.min_depth,
|
||||
max_depth=self.max_depth,
|
||||
min_depth=0.0,
|
||||
max_depth=0.0,
|
||||
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
|
||||
if self._sampling_mode[EvaluationMode.EVALUATION]
|
||||
== RenderSamplingMode.MASK_SAMPLE
|
||||
@@ -132,6 +129,9 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
),
|
||||
}
|
||||
|
||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cameras: CamerasBase,
|
||||
@@ -163,18 +163,11 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
):
|
||||
sample_mask = torch.nn.functional.interpolate(
|
||||
mask,
|
||||
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
|
||||
# `List[int]`.
|
||||
size=[self.image_height, self.image_width],
|
||||
mode="nearest",
|
||||
)[:, 0]
|
||||
|
||||
if self.scene_extent > 0.0:
|
||||
# Override the min/max depth set in initialization based on the
|
||||
# input cameras.
|
||||
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
||||
cameras, self.scene_center, self.scene_extent
|
||||
)
|
||||
min_depth, max_depth = self._get_min_max_depth_bounds(cameras)
|
||||
|
||||
# pyre-fixme[29]:
|
||||
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
|
||||
@@ -183,8 +176,75 @@ class RaySampler(Configurable, torch.nn.Module):
|
||||
ray_bundle = self._raysamplers[evaluation_mode](
|
||||
cameras=cameras,
|
||||
mask=sample_mask,
|
||||
min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None,
|
||||
max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None,
|
||||
min_depth=min_depth,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
||||
return ray_bundle
|
||||
|
||||
|
||||
@registry.register
|
||||
class AdaptiveRaySampler(AbstractMaskRaySampler):
|
||||
"""
|
||||
Adaptively samples points on each ray between near and far planes whose
|
||||
depths are determined based on the distance from the camera center
|
||||
to a predefined scene center.
|
||||
|
||||
More specifically,
|
||||
`min_depth = max(
|
||||
(self.scene_center-camera_center).norm() - self.scene_extent, eps
|
||||
)` and
|
||||
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
|
||||
|
||||
This sampling is ideal for object-centric scenes whose contents are
|
||||
centered around a known `self.scene_center` and fit into a bounding sphere
|
||||
with a radius of `self.scene_extent`.
|
||||
|
||||
Args:
|
||||
scene_center: The xyz coordinates of the center of the scene used
|
||||
along with `scene_extent` to compute the min and max depth planes
|
||||
for sampling ray-points.
|
||||
scene_extent: The radius of the scene bounding box centered at `scene_center`.
|
||||
"""
|
||||
|
||||
scene_extent: float = 8.0
|
||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.scene_extent <= 0.0:
|
||||
raise ValueError("Adaptive raysampler requires self.scene_extent > 0.")
|
||||
self._scene_center = torch.FloatTensor(self.scene_center)
|
||||
|
||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||
"""
|
||||
Returns the adaptivelly calculated near/far planes.
|
||||
"""
|
||||
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
||||
cameras, self._scene_center, self.scene_extent
|
||||
)
|
||||
return float(min_depth[0]), float(max_depth[0])
|
||||
|
||||
|
||||
@registry.register
|
||||
class NearFarRaySampler(AbstractMaskRaySampler):
|
||||
"""
|
||||
Samples a fixed number of points between fixed near and far z-planes.
|
||||
Specifically, samples points along each ray with approximately uniform spacing
|
||||
of z-coordinates between the minimum depth `self.min_depth` and the maximum depth
|
||||
`self.max_depth`. This sampling is useful for rendering scenes where the camera is
|
||||
in a constant distance from the focal point of the scene.
|
||||
|
||||
Args:
|
||||
min_depth: The minimum depth of a ray-point.
|
||||
max_depth: The maximum depth of a ray-point.
|
||||
"""
|
||||
|
||||
min_depth: float = 0.1
|
||||
max_depth: float = 8.0
|
||||
|
||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||
"""
|
||||
Returns the stored near/far planes.
|
||||
"""
|
||||
return self.min_depth, self.max_depth
|
||||
|
||||
@@ -123,12 +123,12 @@ class RayTracing(Configurable, nn.Module):
|
||||
|
||||
ray_directions = ray_directions.reshape(-1, 3)
|
||||
mask_intersect = mask_intersect.reshape(-1)
|
||||
# pyre-fixme[9]: object_mask has type `BoolTensor`; used as `Tensor`.
|
||||
object_mask = object_mask.reshape(-1)
|
||||
|
||||
in_mask = ~network_object_mask & object_mask & ~sampler_mask
|
||||
out_mask = ~object_mask & ~sampler_mask
|
||||
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
|
||||
mask_left_out = (in_mask | out_mask) & ~mask_intersect
|
||||
if (
|
||||
mask_left_out.sum() > 0
|
||||
@@ -295,7 +295,7 @@ class RayTracing(Configurable, nn.Module):
|
||||
) and not_proj_iters < self.line_step_iters:
|
||||
# Step backwards
|
||||
acc_start_dis[not_projected_start] -= (
|
||||
(1 - self.line_search_step) / (2 ** not_proj_iters)
|
||||
(1 - self.line_search_step) / (2**not_proj_iters)
|
||||
) * curr_sdf_start[not_projected_start]
|
||||
curr_start_points[not_projected_start] = (
|
||||
cam_loc
|
||||
@@ -303,7 +303,7 @@ class RayTracing(Configurable, nn.Module):
|
||||
).reshape(-1, 3)[not_projected_start]
|
||||
|
||||
acc_end_dis[not_projected_end] += (
|
||||
(1 - self.line_search_step) / (2 ** not_proj_iters)
|
||||
(1 - self.line_search_step) / (2**not_proj_iters)
|
||||
) * curr_sdf_end[not_projected_end]
|
||||
curr_end_points[not_projected_end] = (
|
||||
cam_loc
|
||||
@@ -410,10 +410,17 @@ class RayTracing(Configurable, nn.Module):
|
||||
if n_p_out > 0:
|
||||
out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
|
||||
sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][
|
||||
torch.arange(n_p_out), out_pts_idx, :
|
||||
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
|
||||
# but got `Tensor`.
|
||||
torch.arange(n_p_out),
|
||||
out_pts_idx,
|
||||
:,
|
||||
]
|
||||
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
|
||||
p_out_mask, :
|
||||
p_out_mask,
|
||||
:
|
||||
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]` but
|
||||
# got `Tensor`.
|
||||
][torch.arange(n_p_out), out_pts_idx]
|
||||
|
||||
# Get Network object mask
|
||||
@@ -434,10 +441,16 @@ class RayTracing(Configurable, nn.Module):
|
||||
secant_pts
|
||||
]
|
||||
z_low = pts_intervals[secant_pts][
|
||||
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
|
||||
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
|
||||
# but got `Tensor`.
|
||||
torch.arange(n_secant_pts),
|
||||
sampler_pts_ind[secant_pts] - 1,
|
||||
]
|
||||
sdf_low = sdf_val[secant_pts][
|
||||
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
|
||||
# pyre-fixme[6]: For 1st param expected `Union[bool, float, int]`
|
||||
# but got `Tensor`.
|
||||
torch.arange(n_secant_pts),
|
||||
sampler_pts_ind[secant_pts] - 1,
|
||||
]
|
||||
cam_loc_secant = cam_loc.reshape(-1, 3)[mask_intersect_idx[secant_pts]]
|
||||
ray_directions_secant = ray_directions.reshape((-1, 3))[
|
||||
@@ -514,6 +527,7 @@ class RayTracing(Configurable, nn.Module):
|
||||
mask_max_dis = max_dis[mask].unsqueeze(-1)
|
||||
mask_min_dis = min_dis[mask].unsqueeze(-1)
|
||||
steps = (
|
||||
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
|
||||
steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis)
|
||||
+ mask_min_dis
|
||||
)
|
||||
@@ -533,8 +547,13 @@ class RayTracing(Configurable, nn.Module):
|
||||
mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
|
||||
min_vals, min_idx = mask_sdf_all.min(-1)
|
||||
min_mask_points = mask_points_all.reshape(-1, n, 3)[
|
||||
torch.arange(0, n_mask_points), min_idx
|
||||
# pyre-fixme[6]: For 2nd param expected `Union[bool, float, int]` but
|
||||
# got `Tensor`.
|
||||
torch.arange(0, n_mask_points),
|
||||
min_idx,
|
||||
]
|
||||
# pyre-fixme[6]: For 2nd param expected `Union[bool, float, int]` but got
|
||||
# `Tensor`.
|
||||
min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
|
||||
|
||||
return min_mask_points, min_mask_dist
|
||||
@@ -553,7 +572,8 @@ def _get_sphere_intersection(
|
||||
# cam_loc = cam_loc.unsqueeze(-1)
|
||||
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
|
||||
ray_cam_dot = (ray_directions * cam_loc).sum(-1) # n_images x n_rays
|
||||
under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, dim=-1) ** 2 - r ** 2)
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
under_sqrt = ray_cam_dot**2 - (cam_loc.norm(2, dim=-1) ** 2 - r**2)
|
||||
|
||||
under_sqrt = under_sqrt.reshape(-1)
|
||||
mask_intersect = under_sqrt > 0
|
||||
|
||||
@@ -4,51 +4,99 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.base import RendererOutput
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
|
||||
|
||||
|
||||
_TTensor = torch.Tensor
|
||||
|
||||
|
||||
class GenericRaymarcher(torch.nn.Module):
|
||||
class RaymarcherBase(ReplaceableBase):
|
||||
"""
|
||||
Defines a base class for raymarchers. Specifically, a raymarcher is responsible
|
||||
for taking a set of features and density descriptors along rendering rays
|
||||
and marching along them in order to generate a feature render.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
rays_densities: torch.Tensor,
|
||||
rays_features: torch.Tensor,
|
||||
aux: Dict[str, Any],
|
||||
) -> RendererOutput:
|
||||
"""
|
||||
Args:
|
||||
rays_densities: Per-ray density values represented with a tensor
|
||||
of shape `(..., n_points_per_ray, 1)`.
|
||||
rays_features: Per-ray feature values represented with a tensor
|
||||
of shape `(..., n_points_per_ray, feature_dim)`.
|
||||
aux: a dictionary with extra information.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
|
||||
"""
|
||||
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
|
||||
and NeuralVolumes' Accumulative ray marcher. It additionally returns
|
||||
and NeuralVolumes' cumsum ray marcher. It additionally returns
|
||||
the rendering weights that can be used in the NVS pipeline to carry out
|
||||
the importance ray-sampling in the refining pass.
|
||||
Different from `EmissionAbsorptionRaymarcher`, it takes raw
|
||||
Different from `pytorch3d.renderer.EmissionAbsorptionRaymarcher`, it takes raw
|
||||
(non-exponentiated) densities.
|
||||
|
||||
Args:
|
||||
bg_color: background_color. Must be of shape (1,) or (feature_dim,)
|
||||
surface_thickness: The thickness of the raymarched surface.
|
||||
bg_color: The background color. A tuple of either 1 element or of D elements,
|
||||
where D matches the feature dimensionality; it is broadcast when necessary.
|
||||
background_opacity: The raw opacity value (i.e. before exponentiation)
|
||||
of the background.
|
||||
density_relu: If `True`, passes the input density through ReLU before
|
||||
raymarching.
|
||||
blend_output: If `True`, alpha-blends the output renders with the
|
||||
background color using the rendered opacity mask.
|
||||
|
||||
capping_function: The capping function of the raymarcher.
|
||||
Options:
|
||||
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
|
||||
- "cap1" (`cap_fn(x) = min(x, 1)`)
|
||||
Set to "exponential" for the standard Emission Absorption raymarching.
|
||||
weight_function: The weighting function of the raymarcher.
|
||||
Options:
|
||||
- "product" (`weight_fn(w, x) = w * x`)
|
||||
- "minimum" (`weight_fn(w, x) = min(w, x)`)
|
||||
Set to "product" for the standard Emission Absorption raymarching.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
surface_thickness: int = 1,
|
||||
bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,),
|
||||
capping_function: str = "exponential", # exponential | cap1
|
||||
weight_function: str = "product", # product | minimum
|
||||
background_opacity: float = 0.0,
|
||||
density_relu: bool = True,
|
||||
blend_output: bool = True,
|
||||
):
|
||||
surface_thickness: int = 1
|
||||
bg_color: Tuple[float, ...] = (0.0,)
|
||||
background_opacity: float = 0.0
|
||||
density_relu: bool = True
|
||||
blend_output: bool = False
|
||||
|
||||
@property
|
||||
def capping_function_type(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def weight_function_type(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __post_init__(self):
|
||||
"""
|
||||
Args:
|
||||
surface_thickness: Denotes the overlap between the absorption
|
||||
function and the density function.
|
||||
"""
|
||||
super().__init__()
|
||||
self.surface_thickness = surface_thickness
|
||||
self.density_relu = density_relu
|
||||
self.background_opacity = background_opacity
|
||||
self.blend_output = blend_output
|
||||
if not isinstance(bg_color, torch.Tensor):
|
||||
bg_color = torch.tensor(bg_color)
|
||||
|
||||
bg_color = torch.tensor(self.bg_color)
|
||||
if bg_color.ndim != 1:
|
||||
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
|
||||
|
||||
@@ -57,12 +105,12 @@ class GenericRaymarcher(torch.nn.Module):
|
||||
self._capping_function: Callable[[_TTensor], _TTensor] = {
|
||||
"exponential": lambda x: 1.0 - torch.exp(-x),
|
||||
"cap1": lambda x: x.clamp(max=1.0),
|
||||
}[capping_function]
|
||||
}[self.capping_function_type]
|
||||
|
||||
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
|
||||
"product": lambda curr, acc: curr * acc,
|
||||
"minimum": lambda curr, acc: torch.minimum(curr, acc),
|
||||
}[weight_function]
|
||||
}[self.weight_function_type]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -71,7 +119,8 @@ class GenericRaymarcher(torch.nn.Module):
|
||||
aux: Dict[str, Any],
|
||||
ray_lengths: torch.Tensor,
|
||||
density_noise_std: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
|
||||
**kwargs,
|
||||
) -> RendererOutput:
|
||||
"""
|
||||
Args:
|
||||
rays_densities: Per-ray density values represented with a tensor
|
||||
@@ -87,7 +136,7 @@ class GenericRaymarcher(torch.nn.Module):
|
||||
features: A tensor of shape `(..., feature_dim)` containing
|
||||
the rendered features for each ray.
|
||||
depth: A tensor of shape `(..., 1)` containing estimated depth.
|
||||
opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
|
||||
opacities: A tensor of shape `(..., 1)` containing rendered opacities.
|
||||
weights: A tensor of shape `(..., n_points_per_ray)` containing
|
||||
the ray-specific non-negative opacity weights. In general, they
|
||||
don't sum to 1 but do not overcome it, i.e.
|
||||
@@ -113,16 +162,15 @@ class GenericRaymarcher(torch.nn.Module):
|
||||
rays_densities = rays_densities[..., 0]
|
||||
|
||||
if density_noise_std > 0.0:
|
||||
rays_densities = (
|
||||
rays_densities + torch.randn_like(rays_densities) * density_noise_std
|
||||
)
|
||||
noise: _TTensor = torch.randn_like(rays_densities).mul(density_noise_std)
|
||||
rays_densities = rays_densities + noise
|
||||
if self.density_relu:
|
||||
rays_densities = torch.relu(rays_densities)
|
||||
|
||||
weighted_densities = deltas * rays_densities
|
||||
capped_densities = self._capping_function(weighted_densities)
|
||||
capped_densities = self._capping_function(weighted_densities) # pyre-ignore: 29
|
||||
|
||||
rays_opacities = self._capping_function(
|
||||
rays_opacities = self._capping_function( # pyre-ignore: 29
|
||||
torch.cumsum(weighted_densities, dim=-1)
|
||||
)
|
||||
opacities = rays_opacities[..., -1:]
|
||||
@@ -131,7 +179,9 @@ class GenericRaymarcher(torch.nn.Module):
|
||||
)
|
||||
absorption_shifted[..., : self.surface_thickness] = 1.0
|
||||
|
||||
weights = self._weight_function(capped_densities, absorption_shifted)
|
||||
weights = self._weight_function( # pyre-ignore: 29
|
||||
capped_densities, absorption_shifted
|
||||
)
|
||||
features = (weights[..., None] * rays_features).sum(dim=-2)
|
||||
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
|
||||
|
||||
@@ -140,4 +190,42 @@ class GenericRaymarcher(torch.nn.Module):
|
||||
raise ValueError("Wrong number of background color channels.")
|
||||
features = alpha * features + (1 - opacities) * self._bg_color
|
||||
|
||||
return features, depth, opacities, weights, aux
|
||||
return RendererOutput(
|
||||
features=features,
|
||||
depths=depth,
|
||||
masks=opacities,
|
||||
weights=weights,
|
||||
aux=aux,
|
||||
)
|
||||
|
||||
|
||||
@registry.register
|
||||
class EmissionAbsorptionRaymarcher(AccumulativeRaymarcherBase):
|
||||
"""
|
||||
Implements the EmissionAbsorption raymarcher.
|
||||
"""
|
||||
|
||||
background_opacity: float = 1e10
|
||||
|
||||
@property
|
||||
def capping_function_type(self) -> str:
|
||||
return "exponential"
|
||||
|
||||
@property
|
||||
def weight_function_type(self) -> str:
|
||||
return "product"
|
||||
|
||||
|
||||
@registry.register
|
||||
class CumsumRaymarcher(AccumulativeRaymarcherBase):
|
||||
"""
|
||||
Implements the NeuralVolumes' cumulative-sum raymarcher.
|
||||
"""
|
||||
|
||||
@property
|
||||
def capping_function_type(self) -> str:
|
||||
return "cap1"
|
||||
|
||||
@property
|
||||
def weight_function_type(self) -> str:
|
||||
return "minimum"
|
||||
|
||||
@@ -16,6 +16,30 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayNormalColoringNetwork(torch.nn.Module):
|
||||
"""
|
||||
Members:
|
||||
d_in and feature_vector_size: Sum of these is the input
|
||||
dimension. These must add up to the sum of
|
||||
- 3 [for the points]
|
||||
- 3 unless mode=no_normal [for the normals]
|
||||
- 3 unless mode=no_view_dir [for view directions]
|
||||
- the feature size, [number of channels in feature_vectors]
|
||||
|
||||
d_out: dimension of output.
|
||||
mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting
|
||||
part of the network input.
|
||||
dims: list of hidden layer sizes.
|
||||
weight_norm: whether to apply weight normalization to each layer.
|
||||
n_harmonic_functions_dir:
|
||||
If >0, use a harmonic embedding with this number of
|
||||
harmonic functions for the view direction. Otherwise view directions
|
||||
are fed without embedding, unless mode is `no_view_dir`.
|
||||
pooled_feature_dim: If a pooling function is in use (provided as
|
||||
pooling_fn to forward()) this must be its number of features.
|
||||
Otherwise this must be set to 0. (If used from GenericModel,
|
||||
this will be set automatically.)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_vector_size: int = 3,
|
||||
|
||||
@@ -132,7 +132,11 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
eik_bounding_box: float = self.object_bounding_sphere
|
||||
n_eik_points = batch_size * num_pixels // 2
|
||||
eikonal_points = torch.empty(
|
||||
n_eik_points, 3, device=self._bg_color.device
|
||||
n_eik_points,
|
||||
3,
|
||||
# pyre-fixme[6]: For 3rd param expected `Union[None, str, device]`
|
||||
# but got `Union[device, Tensor, Module]`.
|
||||
device=self._bg_color.device,
|
||||
).uniform_(-eik_bounding_box, eik_bounding_box)
|
||||
eikonal_pixel_points = points.clone()
|
||||
eikonal_pixel_points = eikonal_pixel_points.detach()
|
||||
@@ -196,7 +200,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||
pooling_fn=None, # TODO
|
||||
)
|
||||
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
|
||||
-self.soft_mask_alpha * sdf_output[~surface_mask]
|
||||
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
||||
-self.soft_mask_alpha
|
||||
* sdf_output[~surface_mask]
|
||||
)
|
||||
|
||||
# scatter points with surface_mask
|
||||
|
||||
5
pytorch3d/implicitron/models/view_pooler/__init__.py
Normal file
5
pytorch3d/implicitron/models/view_pooler/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
@@ -6,11 +6,11 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
from typing import Dict, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
|
||||
from pytorch3d.implicitron.models.view_pooler.view_sampler import (
|
||||
cameras_points_cartesian_product,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
@@ -82,6 +82,33 @@ class FeatureAggregatorBase(ABC, ReplaceableBase):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
"""
|
||||
Returns the final dimensionality of the output aggregated features.
|
||||
|
||||
Args:
|
||||
feats_or_feats_dim: Either a `dict` of sampled features `{f_i: t_i}` corresponding
|
||||
to the `feats_sampled` argument of `forward`,
|
||||
or an `int` representing the sum of dimensionalities of each `t_i`.
|
||||
|
||||
Returns:
|
||||
aggregated_feature_dim: The final dimensionality of the output
|
||||
aggregated features.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def has_aggregation(self) -> bool:
|
||||
"""
|
||||
Specifies whether the aggregator reduces the output `reduce_dim` dimension to 1.
|
||||
|
||||
Returns:
|
||||
has_aggregation: `True` if `reduce_dim==1`, else `False`.
|
||||
"""
|
||||
return hasattr(self, "reduction_functions")
|
||||
|
||||
|
||||
@registry.register
|
||||
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||
@@ -94,8 +121,10 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||
return _get_reduction_aggregator_feature_dim(feats, [])
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -147,7 +176,7 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||
the stack of source-view-specific features to a single feature.
|
||||
"""
|
||||
|
||||
reduction_functions: Sequence[ReductionFunction] = (
|
||||
reduction_functions: Tuple[ReductionFunction, ...] = (
|
||||
ReductionFunction.AVG,
|
||||
ReductionFunction.STD,
|
||||
)
|
||||
@@ -155,8 +184,12 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
return _get_reduction_aggregator_feature_dim(
|
||||
feats_or_feats_dim, self.reduction_functions
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -236,7 +269,7 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
||||
used when calculating the angle-based aggregation weights.
|
||||
"""
|
||||
|
||||
reduction_functions: Sequence[ReductionFunction] = (
|
||||
reduction_functions: Tuple[ReductionFunction, ...] = (
|
||||
ReductionFunction.AVG,
|
||||
ReductionFunction.STD,
|
||||
)
|
||||
@@ -246,8 +279,12 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
return _get_reduction_aggregator_feature_dim(
|
||||
feats_or_feats_dim, self.reduction_functions
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -345,8 +382,10 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||
return _get_reduction_aggregator_feature_dim(feats, [])
|
||||
def get_aggregated_feature_dim(
|
||||
self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int]
|
||||
):
|
||||
return _get_reduction_aggregator_feature_dim(feats_or_feats_dim, [])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -511,7 +550,6 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
|
||||
# torch.Tensor, torch.nn.modules.module.Module]` is not a function.
|
||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
|
||||
camera_rep.T[:, None],
|
||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
|
||||
camera_rep.R.permute(0, 2, 1),
|
||||
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
|
||||
# cam_centers_rep = camera_rep.get_camera_center().reshape(
|
||||
@@ -610,6 +648,7 @@ def _avgmaxstd_reduction_function(
|
||||
x_aggr = torch.cat(pooled_features, dim=-1)
|
||||
|
||||
# zero out features that were all masked out
|
||||
# pyre-fixme[16]: `bool` has no attribute `type_as`.
|
||||
any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr)
|
||||
x_aggr = x_aggr * any_active[..., None]
|
||||
|
||||
@@ -637,6 +676,7 @@ def _std_reduction_function(
|
||||
):
|
||||
if mu is None:
|
||||
mu = _avg_reduction_function(x, w, dim=dim)
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt()
|
||||
# FIXME: somehow this is extremely heavy in mem?
|
||||
return std
|
||||
128
pytorch3d/implicitron/models/view_pooler/view_pooler.py
Normal file
128
pytorch3d/implicitron/models/view_pooler/view_pooler.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import Configurable, run_auto_creation
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .feature_aggregator import FeatureAggregatorBase
|
||||
from .view_sampler import ViewSampler
|
||||
|
||||
|
||||
# pyre-ignore: 13
|
||||
class ViewPooler(Configurable, torch.nn.Module):
|
||||
"""
|
||||
Implements sampling of image-based features at the 2d projections of a set
|
||||
of 3D points, and a subsequent aggregation of the resulting set of features
|
||||
per-point.
|
||||
|
||||
Args:
|
||||
view_sampler: An instance of ViewSampler which is used for sampling of
|
||||
image-based features at the 2D projections of a set
|
||||
of 3D points.
|
||||
feature_aggregator_class_type: The name of the feature aggregator class which
|
||||
is available in the global registry.
|
||||
feature_aggregator: A feature aggregator class which inherits from
|
||||
FeatureAggregatorBase. Typically, the aggregated features and their
|
||||
masks are output by a `ViewSampler` which samples feature tensors extracted
|
||||
from a set of source images. FeatureAggregator executes step (4) above.
|
||||
"""
|
||||
|
||||
view_sampler: ViewSampler
|
||||
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
|
||||
feature_aggregator: FeatureAggregatorBase
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
|
||||
"""
|
||||
Returns the final dimensionality of the output aggregated features.
|
||||
|
||||
Args:
|
||||
feats: Either a `dict` of sampled features `{f_i: t_i}` corresponding
|
||||
to the `feats_sampled` argument of `feature_aggregator,forward`,
|
||||
or an `int` representing the sum of dimensionalities of each `t_i`.
|
||||
|
||||
Returns:
|
||||
aggregated_feature_dim: The final dimensionality of the output
|
||||
aggregated features.
|
||||
"""
|
||||
return self.feature_aggregator.get_aggregated_feature_dim(feats)
|
||||
|
||||
def has_aggregation(self):
|
||||
"""
|
||||
Specifies whether the `feature_aggregator` reduces the output `reduce_dim`
|
||||
dimension to 1.
|
||||
|
||||
Returns:
|
||||
has_aggregation: `True` if `reduce_dim==1`, else `False`.
|
||||
"""
|
||||
return self.feature_aggregator.has_aggregation()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*, # force kw args
|
||||
pts: torch.Tensor,
|
||||
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
|
||||
camera: CamerasBase,
|
||||
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
|
||||
feats: Dict[str, torch.Tensor],
|
||||
masks: Optional[torch.Tensor],
|
||||
**kwargs,
|
||||
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
||||
"""
|
||||
Project each point cloud from a batch of point clouds to corresponding
|
||||
input cameras, sample features at the 2D projection locations in a batch
|
||||
of source images, and aggregate the pointwise sampled features.
|
||||
|
||||
Args:
|
||||
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
|
||||
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
|
||||
from which `pts` were extracted, or a list of string names.
|
||||
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
|
||||
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
|
||||
corresponding to cameras in `camera`, or a list of string names.
|
||||
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
|
||||
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
|
||||
masks: `[n_cameras x 1 x H x W]`, define valid image regions
|
||||
for sampling `feats`.
|
||||
Returns:
|
||||
feats_aggregated: If `feature_aggregator.concatenate_output==True`, a tensor
|
||||
of shape `(pts_batch, reduce_dim, n_pts, sum(dim_1, ... dim_N))`
|
||||
containing the aggregated features. `reduce_dim` depends on
|
||||
the specific feature aggregator implementation and typically
|
||||
equals 1 or `n_cameras`.
|
||||
If `feature_aggregator.concatenate_output==False`, the aggregator
|
||||
does not concatenate the aggregated features and returns a dictionary
|
||||
of per-feature aggregations `{f_i: t_i_aggregated}` instead.
|
||||
Each `t_i_aggregated` is of shape
|
||||
`(pts_batch, reduce_dim, n_pts, aggr_dim_i)`.
|
||||
"""
|
||||
|
||||
# (1) Sample features and masks at the ray points
|
||||
sampled_feats, sampled_masks = self.view_sampler(
|
||||
pts=pts,
|
||||
seq_id_pts=seq_id_pts,
|
||||
camera=camera,
|
||||
seq_id_camera=seq_id_camera,
|
||||
feats=feats,
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
# (2) Aggregate features from multiple views
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
feats_aggregated = self.feature_aggregator( # noqa: E731
|
||||
sampled_feats,
|
||||
sampled_masks,
|
||||
pts=pts,
|
||||
camera=camera,
|
||||
) # TODO: do we need to pass a callback rather than compute here?
|
||||
|
||||
return feats_aggregated
|
||||
@@ -205,7 +205,11 @@ def handle_seq_id(
|
||||
if not torch.is_tensor(seq_id):
|
||||
if isinstance(seq_id[0], str):
|
||||
seq_id = [hash(s) for s in seq_id]
|
||||
# pyre-fixme[9]: seq_id has type `Union[List[int], List[str], LongTensor]`;
|
||||
# used as `Tensor`.
|
||||
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
|
||||
# pyre-fixme[16]: Item `List` of `Union[List[int], List[str], LongTensor]` has
|
||||
# no attribute `to`.
|
||||
return seq_id.to(device)
|
||||
|
||||
|
||||
@@ -287,5 +291,7 @@ def cameras_points_cartesian_product(
|
||||
)
|
||||
.reshape(batch_pts * n_cameras)
|
||||
)
|
||||
# pyre-fixme[6]: For 1st param expected `Union[List[int], int, LongTensor]` but
|
||||
# got `Tensor`.
|
||||
camera_rep = camera[idx_cams]
|
||||
return camera_rep, pts_rep
|
||||
@@ -215,7 +215,6 @@ class BatchLinear(nn.Module):
|
||||
def last_hyper_layer_init(m) -> None:
|
||||
if type(m) == nn.Linear:
|
||||
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
|
||||
# pyre-fixme[41]: `data` cannot be reassigned. It is a read-only property.
|
||||
m.weight.data *= 1e-1
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ def volumetric_camera_overlaps(
|
||||
"""
|
||||
device = cameras.device
|
||||
ba = cameras.R.shape[0]
|
||||
n_vox = int(resol ** 3)
|
||||
n_vox = int(resol**3)
|
||||
grid = pt3d.structures.Volumes(
|
||||
densities=torch.zeros([1, 1, resol, resol, resol], device=device),
|
||||
volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
|
||||
|
||||
@@ -102,13 +102,14 @@ def fit_circle_in_2d(
|
||||
Circle2D object
|
||||
"""
|
||||
design = torch.cat([points2d, torch.ones_like(points2d[:, :1])], dim=1)
|
||||
rhs = (points2d ** 2).sum(1)
|
||||
rhs = (points2d**2).sum(1)
|
||||
n_provided = points2d.shape[0]
|
||||
if n_provided < 3:
|
||||
raise ValueError(f"{n_provided} points are not enough to determine a circle")
|
||||
solution = lstsq(design, rhs)
|
||||
center = solution[:2] / 2
|
||||
radius = torch.sqrt(solution[2] + (center ** 2).sum())
|
||||
solution = lstsq(design, rhs[:, None])
|
||||
center = solution[:2, 0] / 2
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
radius = torch.sqrt(solution[2, 0] + (center**2).sum())
|
||||
if n_points > 0:
|
||||
if angles is not None:
|
||||
warnings.warn("n_points ignored because angles provided")
|
||||
|
||||
@@ -11,6 +11,7 @@ import sys
|
||||
import warnings
|
||||
from collections import Counter, defaultdict
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf, open_dict
|
||||
@@ -175,6 +176,9 @@ _unprocessed_warning: str = (
|
||||
TYPE_SUFFIX: str = "_class_type"
|
||||
ARGS_SUFFIX: str = "_args"
|
||||
ENABLED_SUFFIX: str = "_enabled"
|
||||
CREATE_PREFIX: str = "create_"
|
||||
IMPL_SUFFIX: str = "_impl"
|
||||
TWEAK_SUFFIX: str = "_tweak_args"
|
||||
|
||||
|
||||
class ReplaceableBase:
|
||||
@@ -216,6 +220,7 @@ class Configurable:
|
||||
|
||||
|
||||
_X = TypeVar("X", bound=ReplaceableBase)
|
||||
_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
|
||||
|
||||
|
||||
class _Registry:
|
||||
@@ -259,13 +264,9 @@ class _Registry:
|
||||
raise ValueError(
|
||||
f"Cannot register {some_class}. Cannot tell what it is."
|
||||
)
|
||||
if some_class is base_class:
|
||||
raise ValueError(f"Attempted to register the base class {some_class}")
|
||||
self._mapping[base_class][name] = some_class
|
||||
|
||||
def get(
|
||||
self, base_class_wanted: Type[ReplaceableBase], name: str
|
||||
) -> Type[ReplaceableBase]:
|
||||
def get(self, base_class_wanted: Type[_X], name: str) -> Type[_X]:
|
||||
"""
|
||||
Retrieve a class from the registry by name
|
||||
|
||||
@@ -293,6 +294,7 @@ class _Registry:
|
||||
raise ValueError(
|
||||
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
|
||||
)
|
||||
# pyre-ignore[7]
|
||||
return result
|
||||
|
||||
def get_all(
|
||||
@@ -306,20 +308,23 @@ class _Registry:
|
||||
It determines the namespace.
|
||||
This will typically be a direct subclass of ReplaceableBase.
|
||||
Returns:
|
||||
list of class types
|
||||
list of class types in alphabetical order of registered name.
|
||||
"""
|
||||
if self._is_base_class(base_class_wanted):
|
||||
return list(self._mapping[base_class_wanted].values())
|
||||
source = self._mapping[base_class_wanted]
|
||||
return [source[key] for key in sorted(source)]
|
||||
|
||||
base_class = self._base_class_from_class(base_class_wanted)
|
||||
if base_class is None:
|
||||
raise ValueError(
|
||||
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
|
||||
)
|
||||
source = self._mapping[base_class]
|
||||
return [
|
||||
class_
|
||||
for class_ in self._mapping[base_class].values()
|
||||
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
|
||||
source[key]
|
||||
for key in sorted(source)
|
||||
if issubclass(source[key], base_class_wanted)
|
||||
and source[key] is not base_class_wanted
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
@@ -375,25 +380,68 @@ def _default_create(
|
||||
|
||||
Returns:
|
||||
Function taking one argument, the object whose member should be
|
||||
initialized.
|
||||
initialized, i.e. self.
|
||||
"""
|
||||
impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||
|
||||
def inner(self):
|
||||
expand_args_fields(type_)
|
||||
impl = getattr(self, impl_name)
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
setattr(self, name, type_(**args))
|
||||
impl(True, args)
|
||||
|
||||
def inner_optional(self):
|
||||
expand_args_fields(type_)
|
||||
impl = getattr(self, impl_name)
|
||||
enabled = getattr(self, name + ENABLED_SUFFIX)
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
impl(enabled, args)
|
||||
|
||||
def inner_pluggable(self):
|
||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||
impl = getattr(self, impl_name)
|
||||
if type_name is None:
|
||||
args = None
|
||||
else:
|
||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}", None)
|
||||
impl(type_name, args)
|
||||
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
return inner_optional
|
||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||
|
||||
|
||||
def _default_create_impl(
|
||||
name: str, type_: Type, process_type: _ProcessType
|
||||
) -> Callable[[Any, Any, DictConfig], None]:
|
||||
"""
|
||||
Return the default internal function for initialising a member. This is a function
|
||||
which could be called in the create_ function to initialise the member.
|
||||
|
||||
Args:
|
||||
name: name of the member
|
||||
type_: type of the member (with any Optional removed)
|
||||
process_type: Shows whether member's declared type inherits ReplaceableBase,
|
||||
in which case the actual type to be created is decided at
|
||||
runtime.
|
||||
|
||||
Returns:
|
||||
Function taking
|
||||
- self, the object whose member should be initialized.
|
||||
- option for what to do. This is
|
||||
- for pluggables, the type to initialise or None to do nothing
|
||||
- for non pluggables, a bool indicating whether to initialise.
|
||||
- the args for initializing the member.
|
||||
"""
|
||||
|
||||
def create_configurable(self, enabled, args):
|
||||
if enabled:
|
||||
args = getattr(self, name + ARGS_SUFFIX)
|
||||
expand_args_fields(type_)
|
||||
setattr(self, name, type_(**args))
|
||||
else:
|
||||
setattr(self, name, None)
|
||||
|
||||
def inner_pluggable(self):
|
||||
type_name = getattr(self, name + TYPE_SUFFIX)
|
||||
def create_pluggable(self, type_name, args):
|
||||
if type_name is None:
|
||||
setattr(self, name, None)
|
||||
return
|
||||
@@ -408,12 +456,11 @@ def _default_create(
|
||||
# were made in the redefinition will not be reflected here.
|
||||
warnings.warn(f"New implementation of {type_name} is being chosen.")
|
||||
expand_args_fields(chosen_class)
|
||||
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
|
||||
setattr(self, name, chosen_class(**args))
|
||||
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
return inner_optional
|
||||
return inner if process_type == _ProcessType.CONFIGURABLE else inner_pluggable
|
||||
if process_type in (_ProcessType.CONFIGURABLE, _ProcessType.OPTIONAL_CONFIGURABLE):
|
||||
return create_configurable
|
||||
return create_pluggable
|
||||
|
||||
|
||||
def run_auto_creation(self: Any) -> None:
|
||||
@@ -567,6 +614,9 @@ def _params_iter(C):
|
||||
|
||||
|
||||
def _is_immutable_type(type_: Type, val: Any) -> bool:
|
||||
if val is None:
|
||||
return True
|
||||
|
||||
PRIMITIVE_TYPES = (int, float, bool, str, bytes, tuple)
|
||||
# sometimes type can be too relaxed (e.g. Any), so we also check values
|
||||
if isinstance(val, PRIMITIVE_TYPES):
|
||||
@@ -601,17 +651,19 @@ def _is_actually_dataclass(some_class) -> bool:
|
||||
|
||||
|
||||
def expand_args_fields(
|
||||
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
|
||||
) -> Type[_X]:
|
||||
some_class: Type[_Y], *, _do_not_process: Tuple[type, ...] = ()
|
||||
) -> Type[_Y]:
|
||||
"""
|
||||
This expands a class which inherits Configurable or ReplaceableBase classes,
|
||||
including dataclass processing. some_class is modified in place by this function.
|
||||
If expand_args_fields(some_class) has already been called, subsequent calls do
|
||||
nothing and return some_class unmodified.
|
||||
For classes of type ReplaceableBase, you can add some_class to the registry before
|
||||
or after calling this function. But potential inner classes need to be registered
|
||||
before this function is run on the outer class.
|
||||
|
||||
The transformations this function makes, before the concluding
|
||||
dataclasses.dataclass, are as follows. if X is a base class with registered
|
||||
dataclasses.dataclass, are as follows. If X is a base class with registered
|
||||
subclasses Y and Z, replace a class member
|
||||
|
||||
x: X
|
||||
@@ -626,9 +678,12 @@ def expand_args_fields(
|
||||
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Y))
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
self.x = registry.get(X, self.x_class_type)(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
args = self.getattr(f"x_{self.x_class_type}_args")
|
||||
self.create_x_impl(self.x_class_type, args)
|
||||
def create_x_impl(self, x_type, args):
|
||||
x_type = registry.get(X, x_type)
|
||||
expand_args_fields(x_type)
|
||||
self.x = x_type(**args)
|
||||
x_class_type: str = "UNDEFAULTED"
|
||||
|
||||
without adding the optional attributes if they are already there.
|
||||
@@ -648,12 +703,19 @@ def expand_args_fields(
|
||||
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(Z))
|
||||
def create_x(self):
|
||||
if self.x_class_type is None:
|
||||
args = None
|
||||
else:
|
||||
args = self.getattr(f"x_{self.x_class_type}_args", None)
|
||||
self.create_x_impl(self.x_class_type, args)
|
||||
def create_x_impl(self, x_class_type, args):
|
||||
if x_class_type is None:
|
||||
self.x = None
|
||||
return
|
||||
|
||||
self.x = registry.get(X, self.x_class_type)(
|
||||
**self.getattr(f"x_{self.x_class_type}_args)
|
||||
)
|
||||
x_type = registry.get(X, x_class_type)
|
||||
expand_args_fields(x_type)
|
||||
assert args is not None
|
||||
self.x = x_type(**args)
|
||||
x_class_type: Optional[str] = "UNDEFAULTED"
|
||||
|
||||
without adding the optional attributes if they are already there.
|
||||
@@ -670,7 +732,14 @@ def expand_args_fields(
|
||||
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
def create_x(self):
|
||||
self.x = X(self.x_args)
|
||||
self.create_x_impl(True, self.x_args)
|
||||
|
||||
def create_x_impl(self, enabled, args):
|
||||
if enabled:
|
||||
expand_args_fields(X)
|
||||
self.x = X(**args)
|
||||
else:
|
||||
self.x = None
|
||||
|
||||
Similarly, replace,
|
||||
|
||||
@@ -686,8 +755,12 @@ def expand_args_fields(
|
||||
x_args : DictConfig = dataclasses.field(default_factory=lambda: get_default_args(X))
|
||||
x_enabled: bool = False
|
||||
def create_x(self):
|
||||
if self.x_enabled:
|
||||
self.x = X(self.x_args)
|
||||
self.create_x_impl(self.x_enabled, self.x_args)
|
||||
|
||||
def create_x_impl(self, enabled, args):
|
||||
if enabled:
|
||||
expand_args_fields(X)
|
||||
self.x = X(**args)
|
||||
else:
|
||||
self.x = None
|
||||
|
||||
@@ -695,7 +768,7 @@ def expand_args_fields(
|
||||
Also adds the following class members, unannotated so that dataclass
|
||||
ignores them.
|
||||
- _creation_functions: Tuple[str] of all the create_ functions,
|
||||
including those from base classes.
|
||||
including those from base classes (not the create_x_impl ones).
|
||||
- _known_implementations: Dict[str, Type] containing the classes which
|
||||
have been found from the registry.
|
||||
(used only to raise a warning if it one has been overwritten)
|
||||
@@ -703,6 +776,14 @@ def expand_args_fields(
|
||||
transformed, with values giving the types they were declared to have.
|
||||
(E.g. {"x": X} or {"x": Optional[X]} in the cases above.)
|
||||
|
||||
In addition, if the class has a member function
|
||||
|
||||
@classmethod
|
||||
def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None
|
||||
|
||||
then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and
|
||||
the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args).
|
||||
|
||||
Args:
|
||||
some_class: the class to be processed
|
||||
_do_not_process: Internal use for get_default_args: Because get_default_args calls
|
||||
@@ -779,19 +860,29 @@ def expand_args_fields(
|
||||
return some_class
|
||||
|
||||
|
||||
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
|
||||
def get_default_args_field(
|
||||
C,
|
||||
*,
|
||||
_do_not_process: Tuple[type, ...] = (),
|
||||
_hook: Optional[Callable[[DictConfig], None]] = None,
|
||||
):
|
||||
"""
|
||||
Get a dataclass field which defaults to get_default_args(...)
|
||||
|
||||
Args:
|
||||
As for get_default_args.
|
||||
C: As for get_default_args.
|
||||
_do_not_process: As for get_default_args
|
||||
_hook: Function called on the result before returning.
|
||||
|
||||
Returns:
|
||||
function to return new DictConfig object
|
||||
"""
|
||||
|
||||
def create():
|
||||
return get_default_args(C, _do_not_process=_do_not_process)
|
||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||
if _hook is not None:
|
||||
_hook(args)
|
||||
return args
|
||||
|
||||
return dataclasses.field(default_factory=create)
|
||||
|
||||
@@ -854,6 +945,7 @@ def _process_member(
|
||||
# sure they go at the end of __annotations__ in case
|
||||
# there are non-defaulted standard class members.
|
||||
del some_class.__annotations__[name]
|
||||
hook = getattr(some_class, name + TWEAK_SUFFIX, None)
|
||||
|
||||
if process_type in (_ProcessType.REPLACEABLE, _ProcessType.OPTIONAL_REPLACEABLE):
|
||||
type_name = name + TYPE_SUFFIX
|
||||
@@ -879,11 +971,17 @@ def _process_member(
|
||||
f"Cannot generate {args_name} because it is already present."
|
||||
)
|
||||
some_class.__annotations__[args_name] = DictConfig
|
||||
if hook is not None:
|
||||
hook_closed = partial(hook, derived_type)
|
||||
else:
|
||||
hook_closed = None
|
||||
setattr(
|
||||
some_class,
|
||||
args_name,
|
||||
get_default_args_field(
|
||||
derived_type, _do_not_process=_do_not_process + (some_class,)
|
||||
derived_type,
|
||||
_do_not_process=_do_not_process + (some_class,),
|
||||
_hook=hook_closed,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -896,12 +994,17 @@ def _process_member(
|
||||
raise ValueError(f"Cannot process {type_} inside {some_class}")
|
||||
|
||||
some_class.__annotations__[args_name] = DictConfig
|
||||
if hook is not None:
|
||||
hook_closed = partial(hook, type_)
|
||||
else:
|
||||
hook_closed = None
|
||||
setattr(
|
||||
some_class,
|
||||
args_name,
|
||||
get_default_args_field(
|
||||
type_,
|
||||
_do_not_process=_do_not_process + (some_class,),
|
||||
_hook=hook_closed,
|
||||
),
|
||||
)
|
||||
if process_type == _ProcessType.OPTIONAL_CONFIGURABLE:
|
||||
@@ -910,7 +1013,7 @@ def _process_member(
|
||||
some_class.__annotations__[enabled_name] = bool
|
||||
setattr(some_class, enabled_name, False)
|
||||
|
||||
creation_function_name = f"create_{name}"
|
||||
creation_function_name = f"{CREATE_PREFIX}{name}"
|
||||
if not hasattr(some_class, creation_function_name):
|
||||
setattr(
|
||||
some_class,
|
||||
@@ -919,6 +1022,14 @@ def _process_member(
|
||||
)
|
||||
creation_functions.append(creation_function_name)
|
||||
|
||||
creation_function_impl_name = f"{CREATE_PREFIX}{name}{IMPL_SUFFIX}"
|
||||
if not hasattr(some_class, creation_function_impl_name):
|
||||
setattr(
|
||||
some_class,
|
||||
creation_function_impl_name,
|
||||
_default_create_impl(name, type_, process_type),
|
||||
)
|
||||
|
||||
|
||||
def remove_unused_components(dict_: DictConfig) -> None:
|
||||
"""
|
||||
|
||||
@@ -50,6 +50,7 @@ def cleanup_eval_depth(
|
||||
# the threshold is a sigma-multiple of the standard deviation of the depth
|
||||
mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
|
||||
std = (
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
||||
wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
|
||||
.clamp(1e-4)
|
||||
.sqrt()
|
||||
@@ -58,11 +59,10 @@ def cleanup_eval_depth(
|
||||
good_df_thr = std * sigma
|
||||
good_depth = (df <= good_df_thr).float() * pcl_mask
|
||||
|
||||
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
|
||||
# perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
|
||||
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
|
||||
|
||||
good_depth_raster = torch.zeros_like(depth).view(ba, -1)
|
||||
# pyre-ignore[16]: scatter_add_
|
||||
good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)
|
||||
|
||||
good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()
|
||||
|
||||
@@ -21,9 +21,9 @@ def generate_eval_video_cameras(
|
||||
trajectory_scale: float = 0.2,
|
||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
|
||||
focal_length: Optional[torch.FloatTensor] = None,
|
||||
principal_point: Optional[torch.FloatTensor] = None,
|
||||
time: Optional[torch.FloatTensor] = None,
|
||||
focal_length: Optional[torch.Tensor] = None,
|
||||
principal_point: Optional[torch.Tensor] = None,
|
||||
time: Optional[torch.Tensor] = None,
|
||||
infer_up_as_plane_normal: bool = True,
|
||||
traj_offset: Optional[Tuple[float, float, float]] = None,
|
||||
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
||||
@@ -200,9 +200,6 @@ def _visdom_plot_scene(
|
||||
|
||||
viz = Visdom()
|
||||
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
|
||||
|
||||
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Union
|
||||
from typing import Sequence, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -14,7 +14,7 @@ def mask_background(
|
||||
image_rgb: torch.Tensor,
|
||||
mask_fg: torch.Tensor,
|
||||
dim_color: int = 1,
|
||||
bg_color: Union[torch.Tensor, str, float] = 0.0,
|
||||
bg_color: Union[torch.Tensor, Sequence, str, float] = 0.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Mask the background input image tensor `image_rgb` with `bg_color`.
|
||||
@@ -26,9 +26,11 @@ def mask_background(
|
||||
# obtain the background color tensor
|
||||
if isinstance(bg_color, torch.Tensor):
|
||||
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
|
||||
elif isinstance(bg_color, float):
|
||||
elif isinstance(bg_color, (float, tuple, list)):
|
||||
if isinstance(bg_color, float):
|
||||
bg_color = [bg_color] * 3
|
||||
bg_color_t = torch.tensor(
|
||||
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
|
||||
bg_color, device=image_rgb.device, dtype=image_rgb.dtype
|
||||
).view(*tgt_view)
|
||||
elif isinstance(bg_color, str):
|
||||
if bg_color == "white":
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user