mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-04 11:16:01 +08:00
Compare commits
1 Commits
v0.7.0
...
classner-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e7c1f026ea |
@@ -159,7 +159,7 @@ jobs:
|
|||||||
binary_macos_wheel:
|
binary_macos_wheel:
|
||||||
<<: *binary_common
|
<<: *binary_common
|
||||||
macos:
|
macos:
|
||||||
xcode: "13.4.1"
|
xcode: "12.0"
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
|
|||||||
@@ -159,7 +159,7 @@ jobs:
|
|||||||
binary_macos_wheel:
|
binary_macos_wheel:
|
||||||
<<: *binary_common
|
<<: *binary_common
|
||||||
macos:
|
macos:
|
||||||
xcode: "13.4.1"
|
xcode: "12.0"
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
- run:
|
- run:
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ Key features include:
|
|||||||
- Data structure for storing and manipulating triangle meshes
|
- Data structure for storing and manipulating triangle meshes
|
||||||
- Efficient operations on triangle meshes (projective transformations, graph convolution, sampling, loss functions)
|
- Efficient operations on triangle meshes (projective transformations, graph convolution, sampling, loss functions)
|
||||||
- A differentiable mesh renderer
|
- A differentiable mesh renderer
|
||||||
- Implicitron, see [its README](projects/implicitron_trainer), a framework for new-view synthesis via implicit representations.
|
|
||||||
|
|
||||||
PyTorch3D is designed to integrate smoothly with deep learning methods for predicting and manipulating 3D data.
|
PyTorch3D is designed to integrate smoothly with deep learning methods for predicting and manipulating 3D data.
|
||||||
For this reason, all operators in PyTorch3D:
|
For this reason, all operators in PyTorch3D:
|
||||||
@@ -94,7 +93,6 @@ In alphabetical order:
|
|||||||
|
|
||||||
* Amitav Baruah
|
* Amitav Baruah
|
||||||
* Steve Branson
|
* Steve Branson
|
||||||
* Krzysztof Chalupka
|
|
||||||
* Luya Gao
|
* Luya Gao
|
||||||
* Georgia Gkioxari
|
* Georgia Gkioxari
|
||||||
* Taylor Gordon
|
* Taylor Gordon
|
||||||
|
|||||||
@@ -89,7 +89,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -76,7 +76,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
# Acknowledgements
|
# Acknowledgements
|
||||||
|
|
||||||
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
|
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
|
||||||
|
|
||||||
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
||||||
|
|||||||
@@ -51,7 +51,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -90,7 +90,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -56,7 +56,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -68,7 +68,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -47,7 +47,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,913 +0,0 @@
|
|||||||
{
|
|
||||||
"cells": [
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659619824914,
|
|
||||||
"executionStopTime": 1659619825485,
|
|
||||||
"originalKey": "d38652e8-200a-413c-a36a-f4d349b78a9d",
|
|
||||||
"requestMsgId": "641de8aa-0e42-4446-9304-c160a2d226bf",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "a48a9dcf-e80f-474b-a0c4-2c9a765b15c5",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"# A simple model using Implicitron\n",
|
|
||||||
"\n",
|
|
||||||
"In this demo, we use the VolumeRenderer from PyTorch3D as a custom implicit function in Implicitron. We will see\n",
|
|
||||||
"* some of the main objects in Implicitron\n",
|
|
||||||
"* how to plug in a custom part of a model"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "51337c0e-ad27-4b75-ad6a-737dca5d7b95",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## 0. Install and import modules\n",
|
|
||||||
"\n",
|
|
||||||
"Ensure `torch` and `torchvision` are installed. If `pytorch3d` is not installed, install it using the following cell:\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659619898147,
|
|
||||||
"executionStopTime": 1659619898274,
|
|
||||||
"originalKey": "76f1ecd4-6b73-4214-81b0-118ef8d86872",
|
|
||||||
"requestMsgId": "deb6a860-6923-4227-abef-d31388b5142d",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import os\n",
|
|
||||||
"import sys\n",
|
|
||||||
"import torch\n",
|
|
||||||
"need_pytorch3d=False\n",
|
|
||||||
"try:\n",
|
|
||||||
" import pytorch3d\n",
|
|
||||||
"except ModuleNotFoundError:\n",
|
|
||||||
" need_pytorch3d=True\n",
|
|
||||||
"if need_pytorch3d:\n",
|
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
|
||||||
" version_str=\"\".join([\n",
|
|
||||||
" f\"py3{sys.version_info.minor}_cu\",\n",
|
|
||||||
" torch.version.cuda.replace(\".\",\"\"),\n",
|
|
||||||
" f\"_pyt{pyt_version_str}\"\n",
|
|
||||||
" ])\n",
|
|
||||||
" !pip install fvcore iopath\n",
|
|
||||||
" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
|
||||||
" else:\n",
|
|
||||||
" # We try to install PyTorch3D from source.\n",
|
|
||||||
" !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz\n",
|
|
||||||
" !tar xzf 1.10.0.tar.gz\n",
|
|
||||||
" os.environ[\"CUB_HOME\"] = os.getcwd() + \"/cub-1.10.0\"\n",
|
|
||||||
" !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "2c1020e6-eb4a-4644-9719-9147500d8e4f",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"Ensure omegaconf and visdom are installed. If not, run this cell. (It should not be necessary to restart the runtime.)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"originalKey": "9e751931-a38d-44c9-9ff1-ac2f7d3a3f99",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!pip install omegaconf visdom"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659612480556,
|
|
||||||
"executionStopTime": 1659612480644,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "86807e4a-1675-4520-a033-c7af85b233ec",
|
|
||||||
"requestMsgId": "880a7e20-4a90-4b37-a5eb-bccc0b23cac6"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"import logging\n",
|
|
||||||
"from typing import Tuple\n",
|
|
||||||
"\n",
|
|
||||||
"import matplotlib.animation as animation\n",
|
|
||||||
"import matplotlib.pyplot as plt\n",
|
|
||||||
"import numpy as np\n",
|
|
||||||
"import torch\n",
|
|
||||||
"import tqdm\n",
|
|
||||||
"from IPython.display import HTML\n",
|
|
||||||
"from omegaconf import OmegaConf\n",
|
|
||||||
"from PIL import Image\n",
|
|
||||||
"from pytorch3d.implicitron.dataset.dataset_base import FrameData\n",
|
|
||||||
"from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider\n",
|
|
||||||
"from pytorch3d.implicitron.models.generic_model import GenericModel\n",
|
|
||||||
"from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase\n",
|
|
||||||
"from pytorch3d.implicitron.models.renderer.base import EvaluationMode\n",
|
|
||||||
"from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components\n",
|
|
||||||
"from pytorch3d.renderer import RayBundle\n",
|
|
||||||
"from pytorch3d.renderer.implicit.renderer import VolumeSampler\n",
|
|
||||||
"from pytorch3d.structures import Volumes\n",
|
|
||||||
"from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659610929375,
|
|
||||||
"executionStopTime": 1659610929383,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "b2d9f5bd-a9d4-4f78-b21e-92f2658e0fe9",
|
|
||||||
"requestMsgId": "7e43e623-4030-438b-af4e-b96170c9a052",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"output_resolution = 80"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659610930042,
|
|
||||||
"executionStopTime": 1659610930050,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "0b0c2087-4c86-4c57-b0ee-6f48a70a9c78",
|
|
||||||
"requestMsgId": "46883aad-f00b-4fd4-ac17-eec0b2ac272a",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"torch.set_printoptions(sci_mode=False)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "37809d0d-b02e-42df-85b6-cdd038373653",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## 1. Load renders of a mesh (the cow mesh) as a dataset\n",
|
|
||||||
"\n",
|
|
||||||
"A dataset's train, val and test parts in Implicitron are represented as a `dataset_map`, and provided by an implementation of `DatasetMapProvider`. \n",
|
|
||||||
"`RenderedMeshDatasetMapProvider` is one which generates a single-scene dataset with only a train component by taking a mesh and rendering it.\n",
|
|
||||||
"We use it with the cow mesh."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659620739780,
|
|
||||||
"executionStopTime": 1659620739914,
|
|
||||||
"originalKey": "cc68cb9c-b8bf-4e9e-bef1-2cfafdf6caa2",
|
|
||||||
"requestMsgId": "398cfcae-5d43-4b6f-9c75-db3d297364d4",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"If running this notebook using **Google Colab**, run the following cell to fetch the mesh obj and texture files and save it at the path data/cow_mesh.\n",
|
|
||||||
"If running locally, the data is already available at the correct path."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"originalKey": "2c55e002-a885-4169-8fdc-af9078b05968",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"!mkdir -p data/cow_mesh\n",
|
|
||||||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.obj\n",
|
|
||||||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow.mtl\n",
|
|
||||||
"!wget -P data/cow_mesh https://dl.fbaipublicfiles.com/pytorch3d/data/cow_mesh/cow_texture.png"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "2a976be8-01bf-4a1c-a6e7-61d5d08c3dbd",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"If we want to instantiate one of Implicitron's configurable objects, such as `RenderedMeshDatasetMapProvider`, without using the OmegaConf initialisation (get_default_args), we need to call `expand_args_fields` on the class first."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621652237,
|
|
||||||
"executionStopTime": 1659621652903,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "eb77aaec-048c-40bd-bd69-0e66b6ab60b1",
|
|
||||||
"requestMsgId": "09b9975c-ff86-41c9-b4a9-975d23afc562",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"expand_args_fields(RenderedMeshDatasetMapProvider)\n",
|
|
||||||
"cow_provider = RenderedMeshDatasetMapProvider(\n",
|
|
||||||
" data_file=\"data/cow_mesh/cow.obj\",\n",
|
|
||||||
" use_point_light=False,\n",
|
|
||||||
" resolution=output_resolution,\n",
|
|
||||||
")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659610966145,
|
|
||||||
"executionStopTime": 1659610966255,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "8210e15b-da48-4306-a49a-41c4e7e7d42f",
|
|
||||||
"requestMsgId": "c243edd2-a106-4fba-8471-dfa4f99a2088",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"dataset_map = cow_provider.get_dataset_map()\n",
|
|
||||||
"tr_cameras = [training_frame.camera for training_frame in dataset_map.train]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659610967703,
|
|
||||||
"executionStopTime": 1659610967848,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "458d72ad-d9a7-4f13-b5b7-90d2aec61c16",
|
|
||||||
"requestMsgId": "7f9431f3-8717-4d89-a7fe-1420dd0e00c4",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# The cameras are all in the XZ plane, in a circle about 2.7 from the origin\n",
|
|
||||||
"centers = torch.cat([i.get_camera_center() for i in tr_cameras])\n",
|
|
||||||
"print(centers.min(0).values)\n",
|
|
||||||
"print(centers.max(0).values)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659552920194,
|
|
||||||
"executionStopTime": 1659552923122,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "931e712b-b141-437a-97fb-dc2a07ce3458",
|
|
||||||
"requestMsgId": "931e712b-b141-437a-97fb-dc2a07ce3458",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# visualization of the cameras\n",
|
|
||||||
"plot = plot_scene({\"k\": {i: camera for i, camera in enumerate(tr_cameras)}}, camera_scale=0.25)\n",
|
|
||||||
"plot.layout.scene.aspectmode = \"data\"\n",
|
|
||||||
"plot"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "afa9c02d-f76b-4f68-83e9-9733c615406b",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## 2. Custom implicit function 🧊\n",
|
|
||||||
"\n",
|
|
||||||
"At the core of neural rendering methods are functions of spatial coordinates called implicit functions, which are used in some kind of rendering process.\n",
|
|
||||||
"(Often those functions can additionally take other data as well, such as view direction.)\n",
|
|
||||||
"A common rendering process is ray marching over densities and colors provided by an implicit function.\n",
|
|
||||||
"In our case, taking samples from a 3D volume grid is a very simple function of spatial coordinates. \n",
|
|
||||||
"\n",
|
|
||||||
"Here we define our own implicit function, which uses PyTorch3D's existing functionality for sampling from a volume grid.\n",
|
|
||||||
"We do this by subclassing `ImplicitFunctionBase`.\n",
|
|
||||||
"We need to register our subclass with a special decorator.\n",
|
|
||||||
"We use Python's dataclass annotations for configuring the module."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659613575850,
|
|
||||||
"executionStopTime": 1659613575940,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "61b55043-dc52-4de7-992e-e2195edd2123",
|
|
||||||
"requestMsgId": "dfaace3c-098c-4ffe-9240-6a7ae0ff271e",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"@registry.register\n",
|
|
||||||
"class MyVolumes(ImplicitFunctionBase, torch.nn.Module):\n",
|
|
||||||
" grid_resolution: int = 50 # common HWD of volumes, the number of voxels in each direction\n",
|
|
||||||
" extent: float = 1.0 # In world coordinates, the volume occupies is [-extent, extent] along each axis\n",
|
|
||||||
"\n",
|
|
||||||
" def __post_init__(self):\n",
|
|
||||||
" # We have to call this explicitly if there are other base classes like Module\n",
|
|
||||||
" super().__init__()\n",
|
|
||||||
"\n",
|
|
||||||
" # We define parameters like other torch.nn.Module objects.\n",
|
|
||||||
" # In this case, both our parameter tensors are trainable; they govern the contents of the volume grid.\n",
|
|
||||||
" density = torch.full((self.grid_resolution, self.grid_resolution, self.grid_resolution), -2.0)\n",
|
|
||||||
" self.density = torch.nn.Parameter(density)\n",
|
|
||||||
" color = torch.full((3, self.grid_resolution, self.grid_resolution, self.grid_resolution), 0.0)\n",
|
|
||||||
" self.color = torch.nn.Parameter(color)\n",
|
|
||||||
" self.density_activation = torch.nn.Softplus()\n",
|
|
||||||
"\n",
|
|
||||||
" def forward(\n",
|
|
||||||
" self,\n",
|
|
||||||
" ray_bundle: RayBundle,\n",
|
|
||||||
" fun_viewpool=None,\n",
|
|
||||||
" global_code=None,\n",
|
|
||||||
" ):\n",
|
|
||||||
" densities = self.density_activation(self.density[None, None])\n",
|
|
||||||
" voxel_size = 2.0 * float(self.extent) / self.grid_resolution\n",
|
|
||||||
" features = self.color.sigmoid()[None]\n",
|
|
||||||
"\n",
|
|
||||||
" # Like other PyTorch3D structures, the actual Volumes object should only exist as long\n",
|
|
||||||
" # as one iteration of training. It is local to this function.\n",
|
|
||||||
"\n",
|
|
||||||
" volume = Volumes(densities=densities, features=features, voxel_size=voxel_size)\n",
|
|
||||||
" sampler = VolumeSampler(volumes=volume)\n",
|
|
||||||
" densities, features = sampler(ray_bundle)\n",
|
|
||||||
"\n",
|
|
||||||
" # When an implicit function is used for raymarching, i.e. for MultiPassEmissionAbsorptionRenderer,\n",
|
|
||||||
" # it must return (densities, features, an auxiliary tuple)\n",
|
|
||||||
" return densities, features, {}\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "abaf2cd6-1b68-400e-a142-8fb9f49953f3",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## 3. Construct the model object.\n",
|
|
||||||
"\n",
|
|
||||||
"The main model object in PyTorch3D is `GenericModel`, which has pluggable components for the major steps, including the renderer and the implicit function(s).\n",
|
|
||||||
"There are two ways to construct it which are equivalent here."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621267561,
|
|
||||||
"executionStopTime": 1659621267938,
|
|
||||||
"originalKey": "f26c3dce-fbae-4592-bd0e-e4a8abc57c2c",
|
|
||||||
"requestMsgId": "9213687e-1caf-46a8-a4e5-a9c531530092",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"CONSTRUCT_MODEL_FROM_CONFIG = True\n",
|
|
||||||
"if CONSTRUCT_MODEL_FROM_CONFIG:\n",
|
|
||||||
" # Via a DictConfig - this is how our training loop with hydra works\n",
|
|
||||||
" cfg = get_default_args(GenericModel)\n",
|
|
||||||
" cfg.implicit_function_class_type = \"MyVolumes\"\n",
|
|
||||||
" cfg.render_image_height=output_resolution\n",
|
|
||||||
" cfg.render_image_width=output_resolution\n",
|
|
||||||
" cfg.loss_weights={\"loss_rgb_huber\": 1.0}\n",
|
|
||||||
" cfg.tqdm_trigger_threshold=19000\n",
|
|
||||||
" cfg.raysampler_AdaptiveRaySampler_args.scene_extent= 4.0\n",
|
|
||||||
" gm = GenericModel(**cfg)\n",
|
|
||||||
"else:\n",
|
|
||||||
" # constructing GenericModel directly\n",
|
|
||||||
" expand_args_fields(GenericModel)\n",
|
|
||||||
" gm = GenericModel(\n",
|
|
||||||
" implicit_function_class_type=\"MyVolumes\",\n",
|
|
||||||
" render_image_height=output_resolution,\n",
|
|
||||||
" render_image_width=output_resolution,\n",
|
|
||||||
" loss_weights={\"loss_rgb_huber\": 1.0},\n",
|
|
||||||
" tqdm_trigger_threshold=19000,\n",
|
|
||||||
" raysampler_AdaptiveRaySampler_args = {\"scene_extent\": 4.0}\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" # In this case we can get the equivalent DictConfig cfg object to the way gm is configured as follows\n",
|
|
||||||
" cfg = OmegaConf.structured(gm)\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659611214689,
|
|
||||||
"executionStopTime": 1659611214748,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "4e659f7d-ce66-4999-83de-005eb09d7705",
|
|
||||||
"requestMsgId": "7b815b2b-cf19-44d0-ae89-76fde6df35ec",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
" The default renderer is an emission-absorbtion raymarcher. We keep that default."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621268007,
|
|
||||||
"executionStopTime": 1659621268190,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "d37ae488-c57c-44d3-9def-825dc1a6495b",
|
|
||||||
"requestMsgId": "71143ec1-730f-4876-8a14-e46eea9d6dd1",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# We can display the configuration in use as follows.\n",
|
|
||||||
"remove_unused_components(cfg)\n",
|
|
||||||
"yaml = OmegaConf.to_yaml(cfg, sort_keys=False)\n",
|
|
||||||
"%page -r yaml"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621268727,
|
|
||||||
"executionStopTime": 1659621268776,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "52e53179-3c6e-4c1f-a38a-3a6d803687bb",
|
|
||||||
"requestMsgId": "05de9bc3-3f74-4a6f-851c-9ec919b59506",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"device = torch.device(\"cuda:0\")\n",
|
|
||||||
"gm.to(device)\n",
|
|
||||||
"assert next(gm.parameters()).is_cuda"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"customInput": null,
|
|
||||||
"originalKey": "528a7d53-c645-49c2-9021-09adbb18cd23",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## 4. train the model "
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621270236,
|
|
||||||
"executionStopTime": 1659621270446,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "953280bd-3161-42ba-8dcb-0c8ef2d5cc25",
|
|
||||||
"requestMsgId": "9bba424b-7bfd-4e5a-9d79-ae316e20bab0",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621270815,
|
|
||||||
"executionStopTime": 1659621270948,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "2fcf07f0-0c28-49c7-8c76-1c9a9d810167",
|
|
||||||
"requestMsgId": "821deb43-6084-4ece-83c3-dee214562c47",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"gm.train()\n",
|
|
||||||
"optimizer = torch.optim.Adam(gm.parameters(), lr=0.1)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621271875,
|
|
||||||
"executionStopTime": 1659621298146,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "105099f7-ed0c-4e7f-a976-61a93fd0a8fe",
|
|
||||||
"requestMsgId": "0c87c108-83e3-4129-ad02-85e0140f1368",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"iterator = tqdm.tqdm(range(2000))\n",
|
|
||||||
"for n_batch in iterator:\n",
|
|
||||||
" optimizer.zero_grad()\n",
|
|
||||||
"\n",
|
|
||||||
" frame = train_data_collated[n_batch % len(dataset_map.train)]\n",
|
|
||||||
" out = gm(**frame, evaluation_mode=EvaluationMode.TRAINING)\n",
|
|
||||||
" out[\"objective\"].backward()\n",
|
|
||||||
" if n_batch % 100 == 0:\n",
|
|
||||||
" iterator.set_postfix_str(f\"loss: {float(out['objective']):.5f}\")\n",
|
|
||||||
" optimizer.step()"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659535024768,
|
|
||||||
"executionStopTime": 1659535024906,
|
|
||||||
"originalKey": "e3cd494a-536b-48bc-8290-c048118c82eb",
|
|
||||||
"requestMsgId": "e3cd494a-536b-48bc-8290-c048118c82eb",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"## 5. Evaluate the module\n",
|
|
||||||
"\n",
|
|
||||||
"We generate complete images from all the viewpoints to see how they look."
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621299859,
|
|
||||||
"executionStopTime": 1659621311133,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "fbe1b2ea-cc24-4b20-a2d7-0249185e34a5",
|
|
||||||
"requestMsgId": "771ef1f8-5eee-4932-9e81-33604bf0512a",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"def to_numpy_image(image):\n",
|
|
||||||
" # Takes an image of shape (C, H, W) in [0,1], where C=3 or 1\n",
|
|
||||||
" # to a numpy uint image of shape (H, W, 3)\n",
|
|
||||||
" return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()\n",
|
|
||||||
"def resize_image(image):\n",
|
|
||||||
" # Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)\n",
|
|
||||||
" return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))\n",
|
|
||||||
"\n",
|
|
||||||
"gm.eval()\n",
|
|
||||||
"images = []\n",
|
|
||||||
"expected = []\n",
|
|
||||||
"masks = []\n",
|
|
||||||
"masks_expected = []\n",
|
|
||||||
"for frame in tqdm.tqdm(train_data_collated):\n",
|
|
||||||
" with torch.no_grad():\n",
|
|
||||||
" out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)\n",
|
|
||||||
"\n",
|
|
||||||
" image_rgb = to_numpy_image(out[\"images_render\"][0])\n",
|
|
||||||
" mask = to_numpy_image(out[\"masks_render\"][0])\n",
|
|
||||||
" expd = to_numpy_image(resize_image(frame.image_rgb)[0])\n",
|
|
||||||
" mask_expected = to_numpy_image(resize_image(frame.fg_probability)[0])\n",
|
|
||||||
"\n",
|
|
||||||
" images.append(image_rgb)\n",
|
|
||||||
" masks.append(mask)\n",
|
|
||||||
" expected.append(expd)\n",
|
|
||||||
" masks_expected.append(mask_expected)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"attachments": {},
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659614622542,
|
|
||||||
"executionStopTime": 1659614622757,
|
|
||||||
"originalKey": "24953039-9780-40fd-bd81-5d63e9f40069",
|
|
||||||
"requestMsgId": "7af895a3-dfe4-4c28-ac3b-4ff0fbb40c7f",
|
|
||||||
"showInput": false
|
|
||||||
},
|
|
||||||
"source": [
|
|
||||||
"We draw a grid showing predicted image and expected image, followed by predicted mask and expected mask, from each viewpoint. \n",
|
|
||||||
"This is a grid of four rows of images, wrapped in to several large rows, i.e..\n",
|
|
||||||
"<small><center>\n",
|
|
||||||
"```\n",
|
|
||||||
"┌────────┬────────┐ ┌────────┐\n",
|
|
||||||
"│pred │pred │ │pred │\n",
|
|
||||||
"│image │image │ │image │\n",
|
|
||||||
"│1 │2 │ │n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│expected│expected│ │expected│\n",
|
|
||||||
"│image │image │ ... │image │\n",
|
|
||||||
"│1 │2 │ │n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│pred │pred │ │pred │\n",
|
|
||||||
"│mask │mask │ │mask │\n",
|
|
||||||
"│1 │2 │ │n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│expected│expected│ │expected│\n",
|
|
||||||
"│mask │mask │ │mask │\n",
|
|
||||||
"│1 │2 │ │n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│pred │pred │ │pred │\n",
|
|
||||||
"│image │image │ │image │\n",
|
|
||||||
"│n+1 │n+1 │ │2n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│expected│expected│ │expected│\n",
|
|
||||||
"│image │image │ ... │image │\n",
|
|
||||||
"│n+1 │n+2 │ │2n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│pred │pred │ │pred │\n",
|
|
||||||
"│mask │mask │ │mask │\n",
|
|
||||||
"│n+1 │n+2 │ │2n │\n",
|
|
||||||
"├────────┼────────┤ ├────────┤\n",
|
|
||||||
"│expected│expected│ │expected│\n",
|
|
||||||
"│mask │mask │ │mask │\n",
|
|
||||||
"│n+1 │n+2 │ │2n │\n",
|
|
||||||
"└────────┴────────┘ └────────┘\n",
|
|
||||||
" ...\n",
|
|
||||||
"```\n",
|
|
||||||
"</center></small>"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621313894,
|
|
||||||
"executionStopTime": 1659621314042,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "c488a34a-e46d-4649-93fb-4b1bb5a0e439",
|
|
||||||
"requestMsgId": "4221e632-fca1-4fe5-b2e3-f92c37aa40e4",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]\n",
|
|
||||||
"n_rows = 4\n",
|
|
||||||
"n_images = len(images)\n",
|
|
||||||
"blank_image = images[0] * 0\n",
|
|
||||||
"n_per_row = 1+(n_images-1)//n_rows\n",
|
|
||||||
"for _ in range(n_per_row*n_rows - n_images):\n",
|
|
||||||
" for group in images_to_display:\n",
|
|
||||||
" group.append(blank_image)\n",
|
|
||||||
"\n",
|
|
||||||
"images_to_display_listed = [[[i] for i in j] for j in images_to_display]\n",
|
|
||||||
"split = []\n",
|
|
||||||
"for row in range(n_rows):\n",
|
|
||||||
" for group in images_to_display_listed:\n",
|
|
||||||
" split.append(group[row*n_per_row:(row+1)*n_per_row]) \n",
|
|
||||||
"\n",
|
|
||||||
"Image.fromarray(np.block(split))\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621323795,
|
|
||||||
"executionStopTime": 1659621323820,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "49eab9e1-4fe2-4fbe-b4f3-7b6953340170",
|
|
||||||
"requestMsgId": "85b402ad-f903-431f-a13e-c2d697e869bb",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# Print the maximum channel intensity in the first image.\n",
|
|
||||||
"print(images[1].max()/255)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"code_folding": [],
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621408642,
|
|
||||||
"executionStopTime": 1659621409559,
|
|
||||||
"hidden_ranges": [],
|
|
||||||
"originalKey": "137d2c43-d39d-4266-ac5e-2b714da5e0ee",
|
|
||||||
"requestMsgId": "8e27ec57-c2d6-4ae0-be69-b63b6af929ff",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"plt.ioff()\n",
|
|
||||||
"fig, ax = plt.subplots(figsize=(3,3))\n",
|
|
||||||
"\n",
|
|
||||||
"ax.grid(None)\n",
|
|
||||||
"ims = [[ax.imshow(im, animated=True)] for im in images]\n",
|
|
||||||
"ani = animation.ArtistAnimation(fig, ims, interval=80, blit=True)\n",
|
|
||||||
"ani_html = ani.to_jshtml()\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659621409620,
|
|
||||||
"executionStopTime": 1659621409725,
|
|
||||||
"originalKey": "783e70d6-7cf1-4d76-a126-ba11ffc2f5be",
|
|
||||||
"requestMsgId": "b6843506-c5fa-4508-80fc-8ecae51a934a",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"HTML(ani_html)"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"customInput": null,
|
|
||||||
"customOutput": null,
|
|
||||||
"executionStartTime": 1659614670081,
|
|
||||||
"executionStopTime": 1659614670168,
|
|
||||||
"originalKey": "0286c350-2362-4f47-8181-2fc2ba51cfcf",
|
|
||||||
"requestMsgId": "976f4db9-d4c7-466c-bcfd-218234400226",
|
|
||||||
"showInput": true
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"# If you want to see the output of the model with the volume forced to opaque white, run this and re-evaluate\n",
|
|
||||||
"# with torch.no_grad():\n",
|
|
||||||
"# gm._implicit_functions[0]._fn.density.fill_(9.0)\n",
|
|
||||||
"# gm._implicit_functions[0]._fn.color.fill_(9.0)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"metadata": {
|
|
||||||
"bento_stylesheets": {
|
|
||||||
"bento/extensions/flow/main.css": true,
|
|
||||||
"bento/extensions/kernel_selector/main.css": true,
|
|
||||||
"bento/extensions/kernel_ui/main.css": true,
|
|
||||||
"bento/extensions/new_kernel/main.css": true,
|
|
||||||
"bento/extensions/system_usage/main.css": true,
|
|
||||||
"bento/extensions/theme/main.css": true
|
|
||||||
},
|
|
||||||
"captumWidgetMessage": {},
|
|
||||||
"dataExplorerConfig": {},
|
|
||||||
"kernelspec": {
|
|
||||||
"display_name": "pytorch3d",
|
|
||||||
"language": "python",
|
|
||||||
"metadata": {
|
|
||||||
"cinder_runtime": false,
|
|
||||||
"fbpkg_supported": true,
|
|
||||||
"is_prebuilt": true,
|
|
||||||
"kernel_name": "bento_kernel_pytorch3d",
|
|
||||||
"nightly_builds": true
|
|
||||||
},
|
|
||||||
"name": "bento_kernel_pytorch3d"
|
|
||||||
},
|
|
||||||
"language_info": {
|
|
||||||
"codemirror_mode": {
|
|
||||||
"name": "ipython",
|
|
||||||
"version": 3
|
|
||||||
},
|
|
||||||
"file_extension": ".py",
|
|
||||||
"mimetype": "text/x-python",
|
|
||||||
"name": "python",
|
|
||||||
"nbconvert_exporter": "python",
|
|
||||||
"pygments_lexer": "ipython3"
|
|
||||||
},
|
|
||||||
"last_base_url": "https://9177.od.fbinfra.net:443/",
|
|
||||||
"last_kernel_id": "bb33cd83-7924-489a-8bd8-2d9d62eb0126",
|
|
||||||
"last_msg_id": "99f7088e-d22b355b859660479ef0574e_5743",
|
|
||||||
"last_server_session_id": "2944b203-9ea8-4c0e-9634-645dfea5f26b",
|
|
||||||
"outputWidgetContext": {}
|
|
||||||
},
|
|
||||||
"nbformat": 4,
|
|
||||||
"nbformat_minor": 2
|
|
||||||
}
|
|
||||||
@@ -50,7 +50,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -57,7 +57,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -73,7 +73,7 @@
|
|||||||
"except ModuleNotFoundError:\n",
|
"except ModuleNotFoundError:\n",
|
||||||
" need_pytorch3d=True\n",
|
" need_pytorch3d=True\n",
|
||||||
"if need_pytorch3d:\n",
|
"if need_pytorch3d:\n",
|
||||||
" if torch.__version__.startswith(\"1.12.\") and sys.platform.startswith(\"linux\"):\n",
|
" if torch.__version__.startswith(\"1.11.\") and sys.platform.startswith(\"linux\"):\n",
|
||||||
" # We try to install PyTorch3D via a released wheel.\n",
|
" # We try to install PyTorch3D via a released wheel.\n",
|
||||||
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
||||||
" version_str=\"\".join([\n",
|
" version_str=\"\".join([\n",
|
||||||
|
|||||||
@@ -44,8 +44,6 @@ def generate_cow_renders(
|
|||||||
data_dir: The folder that contains the cow mesh files. If the cow mesh
|
data_dir: The folder that contains the cow mesh files. If the cow mesh
|
||||||
files do not exist in the folder, this function will automatically
|
files do not exist in the folder, this function will automatically
|
||||||
download them.
|
download them.
|
||||||
azimuth_range: number of degrees on each side of the start position to
|
|
||||||
take samples
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ test:
|
|||||||
- imageio
|
- imageio
|
||||||
- hydra-core
|
- hydra-core
|
||||||
- accelerate
|
- accelerate
|
||||||
|
- lpips
|
||||||
commands:
|
commands:
|
||||||
#pytest .
|
#pytest .
|
||||||
python -m unittest discover -v -s tests -t .
|
python -m unittest discover -v -s tests -t .
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ See [Running](#running) section below for examples of training and evaluation co
|
|||||||
|
|
||||||
To plug in custom implementations, for example, of renderer or implicit-function protocols, you need to create your own runner script and import the plug-in implementations there.
|
To plug in custom implementations, for example, of renderer or implicit-function protocols, you need to create your own runner script and import the plug-in implementations there.
|
||||||
First, install PyTorch3D and Implicitron dependencies as described in the previous section.
|
First, install PyTorch3D and Implicitron dependencies as described in the previous section.
|
||||||
Then, implement the custom script; copying `pytorch3d/projects/implicitron_trainer` is a good place to start.
|
Then, implement the custom script; copying `pytorch3d/projects/implicitron_trainer/experiment.py` is a good place to start.
|
||||||
See [Custom plugins](#custom-plugins) for more information on how to import implementations and enable them in the configs.
|
See [Custom plugins](#custom-plugins) for more information on how to import implementations and enable them in the configs.
|
||||||
|
|
||||||
|
|
||||||
@@ -203,29 +203,14 @@ to replace the implementation and potentially override the parameters.
|
|||||||
|
|
||||||
# Code and config structure
|
# Code and config structure
|
||||||
|
|
||||||
The main object for this trainer loop is `Experiment`. It has four top-level replaceable components.
|
|
||||||
|
|
||||||
* `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`.
|
|
||||||
It constructs the data sets and dataloaders.
|
|
||||||
* `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`.
|
|
||||||
It constructs the model, which is usually an instance of implicitron's main `GenericModel` class, and can load its weights from a checkpoint.
|
|
||||||
* `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`.
|
|
||||||
It constructs the optimizer and can load its weights from a checkpoint.
|
|
||||||
* `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop.
|
|
||||||
|
|
||||||
As per above, the config structure is parsed automatically from the module hierarchy.
|
As per above, the config structure is parsed automatically from the module hierarchy.
|
||||||
In particular, for ImplicitronModelFactory with generic model, model parameters are contained in the `model_factory_ImplicitronModelFactory_args.model_GenericModel_args` node, and dataset parameters in `data_source_ImplicitronDataSource_args` node.
|
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `data_source_args` node.
|
||||||
|
|
||||||
Here is the class structure of GenericModel (single-line edges show aggregation, while double lines show available implementations):
|
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
|
||||||
```
|
```
|
||||||
model_GenericModel_args: GenericModel
|
generic_model_args: GenericModel
|
||||||
└-- global_encoder_*_args: GlobalEncoderBase
|
└-- sequence_autodecoder_args: Autodecoder
|
||||||
╘== SequenceAutodecoder
|
└-- raysampler_args: RaySampler
|
||||||
└-- autodecoder_args: Autodecoder
|
|
||||||
╘== HarmonicTimeEncoder
|
|
||||||
└-- raysampler_*_args: RaySampler
|
|
||||||
╘== AdaptiveRaysampler
|
|
||||||
╘== NearFarRaysampler
|
|
||||||
└-- renderer_*_args: BaseRenderer
|
└-- renderer_*_args: BaseRenderer
|
||||||
╘== MultiPassEmissionAbsorptionRenderer
|
╘== MultiPassEmissionAbsorptionRenderer
|
||||||
╘== LSTMRenderer
|
╘== LSTMRenderer
|
||||||
@@ -243,16 +228,19 @@ model_GenericModel_args: GenericModel
|
|||||||
╘== IdrFeatureField
|
╘== IdrFeatureField
|
||||||
└-- image_feature_extractor_*_args: FeatureExtractorBase
|
└-- image_feature_extractor_*_args: FeatureExtractorBase
|
||||||
╘== ResNetFeatureExtractor
|
╘== ResNetFeatureExtractor
|
||||||
└-- view_pooler_args: ViewPooler
|
└-- view_sampler_args: ViewSampler
|
||||||
└-- view_sampler_args: ViewSampler
|
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
||||||
└-- feature_aggregator_*_args: FeatureAggregatorBase
|
|
||||||
╘== IdentityFeatureAggregator
|
╘== IdentityFeatureAggregator
|
||||||
╘== AngleWeightedIdentityFeatureAggregator
|
╘== AngleWeightedIdentityFeatureAggregator
|
||||||
╘== AngleWeightedReductionFeatureAggregator
|
╘== AngleWeightedReductionFeatureAggregator
|
||||||
╘== ReductionFeatureAggregator
|
╘== ReductionFeatureAggregator
|
||||||
|
solver_args: init_optimizer
|
||||||
|
data_source_args: ImplicitronDataSource
|
||||||
|
└-- dataset_map_provider_*_args
|
||||||
|
└-- data_loader_map_provider_*_args
|
||||||
```
|
```
|
||||||
|
|
||||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters. `tests/experiment.yaml` shows every possible option if you have no user-defined classes.
|
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
||||||
|
|
||||||
# Reproducing CO3D experiments
|
# Reproducing CO3D experiments
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,10 @@ defaults:
|
|||||||
- default_config
|
- default_config
|
||||||
- _self_
|
- _self_
|
||||||
exp_dir: ./data/exps/base/
|
exp_dir: ./data/exps/base/
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
architecture: generic
|
||||||
visdom_port: 8097
|
visualize_interval: 0
|
||||||
visualize_interval: 0
|
visdom_port: 8097
|
||||||
max_epochs: 1000
|
data_source_args:
|
||||||
data_source_ImplicitronDataSource_args:
|
|
||||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||||
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
@@ -22,8 +21,7 @@ data_source_ImplicitronDataSource_args:
|
|||||||
load_point_clouds: false
|
load_point_clouds: false
|
||||||
mask_depths: false
|
mask_depths: false
|
||||||
mask_images: false
|
mask_images: false
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_mask_bce: 1.0
|
loss_mask_bce: 1.0
|
||||||
loss_prev_stage_mask_bce: 1.0
|
loss_prev_stage_mask_bce: 1.0
|
||||||
@@ -43,6 +41,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_layers_xyz: 8
|
n_layers_xyz: 8
|
||||||
append_xyz:
|
append_xyz:
|
||||||
- 5
|
- 5
|
||||||
|
latent_dim: 0
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
scene_extent: 8.0
|
scene_extent: 8.0
|
||||||
@@ -67,14 +66,10 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
proj_dim: 16
|
proj_dim: 16
|
||||||
image_rescale: 0.32
|
image_rescale: 0.32
|
||||||
first_max_pool: false
|
first_max_pool: false
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
breed: Adam
|
breed: adam
|
||||||
weight_decay: 0.0
|
|
||||||
lr_policy: MultiStepLR
|
|
||||||
multistep_lr_milestones: []
|
|
||||||
lr: 0.0005
|
lr: 0.0005
|
||||||
gamma: 0.1
|
lr_policy: multistep
|
||||||
|
max_epochs: 2000
|
||||||
momentum: 0.9
|
momentum: 0.9
|
||||||
betas:
|
weight_decay: 0.0
|
||||||
- 0.9
|
|
||||||
- 0.999
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
add_images: true
|
add_images: true
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
stages:
|
stages:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_base.yaml
|
- repro_base.yaml
|
||||||
- _self_
|
- _self_
|
||||||
data_source_ImplicitronDataSource_args:
|
data_source_args:
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
dataset_length_train: 1000
|
dataset_length_train: 1000
|
||||||
@@ -26,13 +26,10 @@ data_source_ImplicitronDataSource_args:
|
|||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
test_on_train: true
|
test_on_train: true
|
||||||
test_restrict_sequence_id: 0
|
test_restrict_sequence_id: 0
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
multistep_lr_milestones:
|
|
||||||
- 1000
|
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
|
||||||
max_epochs: 3000
|
max_epochs: 3000
|
||||||
evaluator_ImplicitronEvaluator_args:
|
milestones:
|
||||||
camera_difficulty_bin_breaks:
|
- 1000
|
||||||
|
camera_difficulty_bin_breaks:
|
||||||
- 0.666667
|
- 0.666667
|
||||||
- 0.833334
|
- 0.833334
|
||||||
is_multisequence: true
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_base.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_mask_bce: 100.0
|
loss_mask_bce: 100.0
|
||||||
loss_kl: 0.0
|
loss_kl: 0.0
|
||||||
@@ -43,6 +42,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
line_step_iters: 3
|
line_step_iters: 3
|
||||||
n_secant_steps: 8
|
n_secant_steps: 8
|
||||||
n_steps: 100
|
n_steps: 100
|
||||||
|
object_bounding_sphere: 8.0
|
||||||
sdf_threshold: 5.0e-05
|
sdf_threshold: 5.0e-05
|
||||||
ray_normal_coloring_network_args:
|
ray_normal_coloring_network_args:
|
||||||
d_in: 9
|
d_in: 9
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_base.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: false
|
view_pooler_enabled: false
|
||||||
global_encoder_class_type: SequenceAutodecoder
|
global_encoder_class_type: SequenceAutodecoder
|
||||||
|
|||||||
@@ -2,11 +2,9 @@ defaults:
|
|||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_base.yaml
|
||||||
- repro_feat_extractor_unnormed.yaml
|
- repro_feat_extractor_unnormed.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
clip_grad: 1.0
|
||||||
model_GenericModel_args:
|
generic_model_args:
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 850
|
n_rays_per_image_sampled_from_mask: 850
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
|
||||||
clip_grad: 1.0
|
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ defaults:
|
|||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_base.yaml
|
||||||
- repro_feat_extractor_transformer.yaml
|
- repro_feat_extractor_transformer.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 800
|
n_rays_per_image_sampled_from_mask: 800
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_nerformer.yaml
|
- repro_multiseq_nerformer.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
view_pooler_args:
|
view_pooler_args:
|
||||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_base.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: false
|
view_pooler_enabled: false
|
||||||
n_train_target_views: -1
|
n_train_target_views: -1
|
||||||
@@ -30,6 +29,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
stratified_point_sampling_evaluation: false
|
stratified_point_sampling_evaluation: false
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
implicit_function_class_type: SRNHyperNetImplicitFunction
|
implicit_function_class_type: SRNHyperNetImplicitFunction
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
breed: Adam
|
breed: adam
|
||||||
lr: 5.0e-05
|
lr: 5.0e-05
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_srn_ad_hypernet.yaml
|
- repro_multiseq_srn_ad_hypernet.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
implicit_function_SRNHyperNetImplicitFunction_args:
|
implicit_function_SRNHyperNetImplicitFunction_args:
|
||||||
pixel_generator_args:
|
pixel_generator_args:
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ defaults:
|
|||||||
- repro_multiseq_base.yaml
|
- repro_multiseq_base.yaml
|
||||||
- repro_feat_extractor_normed.yaml
|
- repro_feat_extractor_normed.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
n_train_target_views: -1
|
n_train_target_views: -1
|
||||||
@@ -26,6 +25,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
implicit_function_class_type: SRNImplicitFunction
|
implicit_function_class_type: SRNImplicitFunction
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
breed: Adam
|
breed: adam
|
||||||
lr: 5.0e-05
|
lr: 5.0e-05
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_multiseq_srn_wce.yaml
|
- repro_multiseq_srn_wce.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
implicit_function_SRNImplicitFunction_args:
|
implicit_function_SRNImplicitFunction_args:
|
||||||
pixel_generator_args:
|
pixel_generator_args:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_base
|
- repro_base
|
||||||
- _self_
|
- _self_
|
||||||
data_source_ImplicitronDataSource_args:
|
data_source_args:
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
dataset_length_train: 1000
|
dataset_length_train: 1000
|
||||||
@@ -12,8 +12,7 @@ data_source_ImplicitronDataSource_args:
|
|||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
test_restrict_sequence_id: 0
|
test_restrict_sequence_id: 0
|
||||||
test_on_train: false
|
test_on_train: false
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
render_image_height: 800
|
render_image_height: 800
|
||||||
render_image_width: 800
|
render_image_width: 800
|
||||||
log_vars:
|
log_vars:
|
||||||
@@ -32,10 +31,9 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
- objective
|
- objective
|
||||||
- epoch
|
- epoch
|
||||||
- sec/it
|
- sec/it
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
lr: 0.0005
|
lr: 0.0005
|
||||||
multistep_lr_milestones:
|
max_epochs: 400
|
||||||
|
milestones:
|
||||||
- 200
|
- 200
|
||||||
- 300
|
- 300
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
|
||||||
max_epochs: 400
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_base
|
- repro_singleseq_base
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_mask_bce: 100.0
|
loss_mask_bce: 100.0
|
||||||
loss_kl: 0.0
|
loss_kl: 0.0
|
||||||
@@ -36,6 +35,7 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
line_step_iters: 3
|
line_step_iters: 3
|
||||||
n_secant_steps: 8
|
n_secant_steps: 8
|
||||||
n_steps: 100
|
n_steps: 100
|
||||||
|
object_bounding_sphere: 8.0
|
||||||
sdf_threshold: 5.0e-05
|
sdf_threshold: 5.0e-05
|
||||||
ray_normal_coloring_network_args:
|
ray_normal_coloring_network_args:
|
||||||
d_in: 9
|
d_in: 9
|
||||||
|
|||||||
@@ -1,48 +0,0 @@
|
|||||||
defaults:
|
|
||||||
- repro_singleseq_base
|
|
||||||
- _self_
|
|
||||||
exp_dir: "./data/nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}"
|
|
||||||
data_source_ImplicitronDataSource_args:
|
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
|
||||||
dataset_length_train: 100
|
|
||||||
dataset_map_provider_class_type: BlenderDatasetMapProvider
|
|
||||||
dataset_map_provider_BlenderDatasetMapProvider_args:
|
|
||||||
base_dir: ${oc.env:BLENDER_DATASET_ROOT}
|
|
||||||
n_known_frames_for_test: null
|
|
||||||
object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS}
|
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
|
||||||
path_manager_factory_PathManagerFactory_args:
|
|
||||||
silence_logs: true
|
|
||||||
|
|
||||||
model_factory_ImplicitronModelFactory_args:
|
|
||||||
model_GenericModel_args:
|
|
||||||
mask_images: false
|
|
||||||
raysampler_class_type: NearFarRaySampler
|
|
||||||
raysampler_NearFarRaySampler_args:
|
|
||||||
n_rays_per_image_sampled_from_mask: 4096
|
|
||||||
min_depth: 2
|
|
||||||
max_depth: 6
|
|
||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
|
||||||
density_noise_std_train: 1.0
|
|
||||||
n_pts_per_ray_fine_training: 128
|
|
||||||
n_pts_per_ray_fine_evaluation: 128
|
|
||||||
raymarcher_EmissionAbsorptionRaymarcher_args:
|
|
||||||
blend_output: false
|
|
||||||
loss_weights:
|
|
||||||
loss_rgb_mse: 1.0
|
|
||||||
loss_prev_stage_rgb_mse: 1.0
|
|
||||||
loss_mask_bce: 0.0
|
|
||||||
loss_prev_stage_mask_bce: 0.0
|
|
||||||
loss_autodecoder_norm: 0.00
|
|
||||||
|
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
|
||||||
exponential_lr_step_size: 3001
|
|
||||||
lr_policy: LinearExponential
|
|
||||||
linear_exponential_lr_milestone: 200
|
|
||||||
|
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
|
||||||
max_epochs: 3201
|
|
||||||
metric_print_interval: 10
|
|
||||||
store_checkpoints_purge: 3
|
|
||||||
test_when_finished: true
|
|
||||||
validation_interval: 100
|
|
||||||
@@ -2,8 +2,7 @@ defaults:
|
|||||||
- repro_singleseq_wce_base.yaml
|
- repro_singleseq_wce_base.yaml
|
||||||
- repro_feat_extractor_unnormed.yaml
|
- repro_feat_extractor_unnormed.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ defaults:
|
|||||||
- repro_singleseq_wce_base.yaml
|
- repro_singleseq_wce_base.yaml
|
||||||
- repro_feat_extractor_transformer.yaml
|
- repro_feat_extractor_transformer.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
chunk_size_grid: 16000
|
chunk_size_grid: 16000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
implicit_function_class_type: NeRFormerImplicitFunction
|
implicit_function_class_type: NeRFormerImplicitFunction
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_base.yaml
|
- repro_singleseq_base.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pooler_enabled: false
|
view_pooler_enabled: false
|
||||||
@@ -24,6 +23,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
stratified_point_sampling_evaluation: false
|
stratified_point_sampling_evaluation: false
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
implicit_function_class_type: SRNImplicitFunction
|
implicit_function_class_type: SRNImplicitFunction
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
breed: Adam
|
breed: adam
|
||||||
lr: 5.0e-05
|
lr: 5.0e-05
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_srn.yaml
|
- repro_singleseq_srn.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
implicit_function_SRNImplicitFunction_args:
|
implicit_function_SRNImplicitFunction_args:
|
||||||
pixel_generator_args:
|
pixel_generator_args:
|
||||||
|
|||||||
@@ -2,8 +2,7 @@ defaults:
|
|||||||
- repro_singleseq_wce_base
|
- repro_singleseq_wce_base
|
||||||
- repro_feat_extractor_normed.yaml
|
- repro_feat_extractor_normed.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
chunk_size_grid: 32000
|
chunk_size_grid: 32000
|
||||||
view_pooler_enabled: true
|
view_pooler_enabled: true
|
||||||
@@ -25,6 +24,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
stratified_point_sampling_evaluation: false
|
stratified_point_sampling_evaluation: false
|
||||||
renderer_class_type: LSTMRenderer
|
renderer_class_type: LSTMRenderer
|
||||||
implicit_function_class_type: SRNImplicitFunction
|
implicit_function_class_type: SRNImplicitFunction
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
solver_args:
|
||||||
breed: Adam
|
breed: adam
|
||||||
lr: 5.0e-05
|
lr: 5.0e-05
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_srn_wce.yaml
|
- repro_singleseq_srn_wce.yaml
|
||||||
- _self_
|
- _self_
|
||||||
model_factory_ImplicitronModelFactory_args:
|
generic_model_args:
|
||||||
model_GenericModel_args:
|
|
||||||
num_passes: 1
|
num_passes: 1
|
||||||
implicit_function_SRNImplicitFunction_args:
|
implicit_function_SRNImplicitFunction_args:
|
||||||
pixel_generator_args:
|
pixel_generator_args:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- repro_singleseq_base
|
- repro_singleseq_base
|
||||||
- _self_
|
- _self_
|
||||||
data_source_ImplicitronDataSource_args:
|
data_source_args:
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 10
|
batch_size: 10
|
||||||
dataset_length_train: 1000
|
dataset_length_train: 1000
|
||||||
|
|||||||
@@ -8,28 +8,27 @@
|
|||||||
""""
|
""""
|
||||||
This file is the entry point for launching experiments with Implicitron.
|
This file is the entry point for launching experiments with Implicitron.
|
||||||
|
|
||||||
|
Main functions
|
||||||
|
---------------
|
||||||
|
- `run_training` is the wrapper for the train, val, test loops
|
||||||
|
and checkpointing
|
||||||
|
- `trainvalidate` is the inner loop which runs the model forward/backward
|
||||||
|
pass, visualizations and metric printing
|
||||||
|
|
||||||
Launch Training
|
Launch Training
|
||||||
---------------
|
---------------
|
||||||
Experiment config .yaml files are located in the
|
Experiment config .yaml files are located in the
|
||||||
`projects/implicitron_trainer/configs` folder. To launch an experiment,
|
`projects/implicitron_trainer/configs` folder. To launch
|
||||||
specify the name of the file. Specific config values can also be overridden
|
an experiment, specify the name of the file. Specific config values can
|
||||||
from the command line, for example:
|
also be overridden from the command line, for example:
|
||||||
|
|
||||||
```
|
```
|
||||||
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
|
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
|
||||||
```
|
```
|
||||||
|
|
||||||
To run an experiment on a specific GPU, specify the `gpu_idx` key in the
|
To run an experiment on a specific GPU, specify the `gpu_idx` key
|
||||||
config file / CLI. To run on a different device, specify the device in
|
in the config file / CLI. To run on a different device, specify the
|
||||||
`run_training`.
|
device in `run_training`.
|
||||||
|
|
||||||
Main functions
|
|
||||||
---------------
|
|
||||||
- The Experiment class defines `run` which creates the model, optimizer, and other
|
|
||||||
objects used in training, then starts TrainingLoop's `run` function.
|
|
||||||
- TrainingLoop takes care of the actual training logic: forward and backward passes,
|
|
||||||
evaluation and testing, as well as model checkpointing, visualization, and metric
|
|
||||||
printing.
|
|
||||||
|
|
||||||
Outputs
|
Outputs
|
||||||
--------
|
--------
|
||||||
@@ -46,40 +45,43 @@ The outputs of the experiment are saved and logged in multiple ways:
|
|||||||
config file.
|
config file.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
from dataclasses import field
|
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
|
import lpips
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import tqdm
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from omegaconf import DictConfig, OmegaConf
|
from omegaconf import DictConfig, OmegaConf
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||||
from pytorch3d.implicitron.dataset.data_source import (
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
|
||||||
DataSourceBase,
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
ImplicitronDataSource,
|
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||||
)
|
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||||
from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||||
MultiPassEmissionAbsorptionRenderer,
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
|
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
|
||||||
|
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
Configurable,
|
|
||||||
expand_args_fields,
|
expand_args_fields,
|
||||||
remove_unused_components,
|
remove_unused_components,
|
||||||
run_auto_creation,
|
|
||||||
)
|
)
|
||||||
|
from pytorch3d.implicitron.tools.stats import Stats
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .impl.model_factory import ModelFactoryBase
|
from .impl.experiment_config import ExperimentConfig
|
||||||
from .impl.optimizer_factory import OptimizerFactoryBase
|
from .impl.optimization import init_optimizer
|
||||||
from .impl.training_loop import TrainingLoopBase
|
|
||||||
from .impl.utils import seed_all_random_engines
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -98,56 +100,266 @@ except ModuleNotFoundError:
|
|||||||
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
|
||||||
|
|
||||||
|
|
||||||
class Experiment(Configurable): # pyre-ignore: 13
|
def init_model(
|
||||||
|
*,
|
||||||
|
cfg: DictConfig,
|
||||||
|
accelerator: Optional[Accelerator] = None,
|
||||||
|
force_load: bool = False,
|
||||||
|
clear_stats: bool = False,
|
||||||
|
load_model_only: bool = False,
|
||||||
|
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
|
||||||
"""
|
"""
|
||||||
This class is at the top level of Implicitron's config hierarchy. Its
|
Returns an instance of `GenericModel`.
|
||||||
members are high-level components necessary for training an implicit rende-
|
|
||||||
ring network.
|
|
||||||
|
|
||||||
Members:
|
If `cfg.resume` is set or `force_load` is true,
|
||||||
data_source: An object that produces datasets and dataloaders.
|
attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
|
||||||
model_factory: An object that produces an implicit rendering model as
|
will return the model with initial weights, unless `force_load` is passed,
|
||||||
well as its corresponding Stats object.
|
in which case a FileNotFoundError is raised.
|
||||||
optimizer_factory: An object that produces the optimizer and lr
|
|
||||||
scheduler.
|
Args:
|
||||||
training_loop: An object that runs training given the outputs produced
|
force_load: If true, force load model from checkpoint even if
|
||||||
by the data_source, model_factory and optimizer_factory.
|
cfg.resume is false.
|
||||||
seed: A random seed to ensure reproducibility.
|
clear_stats: If true, clear the stats object loaded from checkpoint
|
||||||
detect_anomaly: Whether torch.autograd should detect anomalies. Useful
|
load_model_only: If true, load only the model weights from checkpoint
|
||||||
for debugging, but might slow down the training.
|
and do not load the state of the optimizer and stats.
|
||||||
exp_dir: Root experimentation directory. Checkpoints and training stats
|
|
||||||
will be saved here.
|
Returns:
|
||||||
|
model: The model with optionally loaded weights from checkpoint
|
||||||
|
stats: The stats structure (optionally loaded from checkpoint)
|
||||||
|
optimizer_state: The optimizer state dict containing
|
||||||
|
`state` and `param_groups` keys (optionally loaded from checkpoint)
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
FileNotFoundError if `force_load` is passed but checkpoint is not found.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data_source: DataSourceBase
|
# Initialize the model
|
||||||
data_source_class_type: str = "ImplicitronDataSource"
|
if cfg.architecture == "generic":
|
||||||
model_factory: ModelFactoryBase
|
model = GenericModel(**cfg.generic_model_args)
|
||||||
model_factory_class_type: str = "ImplicitronModelFactory"
|
else:
|
||||||
optimizer_factory: OptimizerFactoryBase
|
raise ValueError(f"No such arch {cfg.architecture}.")
|
||||||
optimizer_factory_class_type: str = "ImplicitronOptimizerFactory"
|
|
||||||
training_loop: TrainingLoopBase
|
|
||||||
training_loop_class_type: str = "ImplicitronTrainingLoop"
|
|
||||||
|
|
||||||
seed: int = 42
|
# Determine the network outputs that should be logged
|
||||||
detect_anomaly: bool = False
|
if hasattr(model, "log_vars"):
|
||||||
exp_dir: str = "./data/default_experiment/"
|
log_vars = copy.deepcopy(list(model.log_vars))
|
||||||
|
else:
|
||||||
|
log_vars = ["objective"]
|
||||||
|
|
||||||
hydra: dict = field(
|
visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"
|
||||||
default_factory=lambda: {
|
|
||||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
# Init the stats struct
|
||||||
"output_subdir": None, # disable storing the .hydra logs
|
stats = Stats(
|
||||||
}
|
log_vars,
|
||||||
|
visdom_env=visdom_env_charts,
|
||||||
|
verbose=False,
|
||||||
|
visdom_server=cfg.visdom_server,
|
||||||
|
visdom_port=cfg.visdom_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
# Retrieve the last checkpoint
|
||||||
seed_all_random_engines(
|
if cfg.resume_epoch > 0:
|
||||||
self.seed
|
model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
|
||||||
) # Set all random engine seeds for reproducibility
|
else:
|
||||||
|
model_path = model_io.find_last_checkpoint(cfg.exp_dir)
|
||||||
|
|
||||||
run_auto_creation(self)
|
optimizer_state = None
|
||||||
|
if model_path is not None:
|
||||||
|
logger.info("found previous model %s" % model_path)
|
||||||
|
if force_load or cfg.resume:
|
||||||
|
logger.info(" -> resuming")
|
||||||
|
|
||||||
def run(self) -> None:
|
map_location = None
|
||||||
# Initialize the accelerator if desired.
|
if accelerator is not None and not accelerator.is_local_main_process:
|
||||||
|
map_location = {
|
||||||
|
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||||
|
}
|
||||||
|
if load_model_only:
|
||||||
|
model_state_dict = torch.load(
|
||||||
|
model_io.get_model_path(model_path), map_location=map_location
|
||||||
|
)
|
||||||
|
stats_load, optimizer_state = None, None
|
||||||
|
else:
|
||||||
|
model_state_dict, stats_load, optimizer_state = model_io.load_model(
|
||||||
|
model_path, map_location=map_location
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine if stats should be reset
|
||||||
|
if not clear_stats:
|
||||||
|
if stats_load is None:
|
||||||
|
logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||||
|
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||||
|
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||||
|
|
||||||
|
# Reset the stats struct
|
||||||
|
for _ in range(last_epoch + 1):
|
||||||
|
stats.new_epoch()
|
||||||
|
assert last_epoch == stats.epoch
|
||||||
|
else:
|
||||||
|
stats = stats_load
|
||||||
|
|
||||||
|
# Update stats properties incase it was reset on load
|
||||||
|
stats.visdom_env = visdom_env_charts
|
||||||
|
stats.visdom_server = cfg.visdom_server
|
||||||
|
stats.visdom_port = cfg.visdom_port
|
||||||
|
stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
|
||||||
|
stats.synchronize_logged_vars(log_vars)
|
||||||
|
else:
|
||||||
|
logger.info(" -> clearing stats")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# TODO: fix on creation of the buffers
|
||||||
|
# after the hack above, this will not pass in most cases
|
||||||
|
# ... but this is fine for now
|
||||||
|
model.load_state_dict(model_state_dict, strict=True)
|
||||||
|
except RuntimeError as e:
|
||||||
|
logger.error(e)
|
||||||
|
logger.info("Cant load state dict in strict mode! -> trying non-strict")
|
||||||
|
model.load_state_dict(model_state_dict, strict=False)
|
||||||
|
model.log_vars = log_vars
|
||||||
|
else:
|
||||||
|
logger.info(" -> but not resuming -> starting from scratch")
|
||||||
|
elif force_load:
|
||||||
|
raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")
|
||||||
|
|
||||||
|
return model, stats, optimizer_state
|
||||||
|
|
||||||
|
|
||||||
|
def trainvalidate(
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
epoch,
|
||||||
|
loader,
|
||||||
|
optimizer,
|
||||||
|
validation: bool,
|
||||||
|
*,
|
||||||
|
accelerator: Optional[Accelerator],
|
||||||
|
device: torch.device,
|
||||||
|
bp_var: str = "objective",
|
||||||
|
metric_print_interval: int = 5,
|
||||||
|
visualize_interval: int = 100,
|
||||||
|
visdom_env_root: str = "trainvalidate",
|
||||||
|
clip_grad: float = 0.0,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
This is the main loop for training and evaluation including:
|
||||||
|
model forward pass, loss computation, backward pass and visualization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model module optionally loaded from checkpoint
|
||||||
|
stats: The stats struct, also optionally loaded from checkpoint
|
||||||
|
epoch: The index of the current epoch
|
||||||
|
loader: The dataloader to use for the loop
|
||||||
|
optimizer: The optimizer module optionally loaded from checkpoint
|
||||||
|
validation: If true, run the loop with the model in eval mode
|
||||||
|
and skip the backward pass
|
||||||
|
bp_var: The name of the key in the model output `preds` dict which
|
||||||
|
should be used as the loss for the backward pass.
|
||||||
|
metric_print_interval: The batch interval at which the stats should be
|
||||||
|
logged.
|
||||||
|
visualize_interval: The batch interval at which the visualizations
|
||||||
|
should be plotted
|
||||||
|
visdom_env_root: The name of the visdom environment to use for plotting
|
||||||
|
clip_grad: Optionally clip the gradient norms.
|
||||||
|
If set to a value <=0.0, no clipping
|
||||||
|
device: The device on which to run the model.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
|
||||||
|
if validation:
|
||||||
|
model.eval()
|
||||||
|
trainmode = "val"
|
||||||
|
else:
|
||||||
|
model.train()
|
||||||
|
trainmode = "train"
|
||||||
|
|
||||||
|
t_start = time.time()
|
||||||
|
|
||||||
|
# get the visdom env name
|
||||||
|
visdom_env_imgs = visdom_env_root + "_images_" + trainmode
|
||||||
|
viz = vis_utils.get_visdom_connection(
|
||||||
|
server=stats.visdom_server,
|
||||||
|
port=stats.visdom_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Iterate through the batches
|
||||||
|
n_batches = len(loader)
|
||||||
|
for it, net_input in enumerate(loader):
|
||||||
|
last_iter = it == n_batches - 1
|
||||||
|
|
||||||
|
# move to gpu where possible (in place)
|
||||||
|
net_input = net_input.to(device)
|
||||||
|
|
||||||
|
# run the forward pass
|
||||||
|
if not validation:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(
|
||||||
|
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure we dont overwrite something
|
||||||
|
assert all(k not in preds for k in net_input.keys())
|
||||||
|
# merge everything into one big dict
|
||||||
|
preds.update(net_input)
|
||||||
|
|
||||||
|
# update the stats logger
|
||||||
|
stats.update(preds, time_start=t_start, stat_set=trainmode)
|
||||||
|
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
|
||||||
|
|
||||||
|
# print textual status update
|
||||||
|
if it % metric_print_interval == 0 or last_iter:
|
||||||
|
stats.print(stat_set=trainmode, max_it=n_batches)
|
||||||
|
|
||||||
|
# visualize results
|
||||||
|
if (
|
||||||
|
(accelerator is None or accelerator.is_local_main_process)
|
||||||
|
and visualize_interval > 0
|
||||||
|
and it % visualize_interval == 0
|
||||||
|
):
|
||||||
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||||
|
|
||||||
|
model.visualize(
|
||||||
|
viz,
|
||||||
|
visdom_env_imgs,
|
||||||
|
preds,
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
|
||||||
|
# optimizer step
|
||||||
|
if not validation:
|
||||||
|
loss = preds[bp_var]
|
||||||
|
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
||||||
|
# backprop
|
||||||
|
if accelerator is None:
|
||||||
|
loss.backward()
|
||||||
|
else:
|
||||||
|
accelerator.backward(loss)
|
||||||
|
if clip_grad > 0.0:
|
||||||
|
# Optionally clip the gradient norms.
|
||||||
|
total_norm = torch.nn.utils.clip_grad_norm(
|
||||||
|
model.parameters(), clip_grad
|
||||||
|
)
|
||||||
|
if total_norm > clip_grad:
|
||||||
|
logger.info(
|
||||||
|
f"Clipping gradient: {total_norm}"
|
||||||
|
+ f" with coef {clip_grad / float(total_norm)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
|
def run_training(cfg: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
Entry point to run the training and validation loops
|
||||||
|
based on the specified config file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Initialize the accelerator
|
||||||
if no_accelerate:
|
if no_accelerate:
|
||||||
accelerator = None
|
accelerator = None
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@@ -157,48 +369,67 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
device = accelerator.device
|
device = accelerator.device
|
||||||
|
|
||||||
logger.info(f"Running experiment on device: {device}")
|
logger.info(f"Running experiment on device: {device}")
|
||||||
os.makedirs(self.exp_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# set the debug mode
|
# set the debug mode
|
||||||
if self.detect_anomaly:
|
if cfg.detect_anomaly:
|
||||||
logger.info("Anomaly detection!")
|
logger.info("Anomaly detection!")
|
||||||
torch.autograd.set_detect_anomaly(self.detect_anomaly)
|
torch.autograd.set_detect_anomaly(cfg.detect_anomaly)
|
||||||
|
|
||||||
# Initialize the datasets and dataloaders.
|
# create the output folder
|
||||||
datasets, dataloaders = self.data_source.get_datasets_and_dataloaders()
|
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||||
|
_seed_all_random_engines(cfg.seed)
|
||||||
|
remove_unused_components(cfg)
|
||||||
|
|
||||||
# Init the model and the corresponding Stats object.
|
# dump the exp config to the exp dir
|
||||||
model = self.model_factory(
|
try:
|
||||||
accelerator=accelerator,
|
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||||
exp_dir=self.exp_dir,
|
OmegaConf.save(config=cfg, f=cfg_filename)
|
||||||
)
|
except PermissionError:
|
||||||
|
warnings.warn("Cant dump config due to insufficient permissions!")
|
||||||
|
|
||||||
stats = self.training_loop.load_stats(
|
# setup datasets
|
||||||
log_vars=model.log_vars,
|
datasource = ImplicitronDataSource(**cfg.data_source_args)
|
||||||
exp_dir=self.exp_dir,
|
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
|
||||||
resume=self.model_factory.resume,
|
task = datasource.get_task()
|
||||||
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
|
|
||||||
)
|
# init the model
|
||||||
|
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
|
||||||
start_epoch = stats.epoch + 1
|
start_epoch = stats.epoch + 1
|
||||||
|
|
||||||
|
# move model to gpu
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# Init the optimizer and LR scheduler.
|
# only run evaluation on the test dataloader
|
||||||
optimizer, scheduler = self.optimizer_factory(
|
if cfg.eval_only:
|
||||||
accelerator=accelerator,
|
_eval_and_dump(
|
||||||
exp_dir=self.exp_dir,
|
cfg,
|
||||||
last_epoch=start_epoch,
|
task,
|
||||||
model=model,
|
datasource.all_train_cameras,
|
||||||
resume=self.model_factory.resume,
|
datasets,
|
||||||
resume_epoch=self.model_factory.resume_epoch,
|
dataloaders,
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# init the optimizer
|
||||||
|
optimizer, scheduler = init_optimizer(
|
||||||
|
model,
|
||||||
|
optimizer_state=optimizer_state,
|
||||||
|
last_epoch=start_epoch,
|
||||||
|
**cfg.solver_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check the scheduler and stats have been initialized correctly
|
||||||
|
assert scheduler.last_epoch == stats.epoch + 1
|
||||||
|
assert scheduler.last_epoch == start_epoch
|
||||||
|
|
||||||
# Wrap all modules in the distributed library
|
# Wrap all modules in the distributed library
|
||||||
# Note: we don't pass the scheduler to prepare as it
|
# Note: we don't pass the scheduler to prepare as it
|
||||||
# doesn't need to be stepped at each optimizer step
|
# doesn't need to be stepped at each optimizer step
|
||||||
train_loader = dataloaders.train
|
train_loader = dataloaders.train
|
||||||
val_loader = dataloaders.val
|
val_loader = dataloaders.val
|
||||||
test_loader = dataloaders.test
|
|
||||||
if accelerator is not None:
|
if accelerator is not None:
|
||||||
(
|
(
|
||||||
model,
|
model,
|
||||||
@@ -207,24 +438,214 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
val_loader,
|
val_loader,
|
||||||
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
|
||||||
|
|
||||||
all_train_cameras = self.data_source.all_train_cameras
|
past_scheduler_lrs = []
|
||||||
|
# loop through epochs
|
||||||
|
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
|
||||||
|
# automatic new_epoch and plotting of stats at every epoch start
|
||||||
|
with stats:
|
||||||
|
|
||||||
# Enter the main training loop.
|
# Make sure to re-seed random generators to ensure reproducibility
|
||||||
self.training_loop.run(
|
# even after restart.
|
||||||
train_loader=train_loader,
|
_seed_all_random_engines(cfg.seed + epoch)
|
||||||
val_loader=val_loader,
|
|
||||||
test_loader=test_loader,
|
cur_lr = float(scheduler.get_last_lr()[-1])
|
||||||
model=model,
|
logger.info(f"scheduler lr = {cur_lr:1.2e}")
|
||||||
optimizer=optimizer,
|
past_scheduler_lrs.append(cur_lr)
|
||||||
scheduler=scheduler,
|
|
||||||
all_train_cameras=all_train_cameras,
|
# train loop
|
||||||
accelerator=accelerator,
|
trainvalidate(
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
epoch,
|
||||||
|
train_loader,
|
||||||
|
optimizer,
|
||||||
|
False,
|
||||||
|
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||||
device=device,
|
device=device,
|
||||||
exp_dir=self.exp_dir,
|
accelerator=accelerator,
|
||||||
stats=stats,
|
**cfg,
|
||||||
seed=self.seed,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# val loop (optional)
|
||||||
|
if val_loader is not None and epoch % cfg.validation_interval == 0:
|
||||||
|
trainvalidate(
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
epoch,
|
||||||
|
val_loader,
|
||||||
|
optimizer,
|
||||||
|
True,
|
||||||
|
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||||
|
device=device,
|
||||||
|
accelerator=accelerator,
|
||||||
|
**cfg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# eval loop (optional)
|
||||||
|
if (
|
||||||
|
dataloaders.test is not None
|
||||||
|
and cfg.test_interval > 0
|
||||||
|
and epoch % cfg.test_interval == 0
|
||||||
|
):
|
||||||
|
_run_eval(
|
||||||
|
model,
|
||||||
|
datasource.all_train_cameras,
|
||||||
|
dataloaders.test,
|
||||||
|
task,
|
||||||
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert stats.epoch == epoch, "inconsistent stats!"
|
||||||
|
|
||||||
|
# delete previous models if required
|
||||||
|
# save model only on the main process
|
||||||
|
if cfg.store_checkpoints and (
|
||||||
|
accelerator is None or accelerator.is_local_main_process
|
||||||
|
):
|
||||||
|
if cfg.store_checkpoints_purge > 0:
|
||||||
|
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
|
||||||
|
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
|
||||||
|
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
|
||||||
|
unwrapped_model = (
|
||||||
|
model if accelerator is None else accelerator.unwrap_model(model)
|
||||||
|
)
|
||||||
|
model_io.safe_save_model(
|
||||||
|
unwrapped_model, stats, outfile, optimizer=optimizer
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
new_lr = float(scheduler.get_last_lr()[-1])
|
||||||
|
if new_lr != cur_lr:
|
||||||
|
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||||
|
|
||||||
|
if cfg.test_when_finished:
|
||||||
|
_eval_and_dump(
|
||||||
|
cfg,
|
||||||
|
task,
|
||||||
|
datasource.all_train_cameras,
|
||||||
|
datasets,
|
||||||
|
dataloaders,
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _eval_and_dump(
|
||||||
|
cfg,
|
||||||
|
task: Task,
|
||||||
|
all_train_cameras: Optional[CamerasBase],
|
||||||
|
datasets: DatasetMap,
|
||||||
|
dataloaders: DataLoaderMap,
|
||||||
|
model,
|
||||||
|
stats,
|
||||||
|
device,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Run the evaluation loop with the test data loader and
|
||||||
|
save the predictions to the `exp_dir`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dataloader = dataloaders.test
|
||||||
|
|
||||||
|
if dataloader is None:
|
||||||
|
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||||
|
|
||||||
|
results = _run_eval(
|
||||||
|
model,
|
||||||
|
all_train_cameras,
|
||||||
|
dataloader,
|
||||||
|
task,
|
||||||
|
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# add the evaluation epoch to the results
|
||||||
|
for r in results:
|
||||||
|
r["eval_epoch"] = int(stats.epoch)
|
||||||
|
|
||||||
|
logger.info("Evaluation results")
|
||||||
|
evaluate.pretty_print_nvs_metrics(results)
|
||||||
|
|
||||||
|
with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
|
||||||
|
json.dump(results, f)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_eval_frame_data(frame_data):
|
||||||
|
"""
|
||||||
|
Masks the unknown image data to make sure we cannot use it at model evaluation time.
|
||||||
|
"""
|
||||||
|
frame_data_for_eval = copy.deepcopy(frame_data)
|
||||||
|
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
|
||||||
|
frame_data.image_rgb
|
||||||
|
)[:, None, None, None]
|
||||||
|
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
|
||||||
|
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
|
||||||
|
setattr(frame_data_for_eval, k, value_masked)
|
||||||
|
return frame_data_for_eval
|
||||||
|
|
||||||
|
|
||||||
|
def _run_eval(
|
||||||
|
model,
|
||||||
|
all_train_cameras,
|
||||||
|
loader,
|
||||||
|
task: Task,
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||||
|
device,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the evaluation loop on the test dataloader
|
||||||
|
"""
|
||||||
|
lpips_model = lpips.LPIPS(net="vgg")
|
||||||
|
lpips_model = lpips_model.to(device)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
per_batch_eval_results = []
|
||||||
|
logger.info("Evaluating model ...")
|
||||||
|
for frame_data in tqdm.tqdm(loader):
|
||||||
|
frame_data = frame_data.to(device)
|
||||||
|
|
||||||
|
# mask out the unknown images so that the model does not see them
|
||||||
|
frame_data_for_eval = _get_eval_frame_data(frame_data)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(
|
||||||
|
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Cannot use accelerate gather for two reasons:.
|
||||||
|
# (1) TypeError: Can't apply _gpu_gather_one on object of type
|
||||||
|
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
|
||||||
|
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
|
||||||
|
# (2) Same error above but for frame_data which contains Cameras.
|
||||||
|
|
||||||
|
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||||
|
|
||||||
|
per_batch_eval_results.append(
|
||||||
|
evaluate.eval_batch(
|
||||||
|
frame_data,
|
||||||
|
implicitron_render,
|
||||||
|
bg_color="black",
|
||||||
|
lpips_model=lpips_model,
|
||||||
|
source_cameras=all_train_cameras,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
_, category_result = evaluate.summarize_nvs_eval_results(
|
||||||
|
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||||
|
)
|
||||||
|
|
||||||
|
return category_result["results"]
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_all_random_engines(seed: int) -> None:
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
def _setup_envvars_for_cluster() -> bool:
|
def _setup_envvars_for_cluster() -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -257,20 +678,9 @@ def _setup_envvars_for_cluster() -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def dump_cfg(cfg: DictConfig) -> None:
|
expand_args_fields(ExperimentConfig)
|
||||||
remove_unused_components(cfg)
|
|
||||||
# dump the exp config to the exp dir
|
|
||||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
|
||||||
try:
|
|
||||||
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
|
||||||
OmegaConf.save(config=cfg, f=cfg_filename)
|
|
||||||
except PermissionError:
|
|
||||||
warnings.warn("Can't dump config due to insufficient permissions!")
|
|
||||||
|
|
||||||
|
|
||||||
expand_args_fields(Experiment)
|
|
||||||
cs = hydra.core.config_store.ConfigStore.instance()
|
cs = hydra.core.config_store.ConfigStore.instance()
|
||||||
cs.store(name="default_config", node=Experiment)
|
cs.store(name="default_config", node=ExperimentConfig)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="./configs/", config_name="default_config")
|
@hydra.main(config_path="./configs/", config_name="default_config")
|
||||||
@@ -284,14 +694,12 @@ def experiment(cfg: DictConfig) -> None:
|
|||||||
logger.info("Running locally")
|
logger.info("Running locally")
|
||||||
|
|
||||||
# TODO: The following may be needed for hydra/submitit it to work
|
# TODO: The following may be needed for hydra/submitit it to work
|
||||||
expand_args_fields(ImplicitronModelBase)
|
expand_args_fields(GenericModel)
|
||||||
expand_args_fields(AdaptiveRaySampler)
|
expand_args_fields(AdaptiveRaySampler)
|
||||||
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
|
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
|
||||||
expand_args_fields(ImplicitronDataSource)
|
expand_args_fields(ImplicitronDataSource)
|
||||||
|
|
||||||
experiment = Experiment(**cfg)
|
run_training(cfg)
|
||||||
dump_cfg(cfg)
|
|
||||||
experiment.run()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Any, Dict, Tuple
|
||||||
|
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
|
||||||
|
|
||||||
|
from .optimization import init_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class ExperimentConfig(Configurable):
|
||||||
|
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||||
|
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||||
|
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||||
|
architecture: str = "generic"
|
||||||
|
detect_anomaly: bool = False
|
||||||
|
eval_only: bool = False
|
||||||
|
exp_dir: str = "./data/default_experiment/"
|
||||||
|
exp_idx: int = 0
|
||||||
|
gpu_idx: int = 0
|
||||||
|
metric_print_interval: int = 5
|
||||||
|
resume: bool = True
|
||||||
|
resume_epoch: int = -1
|
||||||
|
seed: int = 0
|
||||||
|
store_checkpoints: bool = True
|
||||||
|
store_checkpoints_purge: int = 1
|
||||||
|
test_interval: int = -1
|
||||||
|
test_when_finished: bool = False
|
||||||
|
validation_interval: int = 1
|
||||||
|
visdom_env: str = ""
|
||||||
|
visdom_port: int = 8097
|
||||||
|
visdom_server: str = "http://127.0.0.1"
|
||||||
|
visualize_interval: int = 1000
|
||||||
|
clip_grad: float = 0.0
|
||||||
|
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||||
|
|
||||||
|
hydra: Dict[str, Any] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||||
|
"output_subdir": None, # disable storing the .hydra logs
|
||||||
|
}
|
||||||
|
)
|
||||||
@@ -1,133 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch.optim
|
|
||||||
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
|
||||||
from pytorch3d.implicitron.tools import model_io
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
|
||||||
registry,
|
|
||||||
ReplaceableBase,
|
|
||||||
run_auto_creation,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.stats import Stats
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFactoryBase(ReplaceableBase):
|
|
||||||
|
|
||||||
resume: bool = True # resume from the last checkpoint
|
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
|
||||||
"""
|
|
||||||
Initialize the model (possibly from a previously saved state).
|
|
||||||
|
|
||||||
Returns: An instance of ImplicitronModelBase.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def load_stats(self, **kwargs) -> Stats:
|
|
||||||
"""
|
|
||||||
Initialize or load a Stats object.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|
||||||
"""
|
|
||||||
A factory class that initializes an implicit rendering model.
|
|
||||||
|
|
||||||
Members:
|
|
||||||
model: An ImplicitronModelBase object.
|
|
||||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
|
||||||
passed to __call__. Failure to do so will return a model with ini-
|
|
||||||
tial weights unless `force_resume` is True.
|
|
||||||
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
|
||||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
|
||||||
force_resume: If True, throw a FileNotFoundError if `resume` is True but
|
|
||||||
a model checkpoint cannot be found.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
model: ImplicitronModelBase
|
|
||||||
model_class_type: str = "GenericModel"
|
|
||||||
resume: bool = True
|
|
||||||
resume_epoch: int = -1
|
|
||||||
force_resume: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
run_auto_creation(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
exp_dir: str,
|
|
||||||
accelerator: Optional[Accelerator] = None,
|
|
||||||
) -> ImplicitronModelBase:
|
|
||||||
"""
|
|
||||||
Returns an instance of `ImplicitronModelBase`, possibly loaded from a
|
|
||||||
checkpoint (if self.resume, self.resume_epoch specify so).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
exp_dir: Root experiment directory.
|
|
||||||
accelerator: An Accelerator object.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
model: The model with optionally loaded weights from checkpoint
|
|
||||||
|
|
||||||
Raise:
|
|
||||||
FileNotFoundError if `force_resume` is True but checkpoint not found.
|
|
||||||
"""
|
|
||||||
# Determine the network outputs that should be logged
|
|
||||||
if hasattr(self.model, "log_vars"):
|
|
||||||
log_vars = list(self.model.log_vars)
|
|
||||||
else:
|
|
||||||
log_vars = ["objective"]
|
|
||||||
|
|
||||||
if self.resume_epoch > 0:
|
|
||||||
# Resume from a certain epoch
|
|
||||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
|
|
||||||
else:
|
|
||||||
# Retrieve the last checkpoint
|
|
||||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
logger.info(f"Found previous model {model_path}")
|
|
||||||
if self.force_resume or self.resume:
|
|
||||||
logger.info("Resuming.")
|
|
||||||
|
|
||||||
map_location = None
|
|
||||||
if accelerator is not None and not accelerator.is_local_main_process:
|
|
||||||
map_location = {
|
|
||||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
|
||||||
}
|
|
||||||
model_state_dict = torch.load(
|
|
||||||
model_io.get_model_path(model_path), map_location=map_location
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
self.model.load_state_dict(model_state_dict, strict=True)
|
|
||||||
except RuntimeError as e:
|
|
||||||
logger.error(e)
|
|
||||||
logger.info(
|
|
||||||
"Cannot load state dict in strict mode! -> trying non-strict"
|
|
||||||
)
|
|
||||||
self.model.load_state_dict(model_state_dict, strict=False)
|
|
||||||
self.model.log_vars = log_vars
|
|
||||||
else:
|
|
||||||
logger.info("Not resuming -> starting from scratch.")
|
|
||||||
elif self.force_resume:
|
|
||||||
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
|
||||||
|
|
||||||
return self.model
|
|
||||||
109
projects/implicitron_trainer/impl/optimization.py
Normal file
109
projects/implicitron_trainer/impl/optimization.py
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def init_optimizer(
|
||||||
|
model: GenericModel,
|
||||||
|
optimizer_state: Optional[Dict[str, Any]],
|
||||||
|
last_epoch: int,
|
||||||
|
breed: str = "adam",
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
lr_policy: str = "multistep",
|
||||||
|
lr: float = 0.0005,
|
||||||
|
gamma: float = 0.1,
|
||||||
|
momentum: float = 0.9,
|
||||||
|
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||||
|
milestones: Tuple[int, ...] = (),
|
||||||
|
max_epochs: int = 1000,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the optimizer (optionally from checkpoint state)
|
||||||
|
and the learning rate scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model with optionally loaded weights
|
||||||
|
optimizer_state: The state dict for the optimizer. If None
|
||||||
|
it has not been loaded from checkpoint
|
||||||
|
last_epoch: If the model was loaded from checkpoint this will be the
|
||||||
|
number of the last epoch that was saved
|
||||||
|
breed: The type of optimizer to use e.g. adam
|
||||||
|
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||||
|
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||||
|
is supported.
|
||||||
|
lr: The value for the initial learning rate
|
||||||
|
gamma: Multiplicative factor of learning rate decay
|
||||||
|
momentum: Momentum factor for SGD optimizer
|
||||||
|
betas: Coefficients used for computing running averages of gradient and its square
|
||||||
|
in the Adam optimizer
|
||||||
|
milestones: List of increasing epoch indices at which the learning rate is
|
||||||
|
modified
|
||||||
|
max_epochs: The maximum number of epochs to run the optimizer for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||||
|
scheduler: Learning rate scheduler module
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
ValueError if `breed` or `lr_policy` are not supported.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Get the parameters to optimize
|
||||||
|
if hasattr(model, "_get_param_groups"): # use the model function
|
||||||
|
# pyre-ignore[29]
|
||||||
|
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||||
|
else:
|
||||||
|
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||||
|
p_groups = [{"params": allprm, "lr": lr}]
|
||||||
|
|
||||||
|
# Intialize the optimizer
|
||||||
|
if breed == "sgd":
|
||||||
|
optimizer = torch.optim.SGD(
|
||||||
|
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
elif breed == "adagrad":
|
||||||
|
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||||
|
elif breed == "adam":
|
||||||
|
optimizer = torch.optim.Adam(
|
||||||
|
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("no such solver type %s" % breed)
|
||||||
|
logger.info(" -> solver type = %s" % breed)
|
||||||
|
|
||||||
|
# Load state from checkpoint
|
||||||
|
if optimizer_state is not None:
|
||||||
|
logger.info(" -> setting loaded optimizer state")
|
||||||
|
optimizer.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
|
# Initialize the learning rate scheduler
|
||||||
|
if lr_policy == "multistep":
|
||||||
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||||
|
optimizer,
|
||||||
|
milestones=milestones,
|
||||||
|
gamma=gamma,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("no such lr policy %s" % lr_policy)
|
||||||
|
|
||||||
|
# When loading from checkpoint, this will make sure that the
|
||||||
|
# lr is correctly set even after returning
|
||||||
|
for _ in range(last_epoch):
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
return optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
enable_get_default_args(init_optimizer)
|
||||||
@@ -1,230 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
|
||||||
|
|
||||||
import torch.optim
|
|
||||||
|
|
||||||
from accelerate import Accelerator
|
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
|
||||||
from pytorch3d.implicitron.tools import model_io
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
|
||||||
registry,
|
|
||||||
ReplaceableBase,
|
|
||||||
run_auto_creation,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerFactoryBase(ReplaceableBase):
|
|
||||||
def __call__(
|
|
||||||
self, model: ImplicitronModelBase, **kwargs
|
|
||||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
|
||||||
"""
|
|
||||||
Initialize the optimizer and lr scheduler.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: The model with optionally loaded weights.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An optimizer module (optionally loaded from a checkpoint) and
|
|
||||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
|
||||||
lr_scheduler._LRScheduler).
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|
||||||
"""
|
|
||||||
A factory that initializes the optimizer and lr scheduler.
|
|
||||||
|
|
||||||
Members:
|
|
||||||
betas: Beta parameters for the Adam optimizer.
|
|
||||||
breed: The type of optimizer to use. We currently support SGD, Adagrad
|
|
||||||
and Adam.
|
|
||||||
exponential_lr_step_size: With Exponential policy only,
|
|
||||||
lr = lr * gamma ** (epoch/step_size)
|
|
||||||
gamma: Multiplicative factor of learning rate decay.
|
|
||||||
lr: The value for the initial learning rate.
|
|
||||||
lr_policy: The policy to use for learning rate. We currently support
|
|
||||||
MultiStepLR and Exponential policies.
|
|
||||||
momentum: A momentum value (for SGD only).
|
|
||||||
multistep_lr_milestones: With MultiStepLR policy only: list of
|
|
||||||
increasing epoch indices at which the learning rate is modified.
|
|
||||||
momentum: Momentum factor for SGD optimizer.
|
|
||||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
|
||||||
"""
|
|
||||||
|
|
||||||
betas: Tuple[float, ...] = (0.9, 0.999)
|
|
||||||
breed: str = "Adam"
|
|
||||||
exponential_lr_step_size: int = 250
|
|
||||||
gamma: float = 0.1
|
|
||||||
lr: float = 0.0005
|
|
||||||
lr_policy: str = "MultiStepLR"
|
|
||||||
momentum: float = 0.9
|
|
||||||
multistep_lr_milestones: tuple = ()
|
|
||||||
weight_decay: float = 0.0
|
|
||||||
linear_exponential_lr_milestone: int = 200
|
|
||||||
linear_exponential_start_gamma: float = 0.1
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
run_auto_creation(self)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
last_epoch: int,
|
|
||||||
model: ImplicitronModelBase,
|
|
||||||
accelerator: Optional[Accelerator] = None,
|
|
||||||
exp_dir: Optional[str] = None,
|
|
||||||
resume: bool = True,
|
|
||||||
resume_epoch: int = -1,
|
|
||||||
**kwargs,
|
|
||||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
|
||||||
"""
|
|
||||||
Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
last_epoch: If the model was loaded from checkpoint this will be the
|
|
||||||
number of the last epoch that was saved.
|
|
||||||
model: The model with optionally loaded weights.
|
|
||||||
accelerator: An optional Accelerator instance.
|
|
||||||
exp_dir: Root experiment directory.
|
|
||||||
resume: If True, attempt to load optimizer checkpoint from exp_dir.
|
|
||||||
Failure to do so will return a newly initialized optimizer.
|
|
||||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
|
||||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
|
||||||
Returns:
|
|
||||||
An optimizer module (optionally loaded from a checkpoint) and
|
|
||||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
|
||||||
lr_scheduler._LRScheduler).
|
|
||||||
"""
|
|
||||||
# Get the parameters to optimize
|
|
||||||
if hasattr(model, "_get_param_groups"): # use the model function
|
|
||||||
# pyre-ignore[29]
|
|
||||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
|
||||||
else:
|
|
||||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
|
||||||
p_groups = [{"params": allprm, "lr": self.lr}]
|
|
||||||
|
|
||||||
# Intialize the optimizer
|
|
||||||
if self.breed == "SGD":
|
|
||||||
optimizer = torch.optim.SGD(
|
|
||||||
p_groups,
|
|
||||||
lr=self.lr,
|
|
||||||
momentum=self.momentum,
|
|
||||||
weight_decay=self.weight_decay,
|
|
||||||
)
|
|
||||||
elif self.breed == "Adagrad":
|
|
||||||
optimizer = torch.optim.Adagrad(
|
|
||||||
p_groups, lr=self.lr, weight_decay=self.weight_decay
|
|
||||||
)
|
|
||||||
elif self.breed == "Adam":
|
|
||||||
optimizer = torch.optim.Adam(
|
|
||||||
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"No such solver type {self.breed}")
|
|
||||||
logger.info(f"Solver type = {self.breed}")
|
|
||||||
|
|
||||||
# Load state from checkpoint
|
|
||||||
optimizer_state = self._get_optimizer_state(
|
|
||||||
exp_dir,
|
|
||||||
accelerator,
|
|
||||||
resume_epoch=resume_epoch,
|
|
||||||
resume=resume,
|
|
||||||
)
|
|
||||||
if optimizer_state is not None:
|
|
||||||
logger.info("Setting loaded optimizer state.")
|
|
||||||
optimizer.load_state_dict(optimizer_state)
|
|
||||||
|
|
||||||
# Initialize the learning rate scheduler
|
|
||||||
if self.lr_policy.casefold() == "MultiStepLR".casefold():
|
|
||||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
||||||
optimizer,
|
|
||||||
milestones=self.multistep_lr_milestones,
|
|
||||||
gamma=self.gamma,
|
|
||||||
)
|
|
||||||
elif self.lr_policy.casefold() == "Exponential".casefold():
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
optimizer,
|
|
||||||
lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
|
|
||||||
verbose=False,
|
|
||||||
)
|
|
||||||
elif self.lr_policy.casefold() == "LinearExponential".casefold():
|
|
||||||
# linear learning rate progression between epochs 0 to
|
|
||||||
# self.linear_exponential_lr_milestone, followed by exponential
|
|
||||||
# lr decay for the rest of the epochs
|
|
||||||
def _get_lr(epoch: int):
|
|
||||||
m = self.linear_exponential_lr_milestone
|
|
||||||
if epoch < m:
|
|
||||||
w = (m - epoch) / m
|
|
||||||
gamma = w * self.linear_exponential_start_gamma + (1 - w)
|
|
||||||
else:
|
|
||||||
epoch_rest = epoch - m
|
|
||||||
gamma = self.gamma ** (epoch_rest / self.exponential_lr_step_size)
|
|
||||||
return gamma
|
|
||||||
|
|
||||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
|
||||||
optimizer, _get_lr, verbose=False
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError("no such lr policy %s" % self.lr_policy)
|
|
||||||
|
|
||||||
# When loading from checkpoint, this will make sure that the
|
|
||||||
# lr is correctly set even after returning.
|
|
||||||
for _ in range(last_epoch):
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
return optimizer, scheduler
|
|
||||||
|
|
||||||
def _get_optimizer_state(
|
|
||||||
self,
|
|
||||||
exp_dir: Optional[str],
|
|
||||||
accelerator: Optional[Accelerator] = None,
|
|
||||||
resume: bool = True,
|
|
||||||
resume_epoch: int = -1,
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Load an optimizer state from a checkpoint.
|
|
||||||
|
|
||||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
|
||||||
passed to __call__. Failure to do so will return a newly initialized
|
|
||||||
optimizer.
|
|
||||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
|
||||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
|
||||||
"""
|
|
||||||
if exp_dir is None or not resume:
|
|
||||||
return None
|
|
||||||
if resume_epoch > 0:
|
|
||||||
save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
|
||||||
if not os.path.isfile(save_path):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Cannot find optimizer from epoch {resume_epoch}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
save_path = model_io.find_last_checkpoint(exp_dir)
|
|
||||||
optimizer_state = None
|
|
||||||
if save_path is not None:
|
|
||||||
logger.info(f"Found previous optimizer state {save_path} -> resuming.")
|
|
||||||
opt_path = model_io.get_optimizer_path(save_path)
|
|
||||||
|
|
||||||
if os.path.isfile(opt_path):
|
|
||||||
map_location = None
|
|
||||||
if accelerator is not None and not accelerator.is_local_main_process:
|
|
||||||
map_location = {
|
|
||||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
|
||||||
}
|
|
||||||
optimizer_state = torch.load(opt_path, map_location)
|
|
||||||
else:
|
|
||||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
|
||||||
return optimizer_state
|
|
||||||
@@ -1,447 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from typing import Any, List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from accelerate import Accelerator
|
|
||||||
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
|
||||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
|
||||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
|
||||||
registry,
|
|
||||||
ReplaceableBase,
|
|
||||||
run_auto_creation,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.stats import Stats
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from .utils import seed_all_random_engines
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingLoopBase(ReplaceableBase):
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
val_loader: Optional[DataLoader],
|
|
||||||
test_loader: Optional[DataLoader],
|
|
||||||
model: ImplicitronModelBase,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
scheduler: Any,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
def load_stats(
|
|
||||||
self,
|
|
||||||
log_vars: List[str],
|
|
||||||
exp_dir: str,
|
|
||||||
resume: bool = True,
|
|
||||||
resume_epoch: int = -1,
|
|
||||||
**kwargs,
|
|
||||||
) -> Stats:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|
||||||
"""
|
|
||||||
Members:
|
|
||||||
eval_only: If True, only run evaluation using the test dataloader.
|
|
||||||
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
|
||||||
max_epochs: Train for this many epochs. Note that if the model was
|
|
||||||
loaded from a checkpoint, we will restart training at the appropriate
|
|
||||||
epoch and run for (max_epochs - checkpoint_epoch) epochs.
|
|
||||||
store_checkpoints: If True, store model and optimizer state checkpoints.
|
|
||||||
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
|
|
||||||
to this many epochs.
|
|
||||||
test_interval: Evaluate on a test dataloader each `test_interval` epochs.
|
|
||||||
test_when_finished: If True, evaluate on a test dataloader when training
|
|
||||||
completes.
|
|
||||||
validation_interval: Validate each `validation_interval` epochs.
|
|
||||||
clip_grad: Optionally clip the gradient norms.
|
|
||||||
If set to a value <=0.0, no clipping
|
|
||||||
metric_print_interval: The batch interval at which the stats should be
|
|
||||||
logged.
|
|
||||||
visualize_interval: The batch interval at which the visualizations
|
|
||||||
should be plotted
|
|
||||||
visdom_env: The name of the Visdom environment to use for plotting.
|
|
||||||
visdom_port: The Visdom port.
|
|
||||||
visdom_server: Address of the Visdom server.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Parameters of the outer training loop.
|
|
||||||
eval_only: bool = False
|
|
||||||
evaluator: EvaluatorBase
|
|
||||||
evaluator_class_type: str = "ImplicitronEvaluator"
|
|
||||||
max_epochs: int = 1000
|
|
||||||
store_checkpoints: bool = True
|
|
||||||
store_checkpoints_purge: int = 1
|
|
||||||
test_interval: int = -1
|
|
||||||
test_when_finished: bool = False
|
|
||||||
validation_interval: int = 1
|
|
||||||
|
|
||||||
# Gradient clipping.
|
|
||||||
clip_grad: float = 0.0
|
|
||||||
|
|
||||||
# Visualization/logging parameters.
|
|
||||||
metric_print_interval: int = 5
|
|
||||||
visualize_interval: int = 1000
|
|
||||||
visdom_env: str = ""
|
|
||||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
|
||||||
visdom_server: str = "http://127.0.0.1"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
run_auto_creation(self)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
val_loader: Optional[DataLoader],
|
|
||||||
test_loader: Optional[DataLoader],
|
|
||||||
model: ImplicitronModelBase,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
scheduler: Any,
|
|
||||||
accelerator: Optional[Accelerator],
|
|
||||||
all_train_cameras: Optional[CamerasBase],
|
|
||||||
device: torch.device,
|
|
||||||
exp_dir: str,
|
|
||||||
stats: Stats,
|
|
||||||
seed: int,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Entry point to run the training and validation loops
|
|
||||||
based on the specified config file.
|
|
||||||
"""
|
|
||||||
start_epoch = stats.epoch + 1
|
|
||||||
assert scheduler.last_epoch == stats.epoch + 1
|
|
||||||
assert scheduler.last_epoch == start_epoch
|
|
||||||
|
|
||||||
# only run evaluation on the test dataloader
|
|
||||||
if self.eval_only:
|
|
||||||
if test_loader is not None:
|
|
||||||
self.evaluator.run(
|
|
||||||
all_train_cameras=all_train_cameras,
|
|
||||||
dataloader=test_loader,
|
|
||||||
device=device,
|
|
||||||
dump_to_json=True,
|
|
||||||
epoch=stats.epoch,
|
|
||||||
exp_dir=exp_dir,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot evaluate and dump results to json, no test data provided."
|
|
||||||
)
|
|
||||||
|
|
||||||
# loop through epochs
|
|
||||||
for epoch in range(start_epoch, self.max_epochs):
|
|
||||||
# automatic new_epoch and plotting of stats at every epoch start
|
|
||||||
with stats:
|
|
||||||
|
|
||||||
# Make sure to re-seed random generators to ensure reproducibility
|
|
||||||
# even after restart.
|
|
||||||
seed_all_random_engines(seed + epoch)
|
|
||||||
|
|
||||||
cur_lr = float(scheduler.get_last_lr()[-1])
|
|
||||||
logger.debug(f"scheduler lr = {cur_lr:1.2e}")
|
|
||||||
|
|
||||||
# train loop
|
|
||||||
self._training_or_validation_epoch(
|
|
||||||
accelerator=accelerator,
|
|
||||||
device=device,
|
|
||||||
epoch=epoch,
|
|
||||||
loader=train_loader,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
stats=stats,
|
|
||||||
validation=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# val loop (optional)
|
|
||||||
if val_loader is not None and epoch % self.validation_interval == 0:
|
|
||||||
self._training_or_validation_epoch(
|
|
||||||
accelerator=accelerator,
|
|
||||||
device=device,
|
|
||||||
epoch=epoch,
|
|
||||||
loader=val_loader,
|
|
||||||
model=model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
stats=stats,
|
|
||||||
validation=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# eval loop (optional)
|
|
||||||
if (
|
|
||||||
test_loader is not None
|
|
||||||
and self.test_interval > 0
|
|
||||||
and epoch % self.test_interval == 0
|
|
||||||
):
|
|
||||||
self.evaluator.run(
|
|
||||||
all_train_cameras=all_train_cameras,
|
|
||||||
device=device,
|
|
||||||
dataloader=test_loader,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert stats.epoch == epoch, "inconsistent stats!"
|
|
||||||
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats)
|
|
||||||
|
|
||||||
scheduler.step()
|
|
||||||
new_lr = float(scheduler.get_last_lr()[-1])
|
|
||||||
if new_lr != cur_lr:
|
|
||||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
|
||||||
|
|
||||||
if self.test_when_finished:
|
|
||||||
if test_loader is not None:
|
|
||||||
self.evaluator.run(
|
|
||||||
all_train_cameras=all_train_cameras,
|
|
||||||
device=device,
|
|
||||||
dump_to_json=True,
|
|
||||||
epoch=stats.epoch,
|
|
||||||
exp_dir=exp_dir,
|
|
||||||
dataloader=test_loader,
|
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot evaluate and dump results to json, no test data provided."
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_stats(
|
|
||||||
self,
|
|
||||||
log_vars: List[str],
|
|
||||||
exp_dir: str,
|
|
||||||
resume: bool = True,
|
|
||||||
resume_epoch: int = -1,
|
|
||||||
**kwargs,
|
|
||||||
) -> Stats:
|
|
||||||
"""
|
|
||||||
Load Stats that correspond to the model's log_vars and resume_epoch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_vars: A list of variable names to log. Should be a subset of the
|
|
||||||
`preds` returned by the forward function of the corresponding
|
|
||||||
ImplicitronModelBase instance.
|
|
||||||
exp_dir: Root experiment directory.
|
|
||||||
resume: If False, do not load stats from the checkpoint speci-
|
|
||||||
fied by resume and resume_epoch; instead, create a fresh stats object.
|
|
||||||
|
|
||||||
stats: The stats structure (optionally loaded from checkpoint)
|
|
||||||
"""
|
|
||||||
# Init the stats struct
|
|
||||||
visdom_env_charts = (
|
|
||||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
|
||||||
)
|
|
||||||
stats = Stats(
|
|
||||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
|
||||||
list(log_vars),
|
|
||||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
|
||||||
visdom_env=visdom_env_charts,
|
|
||||||
verbose=False,
|
|
||||||
visdom_server=self.visdom_server,
|
|
||||||
visdom_port=self.visdom_port,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_path = None
|
|
||||||
if resume:
|
|
||||||
if resume_epoch > 0:
|
|
||||||
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
|
||||||
if not os.path.isfile(model_path):
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Cannot find stats from epoch {resume_epoch}."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
stats_path = model_io.get_stats_path(model_path)
|
|
||||||
stats_load = model_io.load_stats(stats_path)
|
|
||||||
|
|
||||||
# Determine if stats should be reset
|
|
||||||
if resume:
|
|
||||||
if stats_load is None:
|
|
||||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
|
||||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
|
||||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
|
||||||
|
|
||||||
# Reset the stats struct
|
|
||||||
for _ in range(last_epoch + 1):
|
|
||||||
stats.new_epoch()
|
|
||||||
assert last_epoch == stats.epoch
|
|
||||||
else:
|
|
||||||
logger.info(f"Found previous stats in {stats_path} -> resuming.")
|
|
||||||
stats = stats_load
|
|
||||||
|
|
||||||
# Update stats properties incase it was reset on load
|
|
||||||
stats.visdom_env = visdom_env_charts
|
|
||||||
stats.visdom_server = self.visdom_server
|
|
||||||
stats.visdom_port = self.visdom_port
|
|
||||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
|
||||||
stats.synchronize_logged_vars(log_vars)
|
|
||||||
else:
|
|
||||||
logger.info("Clearing stats")
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def _training_or_validation_epoch(
|
|
||||||
self,
|
|
||||||
epoch: int,
|
|
||||||
loader: DataLoader,
|
|
||||||
model: ImplicitronModelBase,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
stats: Stats,
|
|
||||||
validation: bool,
|
|
||||||
*,
|
|
||||||
accelerator: Optional[Accelerator],
|
|
||||||
bp_var: str = "objective",
|
|
||||||
device: torch.device,
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
This is the main loop for training and evaluation including:
|
|
||||||
model forward pass, loss computation, backward pass and visualization.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
epoch: The index of the current epoch
|
|
||||||
loader: The dataloader to use for the loop
|
|
||||||
model: The model module optionally loaded from checkpoint
|
|
||||||
optimizer: The optimizer module optionally loaded from checkpoint
|
|
||||||
stats: The stats struct, also optionally loaded from checkpoint
|
|
||||||
validation: If true, run the loop with the model in eval mode
|
|
||||||
and skip the backward pass
|
|
||||||
accelerator: An optional Accelerator instance.
|
|
||||||
bp_var: The name of the key in the model output `preds` dict which
|
|
||||||
should be used as the loss for the backward pass.
|
|
||||||
device: The device on which to run the model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if validation:
|
|
||||||
model.eval()
|
|
||||||
trainmode = "val"
|
|
||||||
else:
|
|
||||||
model.train()
|
|
||||||
trainmode = "train"
|
|
||||||
|
|
||||||
t_start = time.time()
|
|
||||||
|
|
||||||
# get the visdom env name
|
|
||||||
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode
|
|
||||||
viz = vis_utils.get_visdom_connection(
|
|
||||||
server=stats.visdom_server,
|
|
||||||
port=stats.visdom_port,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Iterate through the batches
|
|
||||||
n_batches = len(loader)
|
|
||||||
for it, net_input in enumerate(loader):
|
|
||||||
last_iter = it == n_batches - 1
|
|
||||||
|
|
||||||
# move to gpu where possible (in place)
|
|
||||||
net_input = net_input.to(device)
|
|
||||||
|
|
||||||
# run the forward pass
|
|
||||||
if not validation:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
preds = model(
|
|
||||||
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
with torch.no_grad():
|
|
||||||
preds = model(
|
|
||||||
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
|
|
||||||
)
|
|
||||||
|
|
||||||
# make sure we dont overwrite something
|
|
||||||
assert all(k not in preds for k in net_input.keys())
|
|
||||||
# merge everything into one big dict
|
|
||||||
preds.update(net_input)
|
|
||||||
|
|
||||||
# update the stats logger
|
|
||||||
stats.update(preds, time_start=t_start, stat_set=trainmode)
|
|
||||||
# pyre-ignore [16]
|
|
||||||
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
|
|
||||||
|
|
||||||
# print textual status update
|
|
||||||
if it % self.metric_print_interval == 0 or last_iter:
|
|
||||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
|
||||||
|
|
||||||
# visualize results
|
|
||||||
if (
|
|
||||||
(accelerator is None or accelerator.is_local_main_process)
|
|
||||||
and self.visualize_interval > 0
|
|
||||||
and it % self.visualize_interval == 0
|
|
||||||
):
|
|
||||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
|
||||||
if hasattr(model, "visualize"):
|
|
||||||
# pyre-ignore [29]
|
|
||||||
model.visualize(
|
|
||||||
viz,
|
|
||||||
visdom_env_imgs,
|
|
||||||
preds,
|
|
||||||
prefix,
|
|
||||||
)
|
|
||||||
|
|
||||||
# optimizer step
|
|
||||||
if not validation:
|
|
||||||
loss = preds[bp_var]
|
|
||||||
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
|
||||||
# backprop
|
|
||||||
if accelerator is None:
|
|
||||||
loss.backward()
|
|
||||||
else:
|
|
||||||
accelerator.backward(loss)
|
|
||||||
if self.clip_grad > 0.0:
|
|
||||||
# Optionally clip the gradient norms.
|
|
||||||
total_norm = torch.nn.utils.clip_grad_norm(
|
|
||||||
model.parameters(), self.clip_grad
|
|
||||||
)
|
|
||||||
if total_norm > self.clip_grad:
|
|
||||||
logger.debug(
|
|
||||||
f"Clipping gradient: {total_norm}"
|
|
||||||
+ f" with coef {self.clip_grad / float(total_norm)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
def _checkpoint(
|
|
||||||
self,
|
|
||||||
accelerator: Optional[Accelerator],
|
|
||||||
epoch: int,
|
|
||||||
exp_dir: str,
|
|
||||||
model: ImplicitronModelBase,
|
|
||||||
optimizer: torch.optim.Optimizer,
|
|
||||||
stats: Stats,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Save a model and its corresponding Stats object to a file, if
|
|
||||||
`self.store_checkpoints` is True. In addition, if
|
|
||||||
`self.store_checkpoints_purge` is True, remove any checkpoints older
|
|
||||||
than `self.store_checkpoints_purge` epochs old.
|
|
||||||
"""
|
|
||||||
if self.store_checkpoints and (
|
|
||||||
accelerator is None or accelerator.is_local_main_process
|
|
||||||
):
|
|
||||||
if self.store_checkpoints_purge > 0:
|
|
||||||
for prev_epoch in range(epoch - self.store_checkpoints_purge):
|
|
||||||
model_io.purge_epoch(exp_dir, prev_epoch)
|
|
||||||
outfile = model_io.get_checkpoint(exp_dir, epoch)
|
|
||||||
unwrapped_model = (
|
|
||||||
model if accelerator is None else accelerator.unwrap_model(model)
|
|
||||||
)
|
|
||||||
model_io.safe_save_model(
|
|
||||||
unwrapped_model, stats, outfile, optimizer=optimizer
|
|
||||||
)
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def seed_all_random_engines(seed: int) -> None:
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
@@ -1,15 +1,296 @@
|
|||||||
data_source_class_type: ImplicitronDataSource
|
generic_model_args:
|
||||||
model_factory_class_type: ImplicitronModelFactory
|
mask_images: true
|
||||||
optimizer_factory_class_type: ImplicitronOptimizerFactory
|
mask_depths: true
|
||||||
training_loop_class_type: ImplicitronTrainingLoop
|
render_image_width: 400
|
||||||
seed: 42
|
render_image_height: 400
|
||||||
detect_anomaly: false
|
mask_threshold: 0.5
|
||||||
exp_dir: ./data/default_experiment/
|
output_rasterized_mc: false
|
||||||
hydra:
|
bg_color:
|
||||||
run:
|
- 0.0
|
||||||
dir: .
|
- 0.0
|
||||||
output_subdir: null
|
- 0.0
|
||||||
data_source_ImplicitronDataSource_args:
|
num_passes: 1
|
||||||
|
chunk_size_grid: 4096
|
||||||
|
render_features_dimensions: 3
|
||||||
|
tqdm_trigger_threshold: 16
|
||||||
|
n_train_target_views: 1
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
global_encoder_class_type: null
|
||||||
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
|
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
||||||
|
image_feature_extractor_class_type: null
|
||||||
|
view_pooler_enabled: false
|
||||||
|
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
||||||
|
view_metrics_class_type: ViewMetrics
|
||||||
|
regularization_metrics_class_type: RegularizationMetrics
|
||||||
|
loss_weights:
|
||||||
|
loss_rgb_mse: 1.0
|
||||||
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
|
loss_mask_bce: 0.0
|
||||||
|
loss_prev_stage_mask_bce: 0.0
|
||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_rgb_huber
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_mask_beta_prior
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_density_tv
|
||||||
|
- loss_depth_neg_penalty
|
||||||
|
- loss_autodecoder_norm
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_psnr_fg
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
|
global_encoder_HarmonicTimeEncoder_args:
|
||||||
|
n_harmonic_functions: 10
|
||||||
|
append_input: true
|
||||||
|
time_divisor: 1.0
|
||||||
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
|
encoding_dim: 0
|
||||||
|
n_instances: 0
|
||||||
|
init_scale: 1.0
|
||||||
|
ignore_input: false
|
||||||
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
image_width: 400
|
||||||
|
image_height: 400
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
scene_extent: 8.0
|
||||||
|
scene_center:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
|
image_width: 400
|
||||||
|
image_height: 400
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
min_depth: 0.1
|
||||||
|
max_depth: 8.0
|
||||||
|
renderer_LSTMRenderer_args:
|
||||||
|
num_raymarch_steps: 10
|
||||||
|
init_depth: 17.0
|
||||||
|
init_depth_noise_std: 0.0005
|
||||||
|
hidden_size: 16
|
||||||
|
n_feature_channels: 256
|
||||||
|
bg_color: null
|
||||||
|
verbose: false
|
||||||
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
|
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
||||||
|
n_pts_per_ray_fine_training: 64
|
||||||
|
n_pts_per_ray_fine_evaluation: 64
|
||||||
|
stratified_sampling_coarse_training: true
|
||||||
|
stratified_sampling_coarse_evaluation: false
|
||||||
|
append_coarse_samples_to_fine: true
|
||||||
|
density_noise_std_train: 0.0
|
||||||
|
return_weights: false
|
||||||
|
raymarcher_CumsumRaymarcher_args:
|
||||||
|
surface_thickness: 1
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
background_opacity: 0.0
|
||||||
|
density_relu: true
|
||||||
|
blend_output: false
|
||||||
|
raymarcher_EmissionAbsorptionRaymarcher_args:
|
||||||
|
surface_thickness: 1
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
background_opacity: 10000000000.0
|
||||||
|
density_relu: true
|
||||||
|
blend_output: false
|
||||||
|
renderer_SignedDistanceFunctionRenderer_args:
|
||||||
|
render_features_dimensions: 3
|
||||||
|
ray_tracer_args:
|
||||||
|
object_bounding_sphere: 1.0
|
||||||
|
sdf_threshold: 5.0e-05
|
||||||
|
line_search_step: 0.5
|
||||||
|
line_step_iters: 1
|
||||||
|
sphere_tracing_iters: 10
|
||||||
|
n_steps: 100
|
||||||
|
n_secant_steps: 8
|
||||||
|
ray_normal_coloring_network_args:
|
||||||
|
feature_vector_size: 3
|
||||||
|
mode: idr
|
||||||
|
d_in: 9
|
||||||
|
d_out: 3
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_dir: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
soft_mask_alpha: 50.0
|
||||||
|
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||||
|
name: resnet34
|
||||||
|
pretrained: true
|
||||||
|
stages:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
normalize_image: true
|
||||||
|
image_rescale: 0.16
|
||||||
|
first_max_pool: true
|
||||||
|
proj_dim: 32
|
||||||
|
l2_norm: true
|
||||||
|
add_masks: true
|
||||||
|
add_images: true
|
||||||
|
global_average_pool: false
|
||||||
|
feature_rescale: 1.0
|
||||||
|
view_pooler_args:
|
||||||
|
feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator
|
||||||
|
view_sampler_args:
|
||||||
|
masked_sampling: false
|
||||||
|
sampling_mode: bilinear
|
||||||
|
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
weight_by_ray_angle_gamma: 1.0
|
||||||
|
min_ray_angle_weight: 0.1
|
||||||
|
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
reduction_functions:
|
||||||
|
- AVG
|
||||||
|
- STD
|
||||||
|
weight_by_ray_angle_gamma: 1.0
|
||||||
|
min_ray_angle_weight: 0.1
|
||||||
|
feature_aggregator_IdentityFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
feature_aggregator_ReductionFeatureAggregator_args:
|
||||||
|
exclude_target_view: true
|
||||||
|
exclude_target_view_mask_features: true
|
||||||
|
concatenate_output: true
|
||||||
|
reduction_functions:
|
||||||
|
- AVG
|
||||||
|
- STD
|
||||||
|
implicit_function_IdrFeatureField_args:
|
||||||
|
feature_vector_size: 3
|
||||||
|
d_in: 3
|
||||||
|
d_out: 1
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
geometric_init: true
|
||||||
|
bias: 1.0
|
||||||
|
skip_in: []
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_xyz: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
encoding_dim: 0
|
||||||
|
implicit_function_NeRFormerImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
latent_dim: 0
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
color_dim: 3
|
||||||
|
transformer_dim_down_factor: 2.0
|
||||||
|
n_hidden_neurons_xyz: 80
|
||||||
|
n_layers_xyz: 2
|
||||||
|
append_xyz:
|
||||||
|
- 1
|
||||||
|
implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
latent_dim: 0
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
color_dim: 3
|
||||||
|
transformer_dim_down_factor: 1.0
|
||||||
|
n_hidden_neurons_xyz: 256
|
||||||
|
n_layers_xyz: 8
|
||||||
|
append_xyz:
|
||||||
|
- 5
|
||||||
|
implicit_function_SRNHyperNetImplicitFunction_args:
|
||||||
|
hypernet_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
n_hidden_units_hypernet: 256
|
||||||
|
n_layers_hypernet: 1
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
latent_dim_hypernet: 0
|
||||||
|
latent_dim: 0
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
implicit_function_SRNImplicitFunction_args:
|
||||||
|
raymarch_function_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
latent_dim: 0
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
raymarch_function: null
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
view_metrics_ViewMetrics_args: {}
|
||||||
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
|
solver_args:
|
||||||
|
breed: adam
|
||||||
|
weight_decay: 0.0
|
||||||
|
lr_policy: multistep
|
||||||
|
lr: 0.0005
|
||||||
|
gamma: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.999
|
||||||
|
milestones: []
|
||||||
|
max_epochs: 1000
|
||||||
|
data_source_args:
|
||||||
dataset_map_provider_class_type: ???
|
dataset_map_provider_class_type: ???
|
||||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||||
dataset_map_provider_BlenderDatasetMapProvider_args:
|
dataset_map_provider_BlenderDatasetMapProvider_args:
|
||||||
@@ -64,11 +345,17 @@ data_source_ImplicitronDataSource_args:
|
|||||||
dataset_class_type: JsonIndexDataset
|
dataset_class_type: JsonIndexDataset
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
dataset_JsonIndexDataset_args:
|
dataset_JsonIndexDataset_args:
|
||||||
|
path_manager: null
|
||||||
|
frame_annotations_file: ''
|
||||||
|
sequence_annotations_file: ''
|
||||||
|
subset_lists_file: ''
|
||||||
|
subsets: null
|
||||||
limit_to: 0
|
limit_to: 0
|
||||||
limit_sequences_to: 0
|
limit_sequences_to: 0
|
||||||
pick_sequence: []
|
pick_sequence: []
|
||||||
exclude_sequence: []
|
exclude_sequence: []
|
||||||
limit_category_to: []
|
limit_category_to: []
|
||||||
|
dataset_root: ''
|
||||||
load_images: true
|
load_images: true
|
||||||
load_depths: true
|
load_depths: true
|
||||||
load_depth_masks: true
|
load_depth_masks: true
|
||||||
@@ -86,6 +373,7 @@ data_source_ImplicitronDataSource_args:
|
|||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
seed: 0
|
seed: 0
|
||||||
sort_frames: false
|
sort_frames: false
|
||||||
|
eval_batches: null
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||||
@@ -95,16 +383,6 @@ data_source_ImplicitronDataSource_args:
|
|||||||
n_known_frames_for_test: null
|
n_known_frames_for_test: null
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
downscale_factor: 4
|
|
||||||
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
|
|
||||||
num_views: 40
|
|
||||||
data_file: null
|
|
||||||
azimuth_range: 180.0
|
|
||||||
resolution: 128
|
|
||||||
use_point_light: true
|
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
|
||||||
path_manager_factory_PathManagerFactory_args:
|
|
||||||
silence_logs: true
|
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
@@ -118,309 +396,30 @@ data_source_ImplicitronDataSource_args:
|
|||||||
sample_consecutive_frames: false
|
sample_consecutive_frames: false
|
||||||
consecutive_frames_max_gap: 0
|
consecutive_frames_max_gap: 0
|
||||||
consecutive_frames_max_gap_seconds: 0.1
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
|
architecture: generic
|
||||||
batch_size: 1
|
detect_anomaly: false
|
||||||
num_workers: 0
|
eval_only: false
|
||||||
dataset_length_train: 0
|
exp_dir: ./data/default_experiment/
|
||||||
dataset_length_val: 0
|
exp_idx: 0
|
||||||
dataset_length_test: 0
|
gpu_idx: 0
|
||||||
model_factory_ImplicitronModelFactory_args:
|
metric_print_interval: 5
|
||||||
resume: true
|
resume: true
|
||||||
model_class_type: GenericModel
|
resume_epoch: -1
|
||||||
resume_epoch: -1
|
seed: 0
|
||||||
force_resume: false
|
store_checkpoints: true
|
||||||
model_GenericModel_args:
|
store_checkpoints_purge: 1
|
||||||
log_vars:
|
test_interval: -1
|
||||||
- loss_rgb_psnr_fg
|
test_when_finished: false
|
||||||
- loss_rgb_psnr
|
validation_interval: 1
|
||||||
- loss_rgb_mse
|
visdom_env: ''
|
||||||
- loss_rgb_huber
|
visdom_port: 8097
|
||||||
- loss_depth_abs
|
visdom_server: http://127.0.0.1
|
||||||
- loss_depth_abs_fg
|
visualize_interval: 1000
|
||||||
- loss_mask_neg_iou
|
clip_grad: 0.0
|
||||||
- loss_mask_bce
|
camera_difficulty_bin_breaks:
|
||||||
- loss_mask_beta_prior
|
- 0.97
|
||||||
- loss_eikonal
|
- 0.98
|
||||||
- loss_density_tv
|
hydra:
|
||||||
- loss_depth_neg_penalty
|
run:
|
||||||
- loss_autodecoder_norm
|
dir: .
|
||||||
- loss_prev_stage_rgb_mse
|
output_subdir: null
|
||||||
- loss_prev_stage_rgb_psnr_fg
|
|
||||||
- loss_prev_stage_rgb_psnr
|
|
||||||
- loss_prev_stage_mask_bce
|
|
||||||
- objective
|
|
||||||
- epoch
|
|
||||||
- sec/it
|
|
||||||
mask_images: true
|
|
||||||
mask_depths: true
|
|
||||||
render_image_width: 400
|
|
||||||
render_image_height: 400
|
|
||||||
mask_threshold: 0.5
|
|
||||||
output_rasterized_mc: false
|
|
||||||
bg_color:
|
|
||||||
- 0.0
|
|
||||||
- 0.0
|
|
||||||
- 0.0
|
|
||||||
num_passes: 1
|
|
||||||
chunk_size_grid: 4096
|
|
||||||
render_features_dimensions: 3
|
|
||||||
tqdm_trigger_threshold: 16
|
|
||||||
n_train_target_views: 1
|
|
||||||
sampling_mode_training: mask_sample
|
|
||||||
sampling_mode_evaluation: full_grid
|
|
||||||
global_encoder_class_type: null
|
|
||||||
raysampler_class_type: AdaptiveRaySampler
|
|
||||||
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
|
||||||
image_feature_extractor_class_type: null
|
|
||||||
view_pooler_enabled: false
|
|
||||||
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
|
||||||
view_metrics_class_type: ViewMetrics
|
|
||||||
regularization_metrics_class_type: RegularizationMetrics
|
|
||||||
loss_weights:
|
|
||||||
loss_rgb_mse: 1.0
|
|
||||||
loss_prev_stage_rgb_mse: 1.0
|
|
||||||
loss_mask_bce: 0.0
|
|
||||||
loss_prev_stage_mask_bce: 0.0
|
|
||||||
global_encoder_HarmonicTimeEncoder_args:
|
|
||||||
n_harmonic_functions: 10
|
|
||||||
append_input: true
|
|
||||||
time_divisor: 1.0
|
|
||||||
global_encoder_SequenceAutodecoder_args:
|
|
||||||
autodecoder_args:
|
|
||||||
encoding_dim: 0
|
|
||||||
n_instances: 1
|
|
||||||
init_scale: 1.0
|
|
||||||
ignore_input: false
|
|
||||||
raysampler_AdaptiveRaySampler_args:
|
|
||||||
n_pts_per_ray_training: 64
|
|
||||||
n_pts_per_ray_evaluation: 64
|
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
|
||||||
stratified_point_sampling_training: true
|
|
||||||
stratified_point_sampling_evaluation: false
|
|
||||||
scene_extent: 8.0
|
|
||||||
scene_center:
|
|
||||||
- 0.0
|
|
||||||
- 0.0
|
|
||||||
- 0.0
|
|
||||||
raysampler_NearFarRaySampler_args:
|
|
||||||
n_pts_per_ray_training: 64
|
|
||||||
n_pts_per_ray_evaluation: 64
|
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
|
||||||
stratified_point_sampling_training: true
|
|
||||||
stratified_point_sampling_evaluation: false
|
|
||||||
min_depth: 0.1
|
|
||||||
max_depth: 8.0
|
|
||||||
renderer_LSTMRenderer_args:
|
|
||||||
num_raymarch_steps: 10
|
|
||||||
init_depth: 17.0
|
|
||||||
init_depth_noise_std: 0.0005
|
|
||||||
hidden_size: 16
|
|
||||||
n_feature_channels: 256
|
|
||||||
bg_color: null
|
|
||||||
verbose: false
|
|
||||||
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
|
||||||
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
|
||||||
n_pts_per_ray_fine_training: 64
|
|
||||||
n_pts_per_ray_fine_evaluation: 64
|
|
||||||
stratified_sampling_coarse_training: true
|
|
||||||
stratified_sampling_coarse_evaluation: false
|
|
||||||
append_coarse_samples_to_fine: true
|
|
||||||
density_noise_std_train: 0.0
|
|
||||||
return_weights: false
|
|
||||||
raymarcher_CumsumRaymarcher_args:
|
|
||||||
surface_thickness: 1
|
|
||||||
bg_color:
|
|
||||||
- 0.0
|
|
||||||
background_opacity: 0.0
|
|
||||||
density_relu: true
|
|
||||||
blend_output: false
|
|
||||||
raymarcher_EmissionAbsorptionRaymarcher_args:
|
|
||||||
surface_thickness: 1
|
|
||||||
bg_color:
|
|
||||||
- 0.0
|
|
||||||
background_opacity: 10000000000.0
|
|
||||||
density_relu: true
|
|
||||||
blend_output: false
|
|
||||||
renderer_SignedDistanceFunctionRenderer_args:
|
|
||||||
ray_normal_coloring_network_args:
|
|
||||||
feature_vector_size: 3
|
|
||||||
mode: idr
|
|
||||||
d_in: 9
|
|
||||||
d_out: 3
|
|
||||||
dims:
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
weight_norm: true
|
|
||||||
n_harmonic_functions_dir: 0
|
|
||||||
pooled_feature_dim: 0
|
|
||||||
bg_color:
|
|
||||||
- 0.0
|
|
||||||
soft_mask_alpha: 50.0
|
|
||||||
ray_tracer_args:
|
|
||||||
sdf_threshold: 5.0e-05
|
|
||||||
line_search_step: 0.5
|
|
||||||
line_step_iters: 1
|
|
||||||
sphere_tracing_iters: 10
|
|
||||||
n_steps: 100
|
|
||||||
n_secant_steps: 8
|
|
||||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
|
||||||
name: resnet34
|
|
||||||
pretrained: true
|
|
||||||
stages:
|
|
||||||
- 1
|
|
||||||
- 2
|
|
||||||
- 3
|
|
||||||
- 4
|
|
||||||
normalize_image: true
|
|
||||||
image_rescale: 0.16
|
|
||||||
first_max_pool: true
|
|
||||||
proj_dim: 32
|
|
||||||
l2_norm: true
|
|
||||||
add_masks: true
|
|
||||||
add_images: true
|
|
||||||
global_average_pool: false
|
|
||||||
feature_rescale: 1.0
|
|
||||||
view_pooler_args:
|
|
||||||
feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator
|
|
||||||
view_sampler_args:
|
|
||||||
masked_sampling: false
|
|
||||||
sampling_mode: bilinear
|
|
||||||
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
|
||||||
exclude_target_view: true
|
|
||||||
exclude_target_view_mask_features: true
|
|
||||||
concatenate_output: true
|
|
||||||
weight_by_ray_angle_gamma: 1.0
|
|
||||||
min_ray_angle_weight: 0.1
|
|
||||||
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
|
|
||||||
exclude_target_view: true
|
|
||||||
exclude_target_view_mask_features: true
|
|
||||||
concatenate_output: true
|
|
||||||
reduction_functions:
|
|
||||||
- AVG
|
|
||||||
- STD
|
|
||||||
weight_by_ray_angle_gamma: 1.0
|
|
||||||
min_ray_angle_weight: 0.1
|
|
||||||
feature_aggregator_IdentityFeatureAggregator_args:
|
|
||||||
exclude_target_view: true
|
|
||||||
exclude_target_view_mask_features: true
|
|
||||||
concatenate_output: true
|
|
||||||
feature_aggregator_ReductionFeatureAggregator_args:
|
|
||||||
exclude_target_view: true
|
|
||||||
exclude_target_view_mask_features: true
|
|
||||||
concatenate_output: true
|
|
||||||
reduction_functions:
|
|
||||||
- AVG
|
|
||||||
- STD
|
|
||||||
implicit_function_IdrFeatureField_args:
|
|
||||||
d_in: 3
|
|
||||||
d_out: 1
|
|
||||||
dims:
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
- 512
|
|
||||||
geometric_init: true
|
|
||||||
bias: 1.0
|
|
||||||
skip_in: []
|
|
||||||
weight_norm: true
|
|
||||||
n_harmonic_functions_xyz: 0
|
|
||||||
pooled_feature_dim: 0
|
|
||||||
implicit_function_NeRFormerImplicitFunction_args:
|
|
||||||
n_harmonic_functions_xyz: 10
|
|
||||||
n_harmonic_functions_dir: 4
|
|
||||||
n_hidden_neurons_dir: 128
|
|
||||||
input_xyz: true
|
|
||||||
xyz_ray_dir_in_camera_coords: false
|
|
||||||
transformer_dim_down_factor: 2.0
|
|
||||||
n_hidden_neurons_xyz: 80
|
|
||||||
n_layers_xyz: 2
|
|
||||||
append_xyz:
|
|
||||||
- 1
|
|
||||||
implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
|
||||||
n_harmonic_functions_xyz: 10
|
|
||||||
n_harmonic_functions_dir: 4
|
|
||||||
n_hidden_neurons_dir: 128
|
|
||||||
input_xyz: true
|
|
||||||
xyz_ray_dir_in_camera_coords: false
|
|
||||||
transformer_dim_down_factor: 1.0
|
|
||||||
n_hidden_neurons_xyz: 256
|
|
||||||
n_layers_xyz: 8
|
|
||||||
append_xyz:
|
|
||||||
- 5
|
|
||||||
implicit_function_SRNHyperNetImplicitFunction_args:
|
|
||||||
hypernet_args:
|
|
||||||
n_harmonic_functions: 3
|
|
||||||
n_hidden_units: 256
|
|
||||||
n_layers: 2
|
|
||||||
n_hidden_units_hypernet: 256
|
|
||||||
n_layers_hypernet: 1
|
|
||||||
in_features: 3
|
|
||||||
out_features: 256
|
|
||||||
xyz_in_camera_coords: false
|
|
||||||
pixel_generator_args:
|
|
||||||
n_harmonic_functions: 4
|
|
||||||
n_hidden_units: 256
|
|
||||||
n_hidden_units_color: 128
|
|
||||||
n_layers: 2
|
|
||||||
in_features: 256
|
|
||||||
out_features: 3
|
|
||||||
ray_dir_in_camera_coords: false
|
|
||||||
implicit_function_SRNImplicitFunction_args:
|
|
||||||
raymarch_function_args:
|
|
||||||
n_harmonic_functions: 3
|
|
||||||
n_hidden_units: 256
|
|
||||||
n_layers: 2
|
|
||||||
in_features: 3
|
|
||||||
out_features: 256
|
|
||||||
xyz_in_camera_coords: false
|
|
||||||
raymarch_function: null
|
|
||||||
pixel_generator_args:
|
|
||||||
n_harmonic_functions: 4
|
|
||||||
n_hidden_units: 256
|
|
||||||
n_hidden_units_color: 128
|
|
||||||
n_layers: 2
|
|
||||||
in_features: 256
|
|
||||||
out_features: 3
|
|
||||||
ray_dir_in_camera_coords: false
|
|
||||||
view_metrics_ViewMetrics_args: {}
|
|
||||||
regularization_metrics_RegularizationMetrics_args: {}
|
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
|
||||||
betas:
|
|
||||||
- 0.9
|
|
||||||
- 0.999
|
|
||||||
breed: Adam
|
|
||||||
exponential_lr_step_size: 250
|
|
||||||
gamma: 0.1
|
|
||||||
lr: 0.0005
|
|
||||||
lr_policy: MultiStepLR
|
|
||||||
momentum: 0.9
|
|
||||||
multistep_lr_milestones: []
|
|
||||||
weight_decay: 0.0
|
|
||||||
linear_exponential_lr_milestone: 200
|
|
||||||
linear_exponential_start_gamma: 0.1
|
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
|
||||||
eval_only: false
|
|
||||||
evaluator_class_type: ImplicitronEvaluator
|
|
||||||
max_epochs: 1000
|
|
||||||
store_checkpoints: true
|
|
||||||
store_checkpoints_purge: 1
|
|
||||||
test_interval: -1
|
|
||||||
test_when_finished: false
|
|
||||||
validation_interval: 1
|
|
||||||
clip_grad: 0.0
|
|
||||||
metric_print_interval: 5
|
|
||||||
visualize_interval: 1000
|
|
||||||
visdom_env: ''
|
|
||||||
visdom_port: 8097
|
|
||||||
visdom_server: http://127.0.0.1
|
|
||||||
evaluator_ImplicitronEvaluator_args:
|
|
||||||
camera_difficulty_bin_breaks:
|
|
||||||
- 0.97
|
|
||||||
- 0.98
|
|
||||||
is_multisequence: false
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -13,7 +12,6 @@ from hydra import compose, initialize_config_dir
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from .. import experiment
|
from .. import experiment
|
||||||
from .utils import intercept_logs
|
|
||||||
|
|
||||||
|
|
||||||
def interactive_testing_requested() -> bool:
|
def interactive_testing_requested() -> bool:
|
||||||
@@ -35,10 +33,7 @@ DEBUG: bool = False
|
|||||||
# TODO:
|
# TODO:
|
||||||
# - add enough files to skateboard_first_5 that this works on RE.
|
# - add enough files to skateboard_first_5 that this works on RE.
|
||||||
# - share common code with PyTorch3D tests?
|
# - share common code with PyTorch3D tests?
|
||||||
|
# - deal with the temporary output files this test creates
|
||||||
|
|
||||||
def _parse_float_from_log(line):
|
|
||||||
return float(line.split()[-1])
|
|
||||||
|
|
||||||
|
|
||||||
class TestExperiment(unittest.TestCase):
|
class TestExperiment(unittest.TestCase):
|
||||||
@@ -49,18 +44,15 @@ class TestExperiment(unittest.TestCase):
|
|||||||
# Test making minimal changes to the dataclass defaults.
|
# Test making minimal changes to the dataclass defaults.
|
||||||
if not interactive_testing_requested() or not internal:
|
if not interactive_testing_requested() or not internal:
|
||||||
return
|
return
|
||||||
|
cfg = OmegaConf.structured(experiment.ExperimentConfig)
|
||||||
# Manually override config values. Note that this is not necessary out-
|
cfg.data_source_args.dataset_map_provider_class_type = (
|
||||||
# side of the tests!
|
|
||||||
cfg = OmegaConf.structured(experiment.Experiment)
|
|
||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
|
||||||
"JsonIndexDatasetMapProvider"
|
"JsonIndexDatasetMapProvider"
|
||||||
)
|
)
|
||||||
dataset_args = (
|
dataset_args = (
|
||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
cfg.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
)
|
)
|
||||||
dataloader_args = (
|
dataloader_args = (
|
||||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
cfg.data_source_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
)
|
)
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
@@ -70,80 +62,18 @@ class TestExperiment(unittest.TestCase):
|
|||||||
dataset_args.dataset_JsonIndexDataset_args.image_width = 80
|
dataset_args.dataset_JsonIndexDataset_args.image_width = 80
|
||||||
dataloader_args.dataset_length_train = 1
|
dataloader_args.dataset_length_train = 1
|
||||||
dataloader_args.dataset_length_val = 1
|
dataloader_args.dataset_length_val = 1
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
cfg.solver_args.max_epochs = 2
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.multistep_lr_milestones = [
|
|
||||||
0,
|
|
||||||
1,
|
|
||||||
]
|
|
||||||
|
|
||||||
if DEBUG:
|
experiment.run_training(cfg)
|
||||||
experiment.dump_cfg(cfg)
|
|
||||||
with intercept_logs(
|
|
||||||
logger_name="projects.implicitron_trainer.impl.training_loop",
|
|
||||||
regexp="LR change!",
|
|
||||||
) as intercepted_logs:
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
# Make sure LR decreased on 0th and 1st epoch 10fold.
|
|
||||||
self.assertEqual(intercepted_logs[0].split()[-1], "5e-06")
|
|
||||||
|
|
||||||
def test_exponential_lr(self):
|
|
||||||
# Test making minimal changes to the dataclass defaults.
|
|
||||||
if not interactive_testing_requested():
|
|
||||||
return
|
|
||||||
cfg = OmegaConf.structured(experiment.Experiment)
|
|
||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
|
|
||||||
"JsonIndexDatasetMapProvider"
|
|
||||||
)
|
|
||||||
dataset_args = (
|
|
||||||
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
|
||||||
)
|
|
||||||
dataloader_args = (
|
|
||||||
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
|
||||||
)
|
|
||||||
dataset_args.category = "skateboard"
|
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
|
||||||
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
|
|
||||||
dataset_args.dataset_JsonIndexDataset_args.image_height = 80
|
|
||||||
dataset_args.dataset_JsonIndexDataset_args.image_width = 80
|
|
||||||
dataloader_args.dataset_length_train = 1
|
|
||||||
dataloader_args.dataset_length_val = 1
|
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
|
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
|
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
|
|
||||||
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
|
|
||||||
2
|
|
||||||
)
|
|
||||||
|
|
||||||
if DEBUG:
|
|
||||||
experiment.dump_cfg(cfg)
|
|
||||||
with intercept_logs(
|
|
||||||
logger_name="projects.implicitron_trainer.impl.training_loop",
|
|
||||||
regexp="LR change!",
|
|
||||||
) as intercepted_logs:
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
# Make sure we followed the exponential lr schedule with gamma=0.1,
|
|
||||||
# exponential_lr_step_size=2 -- so after two epochs, should
|
|
||||||
# decrease lr 10x to 5e-5.
|
|
||||||
self.assertEqual(intercepted_logs[0].split()[-1], "0.00015811388300841897")
|
|
||||||
self.assertEqual(intercepted_logs[1].split()[-1], "5e-05")
|
|
||||||
|
|
||||||
def test_yaml_contents(self):
|
def test_yaml_contents(self):
|
||||||
# Check that the default config values, defined by Experiment and its
|
cfg = OmegaConf.structured(experiment.ExperimentConfig)
|
||||||
# members, is what we expect it to be.
|
|
||||||
cfg = OmegaConf.structured(experiment.Experiment)
|
|
||||||
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
(DATA_DIR / "experiment.yaml").write_text(yaml)
|
(DATA_DIR / "experiment.yaml").write_text(yaml)
|
||||||
self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text())
|
self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text())
|
||||||
|
|
||||||
def test_load_configs(self):
|
def test_load_configs(self):
|
||||||
# Check that all the pre-prepared configs are valid.
|
|
||||||
config_files = []
|
config_files = []
|
||||||
|
|
||||||
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
|
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
|
||||||
@@ -159,78 +89,3 @@ class TestExperiment(unittest.TestCase):
|
|||||||
with self.subTest(file.name):
|
with self.subTest(file.name):
|
||||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||||
compose(file.name)
|
compose(file.name)
|
||||||
|
|
||||||
|
|
||||||
class TestNerfRepro(unittest.TestCase):
|
|
||||||
@unittest.skip("This test runs full blender training.")
|
|
||||||
def test_nerf_blender(self):
|
|
||||||
# Train vanilla NERF.
|
|
||||||
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
|
|
||||||
if not interactive_testing_requested():
|
|
||||||
return
|
|
||||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
|
||||||
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment.dump_cfg(cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
@unittest.skip("This test runs full llff training.")
|
|
||||||
def test_nerf_llff(self):
|
|
||||||
# Train vanilla NERF.
|
|
||||||
# Set env vars LLFF_DATASET_ROOT and LLFF_SINGLESEQ_CLASS first!
|
|
||||||
LLFF_SINGLESEQ_CLASS = os.environ["LLFF_SINGLESEQ_CLASS"]
|
|
||||||
if not interactive_testing_requested():
|
|
||||||
return
|
|
||||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
|
||||||
cfg = compose(
|
|
||||||
config_name=f"repro_singleseq_nerf_llff_{LLFF_SINGLESEQ_CLASS}",
|
|
||||||
overrides=[],
|
|
||||||
)
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment.dump_cfg(cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
@unittest.skip("This test checks resuming of the NeRF training.")
|
|
||||||
def test_nerf_blender_resume(self):
|
|
||||||
# Train one train batch of NeRF, then resume for one more batch.
|
|
||||||
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
|
|
||||||
if not interactive_testing_requested():
|
|
||||||
return
|
|
||||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
|
||||||
with tempfile.TemporaryDirectory() as exp_dir:
|
|
||||||
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
|
|
||||||
cfg.exp_dir = exp_dir
|
|
||||||
|
|
||||||
# set dataset len to 1
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
(
|
|
||||||
cfg
|
|
||||||
.data_source_ImplicitronDataSource_args
|
|
||||||
.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
|
||||||
.dataset_length_train
|
|
||||||
) = 1
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
# run for one epoch
|
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 1
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment.dump_cfg(cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
# update num epochs + 2, let the optimizer resume
|
|
||||||
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 3
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
# start from scratch
|
|
||||||
cfg.model_factory_ImplicitronModelFactory_args.resume = False
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|
||||||
# force resume from epoch 1
|
|
||||||
cfg.model_factory_ImplicitronModelFactory_args.resume = True
|
|
||||||
cfg.model_factory_ImplicitronModelFactory_args.force_resume = True
|
|
||||||
cfg.model_factory_ImplicitronModelFactory_args.resume_epoch = 1
|
|
||||||
experiment_runner = experiment.Experiment(**cfg)
|
|
||||||
experiment_runner.run()
|
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import contextlib
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def intercept_logs(logger_name: str, regexp: str):
|
|
||||||
# Intercept logs that match a regexp, from a given logger.
|
|
||||||
intercepted_messages = []
|
|
||||||
logger = logging.getLogger(logger_name)
|
|
||||||
|
|
||||||
class LoggerInterceptor(logging.Filter):
|
|
||||||
def filter(self, record):
|
|
||||||
message = record.getMessage()
|
|
||||||
if re.search(regexp, message):
|
|
||||||
intercepted_messages.append(message)
|
|
||||||
return True
|
|
||||||
|
|
||||||
interceptor = LoggerInterceptor()
|
|
||||||
logger.addFilter(interceptor)
|
|
||||||
try:
|
|
||||||
yield intercepted_messages
|
|
||||||
finally:
|
|
||||||
logger.removeFilter(interceptor)
|
|
||||||
@@ -22,6 +22,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as Fu
|
import torch.nn.functional as Fu
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||||
@@ -36,7 +37,7 @@ from pytorch3d.implicitron.tools.vis_utils import (
|
|||||||
)
|
)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from .experiment import Experiment
|
from .experiment import init_model
|
||||||
|
|
||||||
|
|
||||||
def render_sequence(
|
def render_sequence(
|
||||||
@@ -343,14 +344,13 @@ def export_scenes(
|
|||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
||||||
|
|
||||||
# Load the previously trained model
|
# Load the previously trained model
|
||||||
experiment = Experiment(config)
|
model, _, _ = init_model(cfg=config, force_load=True, load_model_only=True)
|
||||||
model = experiment.model_factory(force_resume=True)
|
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# Setup the dataset
|
# Setup the dataset
|
||||||
data_source = experiment.data_source
|
datasource = ImplicitronDataSource(**config.data_source_args)
|
||||||
dataset_map, _ = data_source.get_datasets_and_dataloaders()
|
dataset_map = datasource.dataset_map_provider.get_dataset_map()
|
||||||
dataset = dataset_map[split]
|
dataset = dataset_map[split]
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
raise ValueError(f"{split} dataset not provided")
|
raise ValueError(f"{split} dataset not provided")
|
||||||
|
|||||||
@@ -4,4 +4,4 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
__version__ = "0.7.0"
|
__version__ = "0.6.2"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import torch
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Some functions which depend on PyTorch or Python versions.
|
Some functions which depend on PyTorch versions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@@ -79,12 +79,3 @@ def meshgrid_ij(
|
|||||||
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
|
# pyre-fixme[6]: For 1st param expected `Union[List[Tensor], Tensor]` but got
|
||||||
# `Union[Sequence[Tensor], Tensor]`.
|
# `Union[Sequence[Tensor], Tensor]`.
|
||||||
return torch.meshgrid(*A)
|
return torch.meshgrid(*A)
|
||||||
|
|
||||||
|
|
||||||
def prod(iterable, *, start=1):
|
|
||||||
"""
|
|
||||||
Like math.prod in Python 3.8 and later.
|
|
||||||
"""
|
|
||||||
for i in iterable:
|
|
||||||
start *= i
|
|
||||||
return start
|
|
||||||
|
|||||||
@@ -61,84 +61,6 @@ class DataLoaderMapProviderBase(ReplaceableBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class SimpleDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|
||||||
"""
|
|
||||||
Trivial implementation of DataLoaderMapProviderBase.
|
|
||||||
|
|
||||||
If a dataset returns batches from get_eval_batches(), then
|
|
||||||
they will be what the corresponding dataloader returns,
|
|
||||||
independently of any of the fields on this class.
|
|
||||||
|
|
||||||
Otherwise, returns shuffled batches.
|
|
||||||
"""
|
|
||||||
|
|
||||||
batch_size: int = 1
|
|
||||||
num_workers: int = 0
|
|
||||||
dataset_length_train: int = 0
|
|
||||||
dataset_length_val: int = 0
|
|
||||||
dataset_length_test: int = 0
|
|
||||||
|
|
||||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
|
||||||
"""
|
|
||||||
Returns a collection of data loaders for a given collection of datasets.
|
|
||||||
"""
|
|
||||||
return DataLoaderMap(
|
|
||||||
train=self._make_data_loader(
|
|
||||||
datasets.train,
|
|
||||||
self.dataset_length_train,
|
|
||||||
),
|
|
||||||
val=self._make_data_loader(
|
|
||||||
datasets.val,
|
|
||||||
self.dataset_length_val,
|
|
||||||
),
|
|
||||||
test=self._make_data_loader(
|
|
||||||
datasets.test,
|
|
||||||
self.dataset_length_test,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _make_data_loader(
|
|
||||||
self,
|
|
||||||
dataset: Optional[DatasetBase],
|
|
||||||
num_batches: int,
|
|
||||||
) -> Optional[DataLoader[FrameData]]:
|
|
||||||
"""
|
|
||||||
Returns the dataloader for a dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: the dataset
|
|
||||||
num_batches: possible ceiling on number of batches per epoch
|
|
||||||
"""
|
|
||||||
if dataset is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
data_loader_kwargs = {
|
|
||||||
"num_workers": self.num_workers,
|
|
||||||
"collate_fn": dataset.frame_data_type.collate,
|
|
||||||
}
|
|
||||||
|
|
||||||
eval_batches = dataset.get_eval_batches()
|
|
||||||
if eval_batches is not None:
|
|
||||||
return DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_sampler=eval_batches,
|
|
||||||
**data_loader_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_batches > 0:
|
|
||||||
num_samples = self.batch_size * num_batches
|
|
||||||
else:
|
|
||||||
num_samples = None
|
|
||||||
sampler = RandomSampler(dataset, replacement=True, num_samples=num_samples)
|
|
||||||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
|
||||||
return DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_sampler=batch_sampler,
|
|
||||||
**data_loader_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DoublePoolBatchSampler(Sampler[List[int]]):
|
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||||
"""
|
"""
|
||||||
Batch sampler for making random batches of a single frame
|
Batch sampler for making random batches of a single frame
|
||||||
@@ -199,7 +121,7 @@ class DoublePoolBatchSampler(Sampler[List[int]]):
|
|||||||
torch.randperm(len(self.first_indices), generator=self.generator)
|
torch.randperm(len(self.first_indices), generator=self.generator)
|
||||||
for _ in range(n_copies)
|
for _ in range(n_copies)
|
||||||
]
|
]
|
||||||
i_first = torch.cat(raw_indices)[:num_batches]
|
i_first = torch.concat(raw_indices)[:num_batches]
|
||||||
else:
|
else:
|
||||||
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
|
i_first = torch.randperm(len(self.first_indices), generator=self.generator)
|
||||||
first_indices = [self.first_indices[i] for i in i_first]
|
first_indices = [self.first_indices[i] for i in i_first]
|
||||||
|
|||||||
@@ -15,11 +15,10 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
|
|
||||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||||
from .rendered_mesh_dataset_map_provider import RenderedMeshDatasetMapProvider # noqa
|
|
||||||
|
|
||||||
|
|
||||||
class DataSourceBase(ReplaceableBase):
|
class DataSourceBase(ReplaceableBase):
|
||||||
@@ -68,6 +67,9 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
|||||||
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
||||||
return datasets, dataloaders
|
return datasets, dataloaders
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
return self.dataset_map_provider.get_task()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_train_cameras(self) -> Optional[CamerasBase]:
|
def all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
if self._all_train_cameras_cache is None: # pyre-ignore[16]
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from typing import Iterator, Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
@@ -52,6 +53,11 @@ class DatasetMap:
|
|||||||
yield self.test
|
yield self.test
|
||||||
|
|
||||||
|
|
||||||
|
class Task(Enum):
|
||||||
|
SINGLE_SEQUENCE = "singlesequence"
|
||||||
|
MULTI_SEQUENCE = "multisequence"
|
||||||
|
|
||||||
|
|
||||||
class DatasetMapProviderBase(ReplaceableBase):
|
class DatasetMapProviderBase(ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
Base class for a provider of training / validation and testing
|
Base class for a provider of training / validation and testing
|
||||||
@@ -65,6 +71,9 @@ class DatasetMapProviderBase(ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
"""
|
"""
|
||||||
If the data is all for a single scene, returns a list
|
If the data is all for a single scene, returns a list
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from typing import (
|
|||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TYPE_CHECKING,
|
TypedDict,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -45,16 +45,10 @@ from .utils import is_known_frame_scalar
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
class FrameAnnotsEntry(TypedDict):
|
||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
class FrameAnnotsEntry(TypedDict):
|
|
||||||
subset: Optional[str]
|
subset: Optional[str]
|
||||||
frame_annotation: types.FrameAnnotation
|
frame_annotation: types.FrameAnnotation
|
||||||
|
|
||||||
else:
|
|
||||||
FrameAnnotsEntry = dict
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||||
@@ -118,11 +112,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
eval_batches: A list of batches that form the evaluation set;
|
eval_batches: A list of batches that form the evaluation set;
|
||||||
list of batch-sized lists of indices corresponding to __getitem__
|
list of batch-sized lists of indices corresponding to __getitem__
|
||||||
of this class, thus it can be used directly as a batch sampler.
|
of this class, thus it can be used directly as a batch sampler.
|
||||||
eval_batch_index:
|
|
||||||
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
|
||||||
A list of batches of frames described as (sequence_name, frame_idx)
|
|
||||||
that can form the evaluation set, `eval_batches` will be set from this.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_annotations_type: ClassVar[
|
frame_annotations_type: ClassVar[
|
||||||
@@ -158,7 +147,6 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
seed: int = 0
|
seed: int = 0
|
||||||
sort_frames: bool = False
|
sort_frames: bool = False
|
||||||
eval_batches: Any = None
|
eval_batches: Any = None
|
||||||
eval_batch_index: Any = None
|
|
||||||
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||||
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||||
|
|
||||||
@@ -171,22 +159,8 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self._sort_frames()
|
self._sort_frames()
|
||||||
self._load_subset_lists()
|
self._load_subset_lists()
|
||||||
self._filter_db() # also computes sequence indices
|
self._filter_db() # also computes sequence indices
|
||||||
self._extract_and_set_eval_batches()
|
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
def _extract_and_set_eval_batches(self):
|
|
||||||
"""
|
|
||||||
Sets eval_batches based on input eval_batch_index.
|
|
||||||
"""
|
|
||||||
if self.eval_batch_index is not None:
|
|
||||||
if self.eval_batches is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"Cannot define both eval_batch_index and eval_batches."
|
|
||||||
)
|
|
||||||
self.eval_batches = self.seq_frame_index_to_dataset_index(
|
|
||||||
self.eval_batch_index
|
|
||||||
)
|
|
||||||
|
|
||||||
def is_filtered(self):
|
def is_filtered(self):
|
||||||
"""
|
"""
|
||||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig, open_dict
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
expand_args_fields,
|
expand_args_fields,
|
||||||
registry,
|
registry,
|
||||||
@@ -17,7 +17,12 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
from .dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
from .json_index_dataset import JsonIndexDataset
|
from .json_index_dataset import JsonIndexDataset
|
||||||
|
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@@ -52,7 +57,6 @@ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
|||||||
_NEED_CONTROL: Tuple[str, ...] = (
|
_NEED_CONTROL: Tuple[str, ...] = (
|
||||||
"dataset_root",
|
"dataset_root",
|
||||||
"eval_batches",
|
"eval_batches",
|
||||||
"eval_batch_index",
|
|
||||||
"n_frames_per_sequence",
|
"n_frames_per_sequence",
|
||||||
"path_manager",
|
"path_manager",
|
||||||
"pick_sequence",
|
"pick_sequence",
|
||||||
@@ -113,6 +117,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
Called by get_default_args(JsonIndexDatasetMapProvider) to
|
||||||
not expose certain fields of each dataset class.
|
not expose certain fields of each dataset class.
|
||||||
"""
|
"""
|
||||||
|
with open_dict(args):
|
||||||
for key in _NEED_CONTROL:
|
for key in _NEED_CONTROL:
|
||||||
del args[key]
|
del args[key]
|
||||||
|
|
||||||
@@ -154,7 +159,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||||
# to the names of the subsets in the CO3D dataset.
|
# to the names of the subsets in the CO3D dataset.
|
||||||
set_names_mapping = _get_co3d_set_names_mapping(
|
set_names_mapping = _get_co3d_set_names_mapping(
|
||||||
self.task_str,
|
self.get_task(),
|
||||||
self.test_on_train,
|
self.test_on_train,
|
||||||
self.only_test_set,
|
self.only_test_set,
|
||||||
)
|
)
|
||||||
@@ -179,7 +184,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
eval_batch_index = json.load(f)
|
eval_batch_index = json.load(f)
|
||||||
restrict_sequence_name = self.restrict_sequence_name
|
restrict_sequence_name = self.restrict_sequence_name
|
||||||
|
|
||||||
if self.task_str == "singlesequence":
|
if self.get_task() == Task.SINGLE_SEQUENCE:
|
||||||
if (
|
if (
|
||||||
self.test_restrict_sequence_id is None
|
self.test_restrict_sequence_id is None
|
||||||
or self.test_restrict_sequence_id < 0
|
or self.test_restrict_sequence_id < 0
|
||||||
@@ -207,10 +212,6 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
]
|
]
|
||||||
# overwrite the restrict_sequence_name
|
# overwrite the restrict_sequence_name
|
||||||
restrict_sequence_name = [eval_sequence_name]
|
restrict_sequence_name = [eval_sequence_name]
|
||||||
if len(restrict_sequence_name) > 0:
|
|
||||||
eval_batch_index = [
|
|
||||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
|
||||||
]
|
|
||||||
|
|
||||||
dataset_type: Type[JsonIndexDataset] = registry.get(
|
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||||
JsonIndexDataset, self.dataset_class_type
|
JsonIndexDataset, self.dataset_class_type
|
||||||
@@ -238,9 +239,15 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
n_frames_per_sequence=-1,
|
n_frames_per_sequence=-1,
|
||||||
subsets=set_names_mapping["test"],
|
subsets=set_names_mapping["test"],
|
||||||
pick_sequence=restrict_sequence_name,
|
pick_sequence=restrict_sequence_name,
|
||||||
eval_batch_index=eval_batch_index,
|
|
||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
|
if len(restrict_sequence_name) > 0:
|
||||||
|
eval_batch_index = [
|
||||||
|
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||||
|
]
|
||||||
|
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
eval_batch_index
|
||||||
|
)
|
||||||
dataset_map = DatasetMap(
|
dataset_map = DatasetMap(
|
||||||
train=train_dataset, val=val_dataset, test=test_dataset
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
)
|
)
|
||||||
@@ -261,11 +268,12 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
return self.dataset_map
|
return self.dataset_map
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_task(self) -> Task:
|
||||||
if self.task_str == "multisequence":
|
return Task(self.task_str)
|
||||||
return None
|
|
||||||
|
|
||||||
assert self.task_str == "singlesequence"
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
if Task(self.task_str) == Task.MULTI_SEQUENCE:
|
||||||
|
return None
|
||||||
|
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
train_dataset = self.dataset_map.train
|
train_dataset = self.dataset_map.train
|
||||||
@@ -274,7 +282,7 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
|
|
||||||
def _get_co3d_set_names_mapping(
|
def _get_co3d_set_names_mapping(
|
||||||
task_str: str,
|
task: Task,
|
||||||
test_on_train: bool,
|
test_on_train: bool,
|
||||||
only_test: bool,
|
only_test: bool,
|
||||||
) -> Dict[str, List[str]]:
|
) -> Dict[str, List[str]]:
|
||||||
@@ -288,7 +296,7 @@ def _get_co3d_set_names_mapping(
|
|||||||
- val (if not test_on_train)
|
- val (if not test_on_train)
|
||||||
- test (if not test_on_train)
|
- test (if not test_on_train)
|
||||||
"""
|
"""
|
||||||
single_seq = task_str == "singlesequence"
|
single_seq = task == Task.SINGLE_SEQUENCE
|
||||||
|
|
||||||
if only_test:
|
if only_test:
|
||||||
set_names_mapping = {}
|
set_names_mapping = {}
|
||||||
|
|||||||
@@ -9,13 +9,13 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
PathManagerFactory,
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@@ -29,19 +29,6 @@ from pytorch3d.renderer.cameras import CamerasBase
|
|||||||
|
|
||||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||||
|
|
||||||
# _NEED_CONTROL is a list of those elements of JsonIndexDataset which
|
|
||||||
# are not directly specified for it in the config but come from the
|
|
||||||
# DatasetMapProvider.
|
|
||||||
_NEED_CONTROL: Tuple[str, ...] = (
|
|
||||||
"dataset_root",
|
|
||||||
"eval_batches",
|
|
||||||
"eval_batch_index",
|
|
||||||
"path_manager",
|
|
||||||
"subsets",
|
|
||||||
"frame_annotations_file",
|
|
||||||
"sequence_annotations_file",
|
|
||||||
"subset_lists_file",
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -189,20 +176,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
path_manager = self.path_manager_factory.get()
|
path_manager = self.path_manager_factory.get()
|
||||||
|
|
||||||
if path_manager is not None:
|
|
||||||
path_managed_frame_file = path_manager.get_local_path(frame_file)
|
|
||||||
else:
|
|
||||||
path_managed_frame_file = frame_file
|
|
||||||
if not os.path.isfile(path_managed_frame_file):
|
|
||||||
# The frame_file does not exist.
|
|
||||||
# Most probably the user has not specified the root folder.
|
|
||||||
raise ValueError(
|
|
||||||
f"Looking for frame annotations in {path_managed_frame_file}."
|
|
||||||
+ " Please specify a correct dataset_root folder."
|
|
||||||
+ " Note: By default the root folder is taken from the"
|
|
||||||
+ " CO3DV2_DATASET_ROOT environment variable."
|
|
||||||
)
|
|
||||||
|
|
||||||
# setup the common dataset arguments
|
# setup the common dataset arguments
|
||||||
common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args")
|
common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args")
|
||||||
common_dataset_kwargs = {
|
common_dataset_kwargs = {
|
||||||
@@ -296,15 +269,6 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
train=train_dataset, val=val_dataset, test=test_dataset
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
"""
|
|
||||||
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
|
|
||||||
not expose certain fields of each dataset class.
|
|
||||||
"""
|
|
||||||
for key in _NEED_CONTROL:
|
|
||||||
del args[key]
|
|
||||||
|
|
||||||
def create_dataset(self):
|
def create_dataset(self):
|
||||||
# The dataset object is created inside `self.get_dataset_map`
|
# The dataset object is created inside `self.get_dataset_map`
|
||||||
pass
|
pass
|
||||||
@@ -335,6 +299,12 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
)
|
)
|
||||||
return category_to_subset_name_list
|
return category_to_subset_name_list
|
||||||
|
|
||||||
|
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
||||||
|
return {
|
||||||
|
"manyview": Task.SINGLE_SEQUENCE,
|
||||||
|
"fewview": Task.MULTI_SEQUENCE,
|
||||||
|
}[self.subset_name.split("_")[0]]
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
train_dataset = self.dataset_map.train
|
train_dataset = self.dataset_map.train
|
||||||
|
|||||||
@@ -32,21 +32,17 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
|||||||
and test datasets, and this many random training frames are added to
|
and test datasets, and this many random training frames are added to
|
||||||
each test batch. If not set, test batches each contain just a single
|
each test batch. If not set, test batches each contain just a single
|
||||||
testing frame.
|
testing frame.
|
||||||
downscale_factor: determines image sizes.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
downscale_factor: int = 4
|
|
||||||
|
|
||||||
def _load_data(self) -> None:
|
def _load_data(self) -> None:
|
||||||
path_manager = self.path_manager_factory.get()
|
path_manager = self.path_manager_factory.get()
|
||||||
images, poses, _ = load_llff_data(
|
images, poses, _ = load_llff_data(
|
||||||
self.base_dir, factor=self.downscale_factor, path_manager=path_manager
|
self.base_dir, factor=8, path_manager=path_manager
|
||||||
)
|
)
|
||||||
hwf = poses[0, :3, -1]
|
hwf = poses[0, :3, -1]
|
||||||
poses = poses[:, :3, :4]
|
poses = poses[:, :3, :4]
|
||||||
|
|
||||||
llffhold = 8
|
i_test = np.arange(images.shape[0])[::8]
|
||||||
i_test = np.arange(images.shape[0])[::llffhold]
|
|
||||||
i_test_index = set(i_test.tolist())
|
i_test_index = set(i_test.tolist())
|
||||||
i_train = np.array(
|
i_train = np.array(
|
||||||
[i for i in np.arange(images.shape[0]) if i not in i_test_index]
|
[i for i in np.arange(images.shape[0]) if i not in i_test_index]
|
||||||
|
|||||||
@@ -294,7 +294,7 @@ def _local_path(path_manager, path):
|
|||||||
|
|
||||||
def _ls(path_manager, path):
|
def _ls(path_manager, path):
|
||||||
if path_manager is None:
|
if path_manager is None:
|
||||||
return os.listdir(path)
|
return os.path.listdir(path)
|
||||||
return path_manager.ls(path)
|
return path_manager.ls(path)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,211 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
from os.path import dirname, join, realpath
|
|
||||||
from typing import Optional, Tuple
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
|
||||||
expand_args_fields,
|
|
||||||
registry,
|
|
||||||
run_auto_creation,
|
|
||||||
)
|
|
||||||
from pytorch3d.io import IO
|
|
||||||
from pytorch3d.renderer import (
|
|
||||||
AmbientLights,
|
|
||||||
BlendParams,
|
|
||||||
CamerasBase,
|
|
||||||
FoVPerspectiveCameras,
|
|
||||||
HardPhongShader,
|
|
||||||
look_at_view_transform,
|
|
||||||
MeshRasterizer,
|
|
||||||
MeshRendererWithFragments,
|
|
||||||
PointLights,
|
|
||||||
RasterizationSettings,
|
|
||||||
)
|
|
||||||
from pytorch3d.structures.meshes import Meshes
|
|
||||||
|
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
|
||||||
from .single_sequence_dataset import SingleSceneDataset
|
|
||||||
from .utils import DATASET_TYPE_KNOWN
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class RenderedMeshDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|
||||||
"""
|
|
||||||
A simple single-scene dataset based on PyTorch3D renders of a mesh.
|
|
||||||
Provides `num_views` renders of the mesh as train, with no val
|
|
||||||
and test. The renders are generated from viewpoints sampled at uniformly
|
|
||||||
distributed azimuth intervals. The elevation is kept constant so that the
|
|
||||||
camera's vertical position coincides with the equator.
|
|
||||||
|
|
||||||
By default, uses Keenan Crane's cow model, and the camera locations are
|
|
||||||
set to make sense for that.
|
|
||||||
|
|
||||||
Although the rendering used to generate this dataset will use a GPU
|
|
||||||
if one is available, the data it produces is on the CPU just like
|
|
||||||
the data returned by implicitron's other dataset map providers.
|
|
||||||
This is because both datasets and models can be large, so implicitron's
|
|
||||||
GenericModel.forward (etc) expects data on the CPU and only moves
|
|
||||||
what it needs to the device.
|
|
||||||
|
|
||||||
For a more detailed explanation of this code, please refer to the
|
|
||||||
docs/tutorials/fit_textured_mesh.ipynb notebook.
|
|
||||||
|
|
||||||
Members:
|
|
||||||
num_views: The number of generated renders.
|
|
||||||
data_file: The folder that contains the mesh file. By default, finds
|
|
||||||
the cow mesh in the same repo as this code.
|
|
||||||
azimuth_range: number of degrees on each side of the start position to
|
|
||||||
take samples
|
|
||||||
resolution: the common height and width of the output images.
|
|
||||||
use_point_light: whether to use a particular point light as opposed
|
|
||||||
to ambient white.
|
|
||||||
"""
|
|
||||||
|
|
||||||
num_views: int = 40
|
|
||||||
data_file: Optional[str] = None
|
|
||||||
azimuth_range: float = 180
|
|
||||||
resolution: int = 128
|
|
||||||
use_point_light: bool = True
|
|
||||||
path_manager_factory: PathManagerFactory
|
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
|
||||||
|
|
||||||
def get_dataset_map(self) -> DatasetMap:
|
|
||||||
# pyre-ignore[16]
|
|
||||||
return DatasetMap(train=self.train_dataset, val=None, test=None)
|
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> CamerasBase:
|
|
||||||
# pyre-ignore[16]
|
|
||||||
return self.poses
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
run_auto_creation(self)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda:0")
|
|
||||||
else:
|
|
||||||
device = torch.device("cpu")
|
|
||||||
if self.data_file is None:
|
|
||||||
data_file = join(
|
|
||||||
dirname(dirname(dirname(dirname(realpath(__file__))))),
|
|
||||||
"docs",
|
|
||||||
"tutorials",
|
|
||||||
"data",
|
|
||||||
"cow_mesh",
|
|
||||||
"cow.obj",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
data_file = self.data_file
|
|
||||||
io = IO(path_manager=self.path_manager_factory.get())
|
|
||||||
mesh = io.load_mesh(data_file, device=device)
|
|
||||||
poses, images, masks = _generate_cow_renders(
|
|
||||||
num_views=self.num_views,
|
|
||||||
mesh=mesh,
|
|
||||||
azimuth_range=self.azimuth_range,
|
|
||||||
resolution=self.resolution,
|
|
||||||
device=device,
|
|
||||||
use_point_light=self.use_point_light,
|
|
||||||
)
|
|
||||||
# pyre-ignore[16]
|
|
||||||
self.poses = poses.cpu()
|
|
||||||
expand_args_fields(SingleSceneDataset)
|
|
||||||
# pyre-ignore[16]
|
|
||||||
self.train_dataset = SingleSceneDataset( # pyre-ignore[28]
|
|
||||||
object_name="cow",
|
|
||||||
images=list(images.permute(0, 3, 1, 2).cpu()),
|
|
||||||
fg_probabilities=list(masks[:, None].cpu()),
|
|
||||||
poses=[self.poses[i] for i in range(len(poses))],
|
|
||||||
frame_types=[DATASET_TYPE_KNOWN] * len(poses),
|
|
||||||
eval_batches=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def _generate_cow_renders(
|
|
||||||
*,
|
|
||||||
num_views: int,
|
|
||||||
mesh: Meshes,
|
|
||||||
azimuth_range: float,
|
|
||||||
resolution: int,
|
|
||||||
device: torch.device,
|
|
||||||
use_point_light: bool,
|
|
||||||
) -> Tuple[CamerasBase, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Returns:
|
|
||||||
cameras: A batch of `num_views` `FoVPerspectiveCameras` from which the
|
|
||||||
images are rendered.
|
|
||||||
images: A tensor of shape `(num_views, height, width, 3)` containing
|
|
||||||
the rendered images.
|
|
||||||
silhouettes: A tensor of shape `(num_views, height, width)` containing
|
|
||||||
the rendered silhouettes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Load obj file
|
|
||||||
|
|
||||||
# We scale normalize and center the target mesh to fit in a sphere of radius 1
|
|
||||||
# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
|
|
||||||
# to its original center and scale. Note that normalizing the target mesh,
|
|
||||||
# speeds up the optimization but is not necessary!
|
|
||||||
verts = mesh.verts_packed()
|
|
||||||
N = verts.shape[0]
|
|
||||||
center = verts.mean(0)
|
|
||||||
scale = max((verts - center).abs().max(0)[0])
|
|
||||||
mesh.offset_verts_(-(center.expand(N, 3)))
|
|
||||||
mesh.scale_verts_((1.0 / float(scale)))
|
|
||||||
|
|
||||||
# Get a batch of viewing angles.
|
|
||||||
elev = torch.linspace(0, 0, num_views) # keep constant
|
|
||||||
azim = torch.linspace(-azimuth_range, azimuth_range, num_views) + 180.0
|
|
||||||
|
|
||||||
# Place a point light in front of the object. As mentioned above, the front of
|
|
||||||
# the cow is facing the -z direction.
|
|
||||||
if use_point_light:
|
|
||||||
lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])
|
|
||||||
else:
|
|
||||||
lights = AmbientLights(device=device)
|
|
||||||
|
|
||||||
# Initialize an OpenGL perspective camera that represents a batch of different
|
|
||||||
# viewing angles. All the cameras helper methods support mixed type inputs and
|
|
||||||
# broadcasting. So we can view the camera from the a distance of dist=2.7, and
|
|
||||||
# then specify elevation and azimuth angles for each viewpoint as tensors.
|
|
||||||
R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
|
|
||||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
|
||||||
|
|
||||||
# Define the settings for rasterization and shading.
|
|
||||||
# As we are rendering images for visualization
|
|
||||||
# purposes only we will set faces_per_pixel=1 and blur_radius=0.0. Refer to
|
|
||||||
# rasterize_meshes.py for explanations of these parameters. We also leave
|
|
||||||
# bin_size and max_faces_per_bin to their default values of None, which sets
|
|
||||||
# their values using heuristics and ensures that the faster coarse-to-fine
|
|
||||||
# rasterization method is used. Refer to docs/notes/renderer.md for an
|
|
||||||
# explanation of the difference between naive and coarse-to-fine rasterization.
|
|
||||||
raster_settings = RasterizationSettings(
|
|
||||||
image_size=resolution, blur_radius=0.0, faces_per_pixel=1
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a Phong renderer by composing a rasterizer and a shader. The textured
|
|
||||||
# Phong shader will interpolate the texture uv coordinates for each vertex,
|
|
||||||
# sample from a texture image and apply the Phong lighting model
|
|
||||||
blend_params = BlendParams(sigma=1e-4, gamma=1e-4, background_color=(0.0, 0.0, 0.0))
|
|
||||||
rasterizer_type = MeshRasterizer
|
|
||||||
renderer = MeshRendererWithFragments(
|
|
||||||
rasterizer=rasterizer_type(cameras=cameras, raster_settings=raster_settings),
|
|
||||||
shader=HardPhongShader(
|
|
||||||
device=device, cameras=cameras, lights=lights, blend_params=blend_params
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a batch of meshes by repeating the cow mesh and associated textures.
|
|
||||||
# Meshes has a useful `extend` method which allows us do this very easily.
|
|
||||||
# This also extends the textures.
|
|
||||||
meshes = mesh.extend(num_views)
|
|
||||||
|
|
||||||
# Render the cow mesh from each viewing angle
|
|
||||||
target_images, fragments = renderer(meshes, cameras=cameras, lights=lights)
|
|
||||||
silhouette_binary = (fragments.pix_to_face[..., 0] >= 0).float()
|
|
||||||
|
|
||||||
return cameras, target_images[..., :3], silhouette_binary
|
|
||||||
@@ -21,13 +21,17 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
|
||||||
|
|
||||||
from .dataset_base import DatasetBase, FrameData
|
from .dataset_base import DatasetBase, FrameData
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, PathManagerFactory
|
from .dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
|
)
|
||||||
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
|
||||||
|
|
||||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
|
||||||
|
|
||||||
|
|
||||||
@expand_args_fields
|
|
||||||
class SingleSceneDataset(DatasetBase, Configurable):
|
class SingleSceneDataset(DatasetBase, Configurable):
|
||||||
"""
|
"""
|
||||||
A dataset from images from a single scene.
|
A dataset from images from a single scene.
|
||||||
@@ -111,6 +115,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
def _get_dataset(
|
def _get_dataset(
|
||||||
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
|
||||||
) -> SingleSceneDataset:
|
) -> SingleSceneDataset:
|
||||||
|
expand_args_fields(SingleSceneDataset)
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
split = self.i_split[split_idx]
|
split = self.i_split[split_idx]
|
||||||
frame_types = [frame_type] * len(split)
|
frame_types = [frame_type] * len(split)
|
||||||
@@ -154,6 +159,9 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_task(self) -> Task:
|
||||||
|
return Task.SINGLE_SEQUENCE
|
||||||
|
|
||||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
cameras = [self.poses[i] for i in self.i_split[0]]
|
cameras = [self.poses[i] for i in self.i_split[0]]
|
||||||
|
|||||||
@@ -7,12 +7,11 @@
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, cast, Dict, List, Optional, Tuple
|
from typing import Any, cast, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import lpips
|
import lpips
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||||
CO3D_CATEGORIES,
|
CO3D_CATEGORIES,
|
||||||
@@ -28,11 +27,6 @@ from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class Task(Enum):
|
|
||||||
SINGLE_SEQUENCE = "singlesequence"
|
|
||||||
MULTI_SEQUENCE = "multisequence"
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""
|
"""
|
||||||
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
Evaluates new view synthesis metrics of a simple depth-based image rendering
|
||||||
@@ -159,15 +153,11 @@ def evaluate_dbir_for_category(
|
|||||||
|
|
||||||
if task == Task.SINGLE_SEQUENCE:
|
if task == Task.SINGLE_SEQUENCE:
|
||||||
camera_difficulty_bin_breaks = 0.97, 0.98
|
camera_difficulty_bin_breaks = 0.97, 0.98
|
||||||
multisequence_evaluation = False
|
|
||||||
else:
|
else:
|
||||||
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
camera_difficulty_bin_breaks = 2.0 / 3, 5.0 / 6
|
||||||
multisequence_evaluation = True
|
|
||||||
|
|
||||||
category_result_flat, category_result = summarize_nvs_eval_results(
|
category_result_flat, category_result = summarize_nvs_eval_results(
|
||||||
per_batch_eval_results,
|
per_batch_eval_results, task, camera_difficulty_bin_breaks
|
||||||
camera_difficulty_bin_breaks=camera_difficulty_bin_breaks,
|
|
||||||
is_multisequence=multisequence_evaluation,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return category_result["results"]
|
return category_result["results"]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from pytorch3d.implicitron.dataset.data_source import Task
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
from pytorch3d.implicitron.models.base_model import ImplicitronRender
|
||||||
@@ -242,26 +243,10 @@ def eval_batch(
|
|||||||
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
|
||||||
warnings.warn("Empty or missing depth map in evaluation!")
|
warnings.warn("Empty or missing depth map in evaluation!")
|
||||||
|
|
||||||
if frame_data.mask_crop is None:
|
|
||||||
warnings.warn("mask_crop is None, assuming the whole image is valid.")
|
|
||||||
|
|
||||||
if frame_data.fg_probability is None:
|
|
||||||
warnings.warn("fg_probability is None, assuming the whole image is fg.")
|
|
||||||
|
|
||||||
# threshold the masks to make ground truth binary masks
|
# threshold the masks to make ground truth binary masks
|
||||||
mask_fg = (
|
mask_fg, mask_crop = [
|
||||||
frame_data.fg_probability >= mask_thr
|
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
|
||||||
if frame_data.fg_probability is not None
|
]
|
||||||
# pyre-ignore [16]
|
|
||||||
else torch.ones_like(frame_data.image_rgb[:, :1, ...]).bool()
|
|
||||||
)
|
|
||||||
|
|
||||||
mask_crop = (
|
|
||||||
frame_data.mask_crop
|
|
||||||
if frame_data.mask_crop is not None
|
|
||||||
else torch.ones_like(mask_fg)
|
|
||||||
)
|
|
||||||
|
|
||||||
image_rgb_masked = mask_background(
|
image_rgb_masked = mask_background(
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||||
# `Optional[torch.Tensor]`.
|
# `Optional[torch.Tensor]`.
|
||||||
@@ -281,6 +266,7 @@ def eval_batch(
|
|||||||
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
|
||||||
# `Optional[torch.Tensor]`.
|
# `Optional[torch.Tensor]`.
|
||||||
depth_map=frame_data.depth_map,
|
depth_map=frame_data.depth_map,
|
||||||
|
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
|
||||||
depth_mask=frame_data.depth_mask[:1],
|
depth_mask=frame_data.depth_mask[:1],
|
||||||
visdom_env=visualize_visdom_env,
|
visdom_env=visualize_visdom_env,
|
||||||
)
|
)
|
||||||
@@ -312,7 +298,7 @@ def eval_batch(
|
|||||||
results[metric_name].item(), metric_name, loss_mask_now
|
results[metric_name].item(), metric_name, loss_mask_now
|
||||||
)
|
)
|
||||||
|
|
||||||
if name_postfix == "_fg" and frame_data.depth_map is not None:
|
if name_postfix == "_fg":
|
||||||
# only record depth metrics for the foreground
|
# only record depth metrics for the foreground
|
||||||
_, abs_ = eval_depth(
|
_, abs_ = eval_depth(
|
||||||
cloned_render["depth_render"],
|
cloned_render["depth_render"],
|
||||||
@@ -328,7 +314,9 @@ def eval_batch(
|
|||||||
if visualize:
|
if visualize:
|
||||||
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
|
||||||
if break_after_visualising:
|
if break_after_visualising:
|
||||||
breakpoint() # noqa: B601
|
import pdb # noqa: B602
|
||||||
|
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
if lpips_model is not None:
|
if lpips_model is not None:
|
||||||
im1, im2 = [
|
im1, im2 = [
|
||||||
@@ -432,16 +420,16 @@ def _get_camera_difficulty_bin_edges(camera_difficulty_bin_breaks: Tuple[float,
|
|||||||
|
|
||||||
def summarize_nvs_eval_results(
|
def summarize_nvs_eval_results(
|
||||||
per_batch_eval_results: List[Dict[str, Any]],
|
per_batch_eval_results: List[Dict[str, Any]],
|
||||||
is_multisequence: bool,
|
task: Task,
|
||||||
camera_difficulty_bin_breaks: Tuple[float, float],
|
camera_difficulty_bin_breaks: Tuple[float, float] = (0.97, 0.98),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compile the per-batch evaluation results `per_batch_eval_results` into
|
Compile the per-batch evaluation results `per_batch_eval_results` into
|
||||||
a set of aggregate metrics. The produced metrics depend on is_multisequence.
|
a set of aggregate metrics. The produced metrics depend on the task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
per_batch_eval_results: Metrics of each per-batch evaluation.
|
per_batch_eval_results: Metrics of each per-batch evaluation.
|
||||||
is_multisequence: Whether to evaluate as a multisequence task
|
task: The type of the new-view synthesis task.
|
||||||
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
camera_difficulty_bin_breaks: edge hard-medium and medium-easy
|
||||||
|
|
||||||
|
|
||||||
@@ -451,9 +439,14 @@ def summarize_nvs_eval_results(
|
|||||||
"""
|
"""
|
||||||
n_batches = len(per_batch_eval_results)
|
n_batches = len(per_batch_eval_results)
|
||||||
eval_sets: List[Optional[str]] = []
|
eval_sets: List[Optional[str]] = []
|
||||||
|
if task == Task.SINGLE_SEQUENCE:
|
||||||
eval_sets = [None]
|
eval_sets = [None]
|
||||||
if is_multisequence:
|
# assert n_batches==100
|
||||||
|
elif task == Task.MULTI_SEQUENCE:
|
||||||
eval_sets = ["train", "test"]
|
eval_sets = ["train", "test"]
|
||||||
|
# assert n_batches==1000
|
||||||
|
else:
|
||||||
|
raise ValueError(task)
|
||||||
batch_sizes = torch.tensor(
|
batch_sizes = torch.tensor(
|
||||||
[r["meta"]["batch_size"] for r in per_batch_eval_results]
|
[r["meta"]["batch_size"] for r in per_batch_eval_results]
|
||||||
).long()
|
).long()
|
||||||
@@ -473,9 +466,11 @@ def summarize_nvs_eval_results(
|
|||||||
# add per set averages
|
# add per set averages
|
||||||
for SET in eval_sets:
|
for SET in eval_sets:
|
||||||
if SET is None:
|
if SET is None:
|
||||||
|
assert task == Task.SINGLE_SEQUENCE
|
||||||
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
ok_set = torch.ones(n_batches, dtype=torch.bool)
|
||||||
set_name = "test"
|
set_name = "test"
|
||||||
else:
|
else:
|
||||||
|
assert task == Task.MULTI_SEQUENCE
|
||||||
ok_set = is_train == int(SET == "train")
|
ok_set = is_train == int(SET == "train")
|
||||||
set_name = SET
|
set_name = SET
|
||||||
|
|
||||||
@@ -500,7 +495,7 @@ def summarize_nvs_eval_results(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_multisequence:
|
if task == Task.MULTI_SEQUENCE:
|
||||||
# split based on n_src_views
|
# split based on n_src_views
|
||||||
n_src_views = batch_sizes - 1
|
n_src_views = batch_sizes - 1
|
||||||
for n_src in EVAL_N_SRC_VIEWS:
|
for n_src in EVAL_N_SRC_VIEWS:
|
||||||
|
|||||||
@@ -1,164 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import copy
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
import lpips
|
|
||||||
import torch
|
|
||||||
|
|
||||||
import tqdm
|
|
||||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
|
||||||
|
|
||||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
|
||||||
registry,
|
|
||||||
ReplaceableBase,
|
|
||||||
run_auto_creation,
|
|
||||||
)
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluatorBase(ReplaceableBase):
|
|
||||||
"""
|
|
||||||
Evaluate a trained model on given data. Returns a dict of loss/objective
|
|
||||||
names and their values.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Evaluate the results of Implicitron training.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class ImplicitronEvaluator(EvaluatorBase):
|
|
||||||
"""
|
|
||||||
Evaluate the results of Implicitron training.
|
|
||||||
|
|
||||||
Members:
|
|
||||||
camera_difficulty_bin_breaks: low/medium vals to divide camera difficulties into
|
|
||||||
[0-eps, low, medium, 1+eps].
|
|
||||||
"""
|
|
||||||
|
|
||||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
|
||||||
is_multisequence: bool = False
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
run_auto_creation(self)
|
|
||||||
|
|
||||||
def run(
|
|
||||||
self,
|
|
||||||
model: ImplicitronModelBase,
|
|
||||||
dataloader: DataLoader,
|
|
||||||
all_train_cameras: Optional[CamerasBase],
|
|
||||||
device: torch.device,
|
|
||||||
dump_to_json: bool = False,
|
|
||||||
exp_dir: Optional[str] = None,
|
|
||||||
epoch: Optional[int] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Evaluate the results of Implicitron training. Optionally, dump results to
|
|
||||||
exp_dir/results_test.json.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: A (trained) model to evaluate.
|
|
||||||
dataloader: A test dataloader.
|
|
||||||
all_train_cameras: Camera instances we used for training.
|
|
||||||
device: A torch device.
|
|
||||||
dump_to_json: If True, will dump the results to a json file.
|
|
||||||
exp_dir: Root expeirment directory.
|
|
||||||
epoch: Evaluation epoch (to be stored in the results dict).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of results.
|
|
||||||
"""
|
|
||||||
lpips_model = lpips.LPIPS(net="vgg")
|
|
||||||
lpips_model = lpips_model.to(device)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
per_batch_eval_results = []
|
|
||||||
logger.info("Evaluating model ...")
|
|
||||||
for frame_data in tqdm.tqdm(dataloader):
|
|
||||||
frame_data = frame_data.to(device)
|
|
||||||
|
|
||||||
# mask out the unknown images so that the model does not see them
|
|
||||||
frame_data_for_eval = _get_eval_frame_data(frame_data)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
preds = model(
|
|
||||||
**{
|
|
||||||
**frame_data_for_eval,
|
|
||||||
"evaluation_mode": EvaluationMode.EVALUATION,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
|
||||||
per_batch_eval_results.append(
|
|
||||||
evaluate.eval_batch(
|
|
||||||
frame_data,
|
|
||||||
implicitron_render,
|
|
||||||
bg_color="black",
|
|
||||||
lpips_model=lpips_model,
|
|
||||||
source_cameras=( # None will make it use batch’s known cameras
|
|
||||||
None if self.is_multisequence else all_train_cameras
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
_, category_result = evaluate.summarize_nvs_eval_results(
|
|
||||||
per_batch_eval_results,
|
|
||||||
self.is_multisequence,
|
|
||||||
self.camera_difficulty_bin_breaks,
|
|
||||||
)
|
|
||||||
|
|
||||||
results = category_result["results"]
|
|
||||||
evaluate.pretty_print_nvs_metrics(results)
|
|
||||||
if dump_to_json:
|
|
||||||
_dump_to_json(epoch, exp_dir, results)
|
|
||||||
|
|
||||||
return category_result["results"]
|
|
||||||
|
|
||||||
|
|
||||||
def _dump_to_json(
|
|
||||||
epoch: Optional[int], exp_dir: Optional[str], results: List[Dict[str, Any]]
|
|
||||||
) -> None:
|
|
||||||
if epoch is not None:
|
|
||||||
for r in results:
|
|
||||||
r["eval_epoch"] = int(epoch)
|
|
||||||
logger.info("Evaluation results")
|
|
||||||
|
|
||||||
if exp_dir is None:
|
|
||||||
raise ValueError("Cannot save results to json without a specified save path.")
|
|
||||||
with open(os.path.join(exp_dir, "results_test.json"), "w") as f:
|
|
||||||
json.dump(results, f)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_eval_frame_data(frame_data: Any) -> Any:
|
|
||||||
"""
|
|
||||||
Masks the unknown image data to make sure we cannot use it at model evaluation time.
|
|
||||||
"""
|
|
||||||
frame_data_for_eval = copy.deepcopy(frame_data)
|
|
||||||
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
|
|
||||||
frame_data.image_rgb
|
|
||||||
)[:, None, None, None]
|
|
||||||
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
|
|
||||||
value = getattr(frame_data_for_eval, k)
|
|
||||||
value_masked = value.clone() * is_known if value is not None else None
|
|
||||||
setattr(frame_data_for_eval, k, value_masked)
|
|
||||||
return frame_data_for_eval
|
|
||||||
@@ -4,7 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -37,18 +37,12 @@ class ImplicitronRender:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
|
class ImplicitronModelBase(ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
Replaceable abstract base for all image generation / rendering models.
|
Replaceable abstract base for all image generation / rendering models.
|
||||||
`forward()` method produces a render with a depth map. Derives from Module
|
`forward()` method produces a render with a depth map.
|
||||||
so we can rely on basic functionality provided to torch for model
|
|
||||||
optimization.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The keys from `preds` (output of ImplicitronModelBase.forward) to be logged in
|
|
||||||
# the training loop.
|
|
||||||
log_vars: List[str] = field(default_factory=lambda: ["objective"])
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
import tqdm
|
||||||
from omegaconf import DictConfig
|
from pytorch3d.implicitron.models.metrics import ( # noqa
|
||||||
from pytorch3d.common.compat import prod
|
RegularizationMetrics,
|
||||||
from pytorch3d.implicitron.models.metrics import (
|
|
||||||
RegularizationMetricsBase,
|
RegularizationMetricsBase,
|
||||||
|
ViewMetrics,
|
||||||
ViewMetricsBase,
|
ViewMetricsBase,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
||||||
@@ -29,7 +29,7 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
from pytorch3d.implicitron.tools.utils import cat_dataclass, setattr_if_hasattr
|
||||||
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
from pytorch3d.renderer import RayBundle, utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
from visdom import Visdom
|
from visdom import Visdom
|
||||||
@@ -67,7 +67,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13
|
||||||
"""
|
"""
|
||||||
GenericModel is a wrapper for the neural implicit
|
GenericModel is a wrapper for the neural implicit
|
||||||
rendering and reconstruction pipeline which consists
|
rendering and reconstruction pipeline which consists
|
||||||
@@ -148,9 +148,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
thresholded by this value before being applied to the RGB/Depth images
|
thresholded by this value before being applied to the RGB/Depth images
|
||||||
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
||||||
splatting onto an image grid. Default: False.
|
splatting onto an image grid. Default: False.
|
||||||
bg_color: RGB values for setting the background color of input image
|
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
|
||||||
if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
|
|
||||||
way to determine the background color of its output, unrelated to this.
|
|
||||||
num_passes: The specified implicit_function is initialized num_passes
|
num_passes: The specified implicit_function is initialized num_passes
|
||||||
times and run sequentially.
|
times and run sequentially.
|
||||||
chunk_size_grid: The total number of points which can be rendered
|
chunk_size_grid: The total number of points which can be rendered
|
||||||
@@ -535,7 +533,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
return None
|
return None
|
||||||
loss = sum(losses_weighted)
|
loss = sum(losses_weighted)
|
||||||
assert torch.is_tensor(loss)
|
assert torch.is_tensor(loss)
|
||||||
# pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
@@ -620,57 +617,51 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
self.image_feature_extractor.get_feat_dims()
|
self.image_feature_extractor.get_feat_dims()
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
"""
|
|
||||||
We don't expose certain fields of the raysampler because we want to set
|
|
||||||
them from our own members.
|
|
||||||
"""
|
|
||||||
del args["sampling_mode_training"]
|
|
||||||
del args["sampling_mode_evaluation"]
|
|
||||||
del args["image_width"]
|
|
||||||
del args["image_height"]
|
|
||||||
|
|
||||||
def create_raysampler(self):
|
def create_raysampler(self):
|
||||||
extra_args = {
|
|
||||||
"sampling_mode_training": self.sampling_mode_training,
|
|
||||||
"sampling_mode_evaluation": self.sampling_mode_evaluation,
|
|
||||||
"image_width": self.render_image_width,
|
|
||||||
"image_height": self.render_image_height,
|
|
||||||
}
|
|
||||||
raysampler_args = getattr(
|
raysampler_args = getattr(
|
||||||
self, "raysampler_" + self.raysampler_class_type + "_args"
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
)
|
)
|
||||||
|
setattr_if_hasattr(
|
||||||
|
raysampler_args, "sampling_mode_training", self.sampling_mode_training
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(
|
||||||
|
raysampler_args, "sampling_mode_evaluation", self.sampling_mode_evaluation
|
||||||
|
)
|
||||||
|
setattr_if_hasattr(raysampler_args, "image_width", self.render_image_width)
|
||||||
|
setattr_if_hasattr(raysampler_args, "image_height", self.render_image_height)
|
||||||
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||||
**raysampler_args, **extra_args
|
**raysampler_args
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def renderer_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
"""
|
|
||||||
We don't expose certain fields of the renderer because we want to set
|
|
||||||
them based on other inputs.
|
|
||||||
"""
|
|
||||||
args.pop("render_features_dimensions", None)
|
|
||||||
args.pop("object_bounding_sphere", None)
|
|
||||||
|
|
||||||
def create_renderer(self):
|
def create_renderer(self):
|
||||||
extra_args = {}
|
raysampler_args = getattr(
|
||||||
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
|
)
|
||||||
|
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||||
|
"stratified_sampling_coarse_training"
|
||||||
|
] = raysampler_args["stratified_point_sampling_training"]
|
||||||
|
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
|
||||||
|
"stratified_sampling_coarse_evaluation"
|
||||||
|
] = raysampler_args["stratified_point_sampling_evaluation"]
|
||||||
|
self.renderer_SignedDistanceFunctionRenderer_args[
|
||||||
|
"render_features_dimensions"
|
||||||
|
] = self.render_features_dimensions
|
||||||
|
|
||||||
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||||
extra_args["render_features_dimensions"] = self.render_features_dimensions
|
if "scene_extent" not in raysampler_args:
|
||||||
if not hasattr(self.raysampler, "scene_extent"):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"SignedDistanceFunctionRenderer requires"
|
"SignedDistanceFunctionRenderer requires"
|
||||||
+ " a raysampler that defines the 'scene_extent' field"
|
+ " a raysampler that defines the 'scene_extent' field"
|
||||||
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||||
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||||
)
|
)
|
||||||
extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
|
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
|
||||||
|
"object_bounding_sphere"
|
||||||
|
] = self.raysampler_AdaptiveRaySampler_args["scene_extent"]
|
||||||
|
|
||||||
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||||
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||||
**renderer_args, **extra_args
|
**renderer_args
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_implicit_function(self) -> None:
|
def create_implicit_function(self) -> None:
|
||||||
@@ -681,18 +672,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
"""
|
|
||||||
We don't expose certain implicit_function fields because we want to set
|
|
||||||
them based on other inputs.
|
|
||||||
"""
|
|
||||||
args.pop("feature_vector_size", None)
|
|
||||||
args.pop("encoding_dim", None)
|
|
||||||
args.pop("latent_dim", None)
|
|
||||||
args.pop("latent_dim_hypernet", None)
|
|
||||||
args.pop("color_dim", None)
|
|
||||||
|
|
||||||
def _construct_implicit_functions(self):
|
def _construct_implicit_functions(self):
|
||||||
"""
|
"""
|
||||||
After run_auto_creation has been called, the arguments
|
After run_auto_creation has been called, the arguments
|
||||||
@@ -702,31 +681,32 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
implicit function method. Then the required implicit
|
implicit function method. Then the required implicit
|
||||||
function(s) are initialized.
|
function(s) are initialized.
|
||||||
"""
|
"""
|
||||||
extra_args = {}
|
# nerf preprocessing
|
||||||
if self.implicit_function_class_type in (
|
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
||||||
"NeuralRadianceFieldImplicitFunction",
|
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
||||||
"NeRFormerImplicitFunction",
|
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
||||||
):
|
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||||
extra_args["latent_dim"] = (
|
|
||||||
self._get_viewpooled_feature_dim()
|
|
||||||
+ self._get_global_encoder_encoding_dim()
|
|
||||||
)
|
)
|
||||||
extra_args["color_dim"] = self.render_features_dimensions
|
nerf_args["color_dim"] = nerformer_args[
|
||||||
|
"color_dim"
|
||||||
|
] = self.render_features_dimensions
|
||||||
|
|
||||||
if self.implicit_function_class_type == "IdrFeatureField":
|
# idr preprocessing
|
||||||
extra_args["feature_vector_size"] = self.render_features_dimensions
|
idr = self.implicit_function_IdrFeatureField_args
|
||||||
extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
idr["feature_vector_size"] = self.render_features_dimensions
|
||||||
|
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||||
|
|
||||||
if self.implicit_function_class_type == "SRNImplicitFunction":
|
# srn preprocessing
|
||||||
extra_args["latent_dim"] = (
|
srn = self.implicit_function_SRNImplicitFunction_args
|
||||||
self._get_viewpooled_feature_dim()
|
srn.raymarch_function_args.latent_dim = (
|
||||||
+ self._get_global_encoder_encoding_dim()
|
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||||
)
|
)
|
||||||
|
|
||||||
# srn_hypernet preprocessing
|
# srn_hypernet preprocessing
|
||||||
if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
|
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
||||||
extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
|
srn_hypernet_args = srn_hypernet.hypernet_args
|
||||||
extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
|
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
||||||
|
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
||||||
|
|
||||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||||
implicit_function_type = registry.get(
|
implicit_function_type = registry.get(
|
||||||
@@ -749,7 +729,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError(f"{config_name} not present")
|
raise ValueError(f"{config_name} not present")
|
||||||
implicit_functions_list = [
|
implicit_functions_list = [
|
||||||
ImplicitFunctionWrapper(implicit_function_type(**config, **extra_args))
|
ImplicitFunctionWrapper(implicit_function_type(**config))
|
||||||
for _ in range(self.num_passes)
|
for _ in range(self.num_passes)
|
||||||
]
|
]
|
||||||
return torch.nn.ModuleList(implicit_functions_list)
|
return torch.nn.ModuleList(implicit_functions_list)
|
||||||
@@ -860,7 +840,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
|
|
||||||
# Estimate the rasterization point radius so that we approximately fill
|
# Estimate the rasterization point radius so that we approximately fill
|
||||||
# the whole image given the number of rasterized points.
|
# the whole image given the number of rasterized points.
|
||||||
pt_radius = 2.0 / math.sqrt(xys.shape[1])
|
pt_radius = 2.0 * math.sqrt(xys.shape[1])
|
||||||
|
|
||||||
# Rasterize the samples.
|
# Rasterize the samples.
|
||||||
features_depth_render, masks_render = rasterize_mc_samples(
|
features_depth_render, masks_render = rasterize_mc_samples(
|
||||||
@@ -920,7 +900,7 @@ def _chunk_generator(
|
|||||||
f"by n_pts_per_ray ({n_pts_per_ray})"
|
f"by n_pts_per_ray ({n_pts_per_ray})"
|
||||||
)
|
)
|
||||||
|
|
||||||
n_rays = prod(spatial_dim)
|
n_rays = math.prod(spatial_dim)
|
||||||
# special handling for raytracing-based methods
|
# special handling for raytracing-based methods
|
||||||
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
|
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
|
||||||
chunk_size_in_rays = -(-n_rays // n_chunks)
|
chunk_size_in_rays = -(-n_rays // n_chunks)
|
||||||
@@ -936,9 +916,9 @@ def _chunk_generator(
|
|||||||
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
||||||
:, start_idx:end_idx
|
:, start_idx:end_idx
|
||||||
],
|
],
|
||||||
lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
|
lengths=ray_bundle.lengths.reshape(
|
||||||
:, start_idx:end_idx
|
batch_size, math.prod(spatial_dim), n_pts_per_ray
|
||||||
],
|
)[:, start_idx:end_idx],
|
||||||
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
||||||
)
|
)
|
||||||
extra_args = kwargs.copy()
|
extra_args = kwargs.copy()
|
||||||
|
|||||||
@@ -24,16 +24,15 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
encoding_dim: int = 0
|
encoding_dim: int = 0
|
||||||
n_instances: int = 1
|
n_instances: int = 0
|
||||||
init_scale: float = 1.0
|
init_scale: float = 1.0
|
||||||
ignore_input: bool = False
|
ignore_input: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if self.n_instances <= 0:
|
if self.n_instances <= 0:
|
||||||
raise ValueError(f"Invalid n_instances {self.n_instances}")
|
# Do not init the codes at all in case we have 0 instances.
|
||||||
|
return
|
||||||
self._autodecoder_codes = torch.nn.Embedding(
|
self._autodecoder_codes = torch.nn.Embedding(
|
||||||
self.n_instances,
|
self.n_instances,
|
||||||
self.encoding_dim,
|
self.encoding_dim,
|
||||||
@@ -71,9 +70,13 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
return key_map
|
return key_map
|
||||||
|
|
||||||
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
|
||||||
|
if self.n_instances <= 0:
|
||||||
|
return None
|
||||||
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
return (self._autodecoder_codes.weight**2).mean() # pyre-ignore[16]
|
||||||
|
|
||||||
def get_encoding_dim(self) -> int:
|
def get_encoding_dim(self) -> int:
|
||||||
|
if self.n_instances <= 0:
|
||||||
|
return 0
|
||||||
return self.encoding_dim
|
return self.encoding_dim
|
||||||
|
|
||||||
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
|
||||||
@@ -87,6 +90,9 @@ class Autodecoder(Configurable, torch.nn.Module):
|
|||||||
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
codes: A tensor of shape `(N, self.encoding_dim)` containing the
|
||||||
key-specific autodecoder codes.
|
key-specific autodecoder codes.
|
||||||
"""
|
"""
|
||||||
|
if self.n_instances == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
if self.ignore_input:
|
if self.ignore_input:
|
||||||
x = ["singleton"]
|
x = ["singleton"]
|
||||||
|
|
||||||
|
|||||||
@@ -42,13 +42,7 @@ class GlobalEncoderBase(ReplaceableBase):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def forward(
|
def forward(self, **kwargs) -> torch.Tensor:
|
||||||
self,
|
|
||||||
*,
|
|
||||||
frame_timestamp: Optional[torch.Tensor] = None,
|
|
||||||
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Given a set of inputs to encode, generates a tensor containing the encoding.
|
Given a set of inputs to encode, generates a tensor containing the encoding.
|
||||||
|
|
||||||
@@ -76,14 +70,9 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1
|
|||||||
return self.autodecoder.get_encoding_dim()
|
return self.autodecoder.get_encoding_dim()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self, sequence_name: Union[torch.LongTensor, List[str]], **kwargs
|
||||||
*,
|
|
||||||
frame_timestamp: Optional[torch.Tensor] = None,
|
|
||||||
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if sequence_name is None:
|
|
||||||
raise ValueError("sequence_name must be provided.")
|
|
||||||
# run dtype checks and pass sequence_name to self.autodecoder
|
# run dtype checks and pass sequence_name to self.autodecoder
|
||||||
return self.autodecoder(sequence_name)
|
return self.autodecoder(sequence_name)
|
||||||
|
|
||||||
@@ -112,15 +101,7 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
|
|||||||
def get_encoding_dim(self):
|
def get_encoding_dim(self):
|
||||||
return self._harmonic_embedding.get_output_dim(1)
|
return self._harmonic_embedding.get_output_dim(1)
|
||||||
|
|
||||||
def forward(
|
def forward(self, frame_timestamp: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||||
self,
|
|
||||||
*,
|
|
||||||
frame_timestamp: Optional[torch.Tensor] = None,
|
|
||||||
sequence_name: Optional[Union[torch.LongTensor, List[str]]] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if frame_timestamp is None:
|
|
||||||
raise ValueError("frame_timestamp must be provided.")
|
|
||||||
if frame_timestamp.shape[-1] != 1:
|
if frame_timestamp.shape[-1] != 1:
|
||||||
raise ValueError("Frame timestamp's last dimensions should be one.")
|
raise ValueError("Frame timestamp's last dimensions should be one.")
|
||||||
time = frame_timestamp / self.time_divisor
|
time = frame_timestamp / self.time_divisor
|
||||||
|
|||||||
@@ -4,7 +4,6 @@
|
|||||||
from typing import Any, cast, Optional, Tuple
|
from typing import Any, cast, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
||||||
@@ -328,7 +327,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
@registry.register
|
@registry.register
|
||||||
# pyre-fixme[13]: Uninitialized attribute
|
# pyre-fixme[13]: Uninitialized attribute
|
||||||
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||||
latent_dim: int = 0
|
|
||||||
raymarch_function: SRNRaymarchFunction
|
raymarch_function: SRNRaymarchFunction
|
||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
@@ -336,17 +334,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def create_raymarch_function(self) -> None:
|
|
||||||
self.raymarch_function = SRNRaymarchFunction(
|
|
||||||
latent_dim=self.latent_dim,
|
|
||||||
# pyre-ignore[32]
|
|
||||||
**self.raymarch_function_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
args.pop("latent_dim", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
@@ -384,8 +371,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
the cache.
|
the cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
latent_dim_hypernet: int = 0
|
|
||||||
latent_dim: int = 0
|
|
||||||
hypernet: SRNRaymarchHyperNet
|
hypernet: SRNRaymarchHyperNet
|
||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
@@ -393,19 +378,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def create_hypernet(self) -> None:
|
|
||||||
self.hypernet = SRNRaymarchHyperNet(
|
|
||||||
latent_dim=self.latent_dim,
|
|
||||||
latent_dim_hypernet=self.latent_dim_hypernet,
|
|
||||||
# pyre-ignore[32]
|
|
||||||
**self.hypernet_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
args.pop("latent_dim", None)
|
|
||||||
args.pop("latent_dim_hypernet", None)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
|
|||||||
@@ -4,10 +4,10 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import math
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.compat import prod
|
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ def create_embeddings_for_implicit_function(
|
|||||||
embeds = torch.empty(
|
embeds = torch.empty(
|
||||||
bs,
|
bs,
|
||||||
1,
|
1,
|
||||||
prod(spatial_size),
|
math.prod(spatial_size),
|
||||||
pts_per_ray,
|
pts_per_ray,
|
||||||
0,
|
0,
|
||||||
dtype=xyz_world.dtype,
|
dtype=xyz_world.dtype,
|
||||||
@@ -62,7 +62,7 @@ def create_embeddings_for_implicit_function(
|
|||||||
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
|
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
|
||||||
bs,
|
bs,
|
||||||
1,
|
1,
|
||||||
prod(spatial_size),
|
math.prod(spatial_size),
|
||||||
pts_per_ray,
|
pts_per_ray,
|
||||||
-1,
|
-1,
|
||||||
) # flatten spatial, add n_src dim
|
) # flatten spatial, add n_src dim
|
||||||
@@ -73,7 +73,7 @@ def create_embeddings_for_implicit_function(
|
|||||||
embed_shape = (
|
embed_shape = (
|
||||||
bs,
|
bs,
|
||||||
embeds_viewpooled.shape[1],
|
embeds_viewpooled.shape[1],
|
||||||
prod(spatial_size),
|
math.prod(spatial_size),
|
||||||
pts_per_ray,
|
pts_per_ray,
|
||||||
-1,
|
-1,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from .renderer.base import EvaluationMode
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ModelDBIR(ImplicitronModelBase):
|
class ModelDBIR(ImplicitronModelBase, torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
A simple depth-based image rendering model.
|
A simple depth-based image rendering model.
|
||||||
|
|
||||||
|
|||||||
@@ -53,12 +53,10 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
|||||||
fine rendering pass during training.
|
fine rendering pass during training.
|
||||||
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
|
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
|
||||||
fine rendering pass during evaluation.
|
fine rendering pass during evaluation.
|
||||||
stratified_sampling_coarse_training: Enable/disable stratified sampling in the
|
stratified_sampling_coarse_training: Enable/disable stratified sampling during
|
||||||
refiner during training. Only matters if there are multiple implicit
|
training.
|
||||||
functions (i.e. in GenericModel if num_passes>1).
|
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
|
||||||
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling in
|
evaluation.
|
||||||
the refiner during evaluation. Only matters if there are multiple implicit
|
|
||||||
functions (i.e. in GenericModel if num_passes>1).
|
|
||||||
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
|
||||||
after sampling.
|
after sampling.
|
||||||
density_noise_std_train: Standard deviation of the noise added to the
|
density_noise_std_train: Standard deviation of the noise added to the
|
||||||
|
|||||||
@@ -218,7 +218,7 @@ class AdaptiveRaySampler(AbstractMaskRaySampler):
|
|||||||
|
|
||||||
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
|
||||||
"""
|
"""
|
||||||
Returns the adaptively calculated near/far planes.
|
Returns the adaptivelly calculated near/far planes.
|
||||||
"""
|
"""
|
||||||
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
|
||||||
cameras, self._scene_center, self.scene_extent
|
cameras, self._scene_center, self.scene_extent
|
||||||
|
|||||||
@@ -3,16 +3,12 @@
|
|||||||
# implicit_differentiable_renderer.py
|
# implicit_differentiable_renderer.py
|
||||||
# Copyright (c) 2020 Lior Yariv
|
# Copyright (c) 2020 Lior Yariv
|
||||||
import functools
|
import functools
|
||||||
|
import math
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.common.compat import prod
|
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
|
||||||
from pytorch3d.implicitron.tools.config import (
|
|
||||||
get_default_args_field,
|
|
||||||
registry,
|
|
||||||
run_auto_creation,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.utils import evaluating
|
from pytorch3d.implicitron.tools.utils import evaluating
|
||||||
from pytorch3d.renderer import RayBundle
|
from pytorch3d.renderer import RayBundle
|
||||||
|
|
||||||
@@ -22,10 +18,9 @@ from .rgb_net import RayNormalColoringNetwork
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ignore[13]
|
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
||||||
render_features_dimensions: int = 3
|
render_features_dimensions: int = 3
|
||||||
object_bounding_sphere: float = 1.0
|
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
||||||
ray_tracer: RayTracing
|
|
||||||
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
||||||
RayNormalColoringNetwork
|
RayNormalColoringNetwork
|
||||||
)
|
)
|
||||||
@@ -42,7 +37,8 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
f"Background color should have {render_features_dimensions} entries."
|
f"Background color should have {render_features_dimensions} entries."
|
||||||
)
|
)
|
||||||
|
|
||||||
run_auto_creation(self)
|
self.ray_tracer = RayTracing(**self.ray_tracer_args)
|
||||||
|
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
|
||||||
|
|
||||||
self.ray_normal_coloring_network_args[
|
self.ray_normal_coloring_network_args[
|
||||||
"feature_vector_size"
|
"feature_vector_size"
|
||||||
@@ -53,17 +49,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
|
|
||||||
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ray_tracer_tweak_args(cls, type, args: DictConfig) -> None:
|
|
||||||
del args["object_bounding_sphere"]
|
|
||||||
|
|
||||||
def create_ray_tracer(self) -> None:
|
|
||||||
self.ray_tracer = RayTracing(
|
|
||||||
# pyre-ignore[32]
|
|
||||||
**self.ray_tracer_args,
|
|
||||||
object_bounding_sphere=self.object_bounding_sphere,
|
|
||||||
)
|
|
||||||
|
|
||||||
def requires_object_mask(self) -> bool:
|
def requires_object_mask(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -105,13 +90,14 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
|
|
||||||
# object_mask: silhouette of the object
|
# object_mask: silhouette of the object
|
||||||
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
|
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
|
||||||
num_pixels = prod(spatial_size)
|
num_pixels = math.prod(spatial_size)
|
||||||
|
|
||||||
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
|
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
|
||||||
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)
|
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)
|
||||||
object_mask = object_mask.reshape(batch_size, -1)
|
object_mask = object_mask.reshape(batch_size, -1)
|
||||||
|
|
||||||
with torch.no_grad(), evaluating(implicit_function):
|
with torch.no_grad(), evaluating(implicit_function):
|
||||||
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
points, network_object_mask, dists = self.ray_tracer(
|
points, network_object_mask, dists = self.ray_tracer(
|
||||||
sdf=lambda x: implicit_function(x)[
|
sdf=lambda x: implicit_function(x)[
|
||||||
:, 0
|
:, 0
|
||||||
@@ -142,6 +128,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
|||||||
N = surface_points.shape[0]
|
N = surface_points.shape[0]
|
||||||
|
|
||||||
# Sample points for the eikonal loss
|
# Sample points for the eikonal loss
|
||||||
|
# pyre-fixme[9]
|
||||||
eik_bounding_box: float = self.object_bounding_sphere
|
eik_bounding_box: float = self.object_bounding_sphere
|
||||||
n_eik_points = batch_size * num_pixels // 2
|
n_eik_points = batch_size * num_pixels // 2
|
||||||
eikonal_points = torch.empty(
|
eikonal_points = torch.empty(
|
||||||
|
|||||||
@@ -881,42 +881,6 @@ def get_default_args_field(
|
|||||||
def create():
|
def create():
|
||||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
args = get_default_args(C, _do_not_process=_do_not_process)
|
||||||
if _hook is not None:
|
if _hook is not None:
|
||||||
with open_dict(args):
|
|
||||||
_hook(args)
|
|
||||||
return args
|
|
||||||
|
|
||||||
return dataclasses.field(default_factory=create)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_args_field_from_registry(
|
|
||||||
*,
|
|
||||||
base_class_wanted: Type[_X],
|
|
||||||
name: str,
|
|
||||||
_do_not_process: Tuple[type, ...] = (),
|
|
||||||
_hook: Optional[Callable[[DictConfig], None]] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Get a dataclass field which defaults to
|
|
||||||
get_default_args(registry.get(base_class_wanted, name)).
|
|
||||||
|
|
||||||
This is used internally in place of get_default_args_field in
|
|
||||||
order that default values are updated if a class is redefined.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_class_wanted: As for registry.get.
|
|
||||||
name: As for registry.get.
|
|
||||||
_do_not_process: As for get_default_args
|
|
||||||
_hook: Function called on the result before returning.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
function to return new DictConfig object
|
|
||||||
"""
|
|
||||||
|
|
||||||
def create():
|
|
||||||
C = registry.get(base_class_wanted=base_class_wanted, name=name)
|
|
||||||
args = get_default_args(C, _do_not_process=_do_not_process)
|
|
||||||
if _hook is not None:
|
|
||||||
with open_dict(args):
|
|
||||||
_hook(args)
|
_hook(args)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -1014,9 +978,8 @@ def _process_member(
|
|||||||
setattr(
|
setattr(
|
||||||
some_class,
|
some_class,
|
||||||
args_name,
|
args_name,
|
||||||
_get_default_args_field_from_registry(
|
get_default_args_field(
|
||||||
base_class_wanted=type_,
|
derived_type,
|
||||||
name=derived_type.__name__,
|
|
||||||
_do_not_process=_do_not_process + (some_class,),
|
_do_not_process=_do_not_process + (some_class,),
|
||||||
_hook=hook_closed,
|
_hook=hook_closed,
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -74,7 +74,6 @@ class Stats(object):
|
|||||||
"""
|
"""
|
||||||
stats logging object useful for gathering statistics of training a deep net in pytorch
|
stats logging object useful for gathering statistics of training a deep net in pytorch
|
||||||
Example:
|
Example:
|
||||||
```
|
|
||||||
# init stats structure that logs statistics 'objective' and 'top1e'
|
# init stats structure that logs statistics 'objective' and 'top1e'
|
||||||
stats = Stats( ('objective','top1e') )
|
stats = Stats( ('objective','top1e') )
|
||||||
network = init_net() # init a pytorch module (=nueral network)
|
network = init_net() # init a pytorch module (=nueral network)
|
||||||
@@ -95,7 +94,6 @@ class Stats(object):
|
|||||||
# stores the training plots into '/tmp/epoch_stats.pdf'
|
# stores the training plots into '/tmp/epoch_stats.pdf'
|
||||||
# and plots into a visdom server running at localhost (if running)
|
# and plots into a visdom server running at localhost (if running)
|
||||||
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
|
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
|
||||||
```
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -198,6 +196,7 @@ class Stats(object):
|
|||||||
if verbose:
|
if verbose:
|
||||||
print(f"Adding {add_log_var}")
|
print(f"Adding {add_log_var}")
|
||||||
self.log_vars.append(add_log_var)
|
self.log_vars.append(add_log_var)
|
||||||
|
# self.synchronize_logged_vars(self.log_vars, verbose=verbose)
|
||||||
|
|
||||||
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
|
||||||
|
|
||||||
@@ -229,6 +228,7 @@ class Stats(object):
|
|||||||
elapsed = time.time() - time_start
|
elapsed = time.time() - time_start
|
||||||
time_per_it = float(elapsed) / float(it + 1)
|
time_per_it = float(elapsed) / float(it + 1)
|
||||||
val = time_per_it
|
val = time_per_it
|
||||||
|
# self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
|
||||||
else:
|
else:
|
||||||
if stat in preds:
|
if stat in preds:
|
||||||
try:
|
try:
|
||||||
@@ -439,6 +439,7 @@ class Stats(object):
|
|||||||
self.log_vars = log_vars # !!!
|
self.log_vars = log_vars # !!!
|
||||||
|
|
||||||
for stat_set in stat_sets:
|
for stat_set in stat_sets:
|
||||||
|
reference_stat = list(self.stats[stat_set].keys())[0]
|
||||||
for stat in log_vars:
|
for stat in log_vars:
|
||||||
if stat not in self.stats[stat_set]:
|
if stat not in self.stats[stat_set]:
|
||||||
if verbose:
|
if verbose:
|
||||||
@@ -465,11 +466,12 @@ class Stats(object):
|
|||||||
lastep = self.epoch + 1
|
lastep = self.epoch + 1
|
||||||
for ep in range(lastep):
|
for ep in range(lastep):
|
||||||
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
|
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
|
||||||
|
epoch_self = self.stats[stat_set][reference_stat].get_epoch()
|
||||||
epoch_generated = self.stats[stat_set][stat].get_epoch()
|
epoch_generated = self.stats[stat_set][stat].get_epoch()
|
||||||
assert (
|
assert (
|
||||||
epoch_generated == self.epoch + 1
|
epoch_self == epoch_generated
|
||||||
), "bad epoch of synchronized log_var! %d vs %d" % (
|
), "bad epoch of synchronized log_var! %d vs %d" % (
|
||||||
self.epoch + 1,
|
epoch_self,
|
||||||
epoch_generated,
|
epoch_generated,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -157,6 +157,15 @@ def cat_dataclass(batch, tensor_collator: Callable):
|
|||||||
return type(elem)(**collated)
|
return type(elem)(**collated)
|
||||||
|
|
||||||
|
|
||||||
|
def setattr_if_hasattr(obj, name, value):
|
||||||
|
"""
|
||||||
|
Same as setattr(obj, name, value), but does nothing in case `name` is
|
||||||
|
not an attribe of `obj`.
|
||||||
|
"""
|
||||||
|
if hasattr(obj, name):
|
||||||
|
setattr(obj, name, value)
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
"""
|
"""
|
||||||
A simple class for timing execution.
|
A simple class for timing execution.
|
||||||
|
|||||||
@@ -84,6 +84,8 @@ class VideoWriter:
|
|||||||
or a 2-tuple defining the size of the output image.
|
or a 2-tuple defining the size of the output image.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[PathLike[str], str]` but got
|
||||||
|
# `Optional[str]`.
|
||||||
outfile = os.path.join(self.cache_dir, self.regexp % self.frame_num)
|
outfile = os.path.join(self.cache_dir, self.regexp % self.frame_num)
|
||||||
|
|
||||||
if isinstance(frame, matplotlib.figure.Figure):
|
if isinstance(frame, matplotlib.figure.Figure):
|
||||||
@@ -125,6 +127,8 @@ class VideoWriter:
|
|||||||
video_path: The path to the generated video.
|
video_path: The path to the generated video.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# pyre-fixme[6]: For 1st param expected `Union[PathLike[str], str]` but got
|
||||||
|
# `Optional[str]`.
|
||||||
regexp = os.path.join(self.cache_dir, self.regexp)
|
regexp = os.path.join(self.cache_dir, self.regexp)
|
||||||
|
|
||||||
if self.output_format == "visdom": # works for ppt too
|
if self.output_format == "visdom": # works for ppt too
|
||||||
|
|||||||
@@ -14,22 +14,20 @@ from visdom import Visdom
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_visdom_env(visdom_env: str, exp_dir: str) -> str:
|
def get_visdom_env(cfg):
|
||||||
"""
|
"""
|
||||||
Parse out visdom environment name from the input config.
|
Parse out visdom environment name from the input config.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
visdom_env: Name of the wisdom environment, could be empty string.
|
cfg: The global config file.
|
||||||
exp_dir: Root experiment directory.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
visdom_env: The name of the visdom environment. If the given visdom_env is
|
visdom_env: The name of the visdom environment.
|
||||||
empty, return the name of the bottom directory in exp_dir.
|
|
||||||
"""
|
"""
|
||||||
if len(visdom_env) == 0:
|
if len(cfg.visdom_env) == 0:
|
||||||
visdom_env = exp_dir.split("/")[-1]
|
visdom_env = cfg.exp_dir.split("/")[-1]
|
||||||
else:
|
else:
|
||||||
visdom_env = visdom_env
|
visdom_env = cfg.visdom_env
|
||||||
return visdom_env
|
return visdom_env
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -215,6 +215,8 @@ def load_obj(
|
|||||||
"""
|
"""
|
||||||
data_dir = "./"
|
data_dir = "./"
|
||||||
if isinstance(f, (str, bytes, Path)):
|
if isinstance(f, (str, bytes, Path)):
|
||||||
|
# pyre-fixme[6]: For 1st param expected `PathLike[Variable[AnyStr <: [str,
|
||||||
|
# bytes]]]` but got `Union[Path, bytes, str]`.
|
||||||
data_dir = os.path.dirname(f)
|
data_dir = os.path.dirname(f)
|
||||||
if path_manager is None:
|
if path_manager is None:
|
||||||
path_manager = PathManager()
|
path_manager = PathManager()
|
||||||
|
|||||||
@@ -65,6 +65,11 @@ from .mesh import (
|
|||||||
TexturesVertex,
|
TexturesVertex,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from .opengl import EGLContext, global_device_context_store, MeshRasterizerOpenGL
|
||||||
|
except (ImportError, ModuleNotFoundError):
|
||||||
|
pass # opengl or pycuda.gl not available, or pytorch3_opengl not in TARGETS.
|
||||||
|
|
||||||
from .points import (
|
from .points import (
|
||||||
AlphaCompositor,
|
AlphaCompositor,
|
||||||
NormWeightedCompositor,
|
NormWeightedCompositor,
|
||||||
|
|||||||
@@ -1661,9 +1661,9 @@ def look_at_rotation(
|
|||||||
|
|
||||||
|
|
||||||
def look_at_view_transform(
|
def look_at_view_transform(
|
||||||
dist: _BatchFloatType = 1.0,
|
dist: float = 1.0,
|
||||||
elev: _BatchFloatType = 0.0,
|
elev: float = 0.0,
|
||||||
azim: _BatchFloatType = 0.0,
|
azim: float = 0.0,
|
||||||
degrees: bool = True,
|
degrees: bool = True,
|
||||||
eye: Optional[Union[Sequence, torch.Tensor]] = None,
|
eye: Optional[Union[Sequence, torch.Tensor]] = None,
|
||||||
at=((0, 0, 0),), # (1, 3)
|
at=((0, 0, 0),), # (1, 3)
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ...structures.meshes import Meshes
|
from ...structures.meshes import Meshes
|
||||||
|
from .rasterizer import MeshRasterizer
|
||||||
|
|
||||||
|
|
||||||
# A renderer class should be initialized with a
|
# A renderer class should be initialized with a
|
||||||
# function for rasterization and a function for shading.
|
# function for rasterization and a function for shading.
|
||||||
@@ -30,11 +32,11 @@ from ...structures.meshes import Meshes
|
|||||||
class MeshRenderer(nn.Module):
|
class MeshRenderer(nn.Module):
|
||||||
"""
|
"""
|
||||||
A class for rendering a batch of heterogeneous meshes. The class should
|
A class for rendering a batch of heterogeneous meshes. The class should
|
||||||
be initialized with a rasterizer (a MeshRasterizer or a MeshRasterizerOpenGL)
|
be initialized with a rasterizer and shader class which each have a forward
|
||||||
and shader class which each have a forward function.
|
function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rasterizer, shader) -> None:
|
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.shader = shader
|
self.shader = shader
|
||||||
@@ -67,8 +69,8 @@ class MeshRenderer(nn.Module):
|
|||||||
class MeshRendererWithFragments(nn.Module):
|
class MeshRendererWithFragments(nn.Module):
|
||||||
"""
|
"""
|
||||||
A class for rendering a batch of heterogeneous meshes. The class should
|
A class for rendering a batch of heterogeneous meshes. The class should
|
||||||
be initialized with a rasterizer (a MeshRasterizer or a MeshRasterizerOpenGL)
|
be initialized with a rasterizer and shader class which each have a forward
|
||||||
and shader class which each have a forward function.
|
function.
|
||||||
|
|
||||||
In the forward pass this class returns the `fragments` from which intermediate
|
In the forward pass this class returns the `fragments` from which intermediate
|
||||||
values such as the depth map can be easily extracted e.g.
|
values such as the depth map can be easily extracted e.g.
|
||||||
@@ -78,7 +80,7 @@ class MeshRendererWithFragments(nn.Module):
|
|||||||
depth = fragments.zbuf
|
depth = fragments.zbuf
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, rasterizer, shader) -> None:
|
def __init__(self, rasterizer: MeshRasterizer, shader) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.shader = shader
|
self.shader = shader
|
||||||
|
|||||||
@@ -130,9 +130,8 @@ class MeshRasterizerOpenGL(nn.Module):
|
|||||||
|
|
||||||
Fragments output by MeshRasterizerOpenGL and MeshRasterizer should have near
|
Fragments output by MeshRasterizerOpenGL and MeshRasterizer should have near
|
||||||
identical pix_to_face, bary_coords and zbuf. However, MeshRasterizerOpenGL does not
|
identical pix_to_face, bary_coords and zbuf. However, MeshRasterizerOpenGL does not
|
||||||
return Fragments.dists which is only relevant to SoftPhongShader and
|
return Fragments.dists which is only relevant to SoftPhongShader which doesn't work
|
||||||
SoftSilhouetteShader. These do not work with MeshRasterizerOpenGL (because it is
|
with MeshRasterizerOpenGL (because it is not differentiable).
|
||||||
not differentiable).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@@ -474,12 +474,6 @@ class Renderer(torch.nn.Module):
|
|||||||
rot_mat = axis_angle_to_matrix(rot_vec)
|
rot_mat = axis_angle_to_matrix(rot_vec)
|
||||||
if first_R_then_T:
|
if first_R_then_T:
|
||||||
pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0]
|
pos_vec = torch.matmul(rot_mat, pos_vec[..., None])[:, :, 0]
|
||||||
LOGGER.debug(
|
|
||||||
"Camera position: %s, rotation: %s. Focal length: %s.",
|
|
||||||
str(pos_vec),
|
|
||||||
str(rot_vec),
|
|
||||||
str(focal_length),
|
|
||||||
)
|
|
||||||
sensor_dir_x = torch.matmul(
|
sensor_dir_x = torch.matmul(
|
||||||
rot_mat,
|
rot_mat,
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
@@ -500,56 +494,20 @@ class Renderer(torch.nn.Module):
|
|||||||
)[:, :, 0]
|
)[:, :, 0]
|
||||||
if right_handed:
|
if right_handed:
|
||||||
sensor_dir_z *= -1
|
sensor_dir_z *= -1
|
||||||
LOGGER.debug(
|
|
||||||
"Sensor direction vectors: %s, %s, %s.",
|
|
||||||
str(sensor_dir_x),
|
|
||||||
str(sensor_dir_y),
|
|
||||||
str(sensor_dir_z),
|
|
||||||
)
|
|
||||||
if orthogonal:
|
if orthogonal:
|
||||||
sensor_center = pos_vec
|
sensor_center = pos_vec
|
||||||
else:
|
else:
|
||||||
sensor_center = pos_vec + focal_length * sensor_dir_z
|
sensor_center = pos_vec + focal_length * sensor_dir_z
|
||||||
LOGGER.debug("Sensor center: %s.", str(sensor_center))
|
|
||||||
sensor_luc = ( # Sensor left upper corner.
|
sensor_luc = ( # Sensor left upper corner.
|
||||||
sensor_center
|
sensor_center
|
||||||
- sensor_dir_x * (sensor_size_x / 2.0)
|
- sensor_dir_x * (sensor_size_x / 2.0)
|
||||||
- sensor_dir_y * (sensor_size_y / 2.0)
|
- sensor_dir_y * (sensor_size_y / 2.0)
|
||||||
)
|
)
|
||||||
LOGGER.debug("Sensor luc: %s.", str(sensor_luc))
|
|
||||||
pixel_size_x = sensor_size_x / float(width)
|
pixel_size_x = sensor_size_x / float(width)
|
||||||
pixel_size_y = sensor_size_y / float(height)
|
pixel_size_y = sensor_size_y / float(height)
|
||||||
LOGGER.debug(
|
|
||||||
"Pixel sizes (x): %s, (y) %s.", str(pixel_size_x), str(pixel_size_y)
|
|
||||||
)
|
|
||||||
pixel_vec_x: torch.Tensor = sensor_dir_x * pixel_size_x
|
pixel_vec_x: torch.Tensor = sensor_dir_x * pixel_size_x
|
||||||
pixel_vec_y: torch.Tensor = sensor_dir_y * pixel_size_y
|
pixel_vec_y: torch.Tensor = sensor_dir_y * pixel_size_y
|
||||||
pixel_0_0_center = sensor_luc + 0.5 * pixel_vec_x + 0.5 * pixel_vec_y
|
pixel_0_0_center = sensor_luc + 0.5 * pixel_vec_x + 0.5 * pixel_vec_y
|
||||||
LOGGER.debug(
|
|
||||||
"Pixel 0 centers: %s, vec x: %s, vec y: %s.",
|
|
||||||
str(pixel_0_0_center),
|
|
||||||
str(pixel_vec_x),
|
|
||||||
str(pixel_vec_y),
|
|
||||||
)
|
|
||||||
if not orthogonal:
|
|
||||||
LOGGER.debug(
|
|
||||||
"Camera horizontal fovs: %s deg.",
|
|
||||||
str(
|
|
||||||
2.0
|
|
||||||
* torch.atan(0.5 * sensor_size_x / focal_length)
|
|
||||||
/ math.pi
|
|
||||||
* 180.0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
LOGGER.debug(
|
|
||||||
"Camera vertical fovs: %s deg.",
|
|
||||||
str(
|
|
||||||
2.0
|
|
||||||
* torch.atan(0.5 * sensor_size_y / focal_length)
|
|
||||||
/ math.pi
|
|
||||||
* 180.0
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# Reduce dimension.
|
# Reduce dimension.
|
||||||
focal_length: torch.Tensor = focal_length[:, 0]
|
focal_length: torch.Tensor = focal_length[:, 0]
|
||||||
if batch_processing:
|
if batch_processing:
|
||||||
|
|||||||
@@ -323,6 +323,7 @@ def random_quaternions(
|
|||||||
"""
|
"""
|
||||||
if isinstance(device, str):
|
if isinstance(device, str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
# pyre-fixme[6]: For 2nd param expected `dtype` but got `Optional[dtype]`.
|
||||||
o = torch.randn((n, 4), dtype=dtype, device=device)
|
o = torch.randn((n, 4), dtype=dtype, device=device)
|
||||||
s = (o * o).sum(1)
|
s = (o * o).sum(1)
|
||||||
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
||||||
|
|||||||
21
setup.py
21
setup.py
@@ -8,7 +8,6 @@
|
|||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
import runpy
|
import runpy
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
@@ -36,13 +35,6 @@ def get_existing_ccbin(nvcc_args: List[str]) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def get_extensions():
|
def get_extensions():
|
||||||
no_extension = os.getenv("PYTORCH3D_NO_EXTENSION", "0") == "1"
|
|
||||||
if no_extension:
|
|
||||||
msg = "SKIPPING EXTENSION BUILD. PYTORCH3D WILL NOT WORK!"
|
|
||||||
print(msg, file=sys.stderr)
|
|
||||||
warnings.warn(msg)
|
|
||||||
return []
|
|
||||||
|
|
||||||
this_dir = os.path.dirname(os.path.abspath(__file__))
|
this_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc")
|
extensions_dir = os.path.join(this_dir, "pytorch3d", "csrc")
|
||||||
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
|
sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp"), recursive=True)
|
||||||
@@ -54,10 +46,7 @@ def get_extensions():
|
|||||||
include_dirs = [extensions_dir]
|
include_dirs = [extensions_dir]
|
||||||
|
|
||||||
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
|
force_cuda = os.getenv("FORCE_CUDA", "0") == "1"
|
||||||
force_no_cuda = os.getenv("PYTORCH3D_FORCE_NO_CUDA", "0") == "1"
|
if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda:
|
||||||
if (
|
|
||||||
not force_no_cuda and torch.cuda.is_available() and CUDA_HOME is not None
|
|
||||||
) or force_cuda:
|
|
||||||
extension = CUDAExtension
|
extension = CUDAExtension
|
||||||
sources += source_cuda
|
sources += source_cuda
|
||||||
define_macros += [("WITH_CUDA", None)]
|
define_macros += [("WITH_CUDA", None)]
|
||||||
@@ -139,7 +128,7 @@ if os.getenv("PYTORCH3D_NO_NINJA", "0") == "1":
|
|||||||
else:
|
else:
|
||||||
BuildExtension = torch.utils.cpp_extension.BuildExtension
|
BuildExtension = torch.utils.cpp_extension.BuildExtension
|
||||||
|
|
||||||
trainer = "pytorch3d.implicitron_trainer"
|
trainer = "projects.implicitron_trainer"
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="pytorch3d",
|
name="pytorch3d",
|
||||||
@@ -149,10 +138,8 @@ setup(
|
|||||||
description="PyTorch3D is FAIR's library of reusable components "
|
description="PyTorch3D is FAIR's library of reusable components "
|
||||||
"for deep Learning with 3D data.",
|
"for deep Learning with 3D data.",
|
||||||
packages=find_packages(
|
packages=find_packages(
|
||||||
exclude=("configs", "tests", "tests.*", "docs.*", "projects.*")
|
exclude=("configs", "tests", "tests.*", "docs.*", "projects.nerf.*")
|
||||||
)
|
),
|
||||||
+ [trainer],
|
|
||||||
package_dir={trainer: "projects/implicitron_trainer"},
|
|
||||||
install_requires=["fvcore", "iopath"],
|
install_requires=["fvcore", "iopath"],
|
||||||
extras_require={
|
extras_require={
|
||||||
"all": ["matplotlib", "tqdm>4.29.0", "imageio", "ipywidgets"],
|
"all": ["matplotlib", "tqdm>4.29.0", "imageio", "ipywidgets"],
|
||||||
|
|||||||
@@ -2,6 +2,6 @@
|
|||||||
|
|
||||||
This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration.
|
This is copied version of docs/tutorials/data/cow_mesh with removed line 6159 (usemtl material_1) to test behavior without usemtl material_1 declaration.
|
||||||
|
|
||||||
Thank you to Keenan Crane for allowing the cow mesh model to be used freely in the public domain.
|
Thank you to Keenen Crane for allowing the cow mesh model to be used freely in the public domain.
|
||||||
|
|
||||||
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
###### Source: http://www.cs.cmu.edu/~kmcrane/Projects/ModelRepository/
|
||||||
|
|||||||
@@ -52,11 +52,17 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
|||||||
dataset_class_type: JsonIndexDataset
|
dataset_class_type: JsonIndexDataset
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
dataset_JsonIndexDataset_args:
|
dataset_JsonIndexDataset_args:
|
||||||
|
path_manager: null
|
||||||
|
frame_annotations_file: ''
|
||||||
|
sequence_annotations_file: ''
|
||||||
|
subset_lists_file: ''
|
||||||
|
subsets: null
|
||||||
limit_to: 0
|
limit_to: 0
|
||||||
limit_sequences_to: 0
|
limit_sequences_to: 0
|
||||||
pick_sequence: []
|
pick_sequence: []
|
||||||
exclude_sequence: []
|
exclude_sequence: []
|
||||||
limit_category_to: []
|
limit_category_to: []
|
||||||
|
dataset_root: ''
|
||||||
load_images: true
|
load_images: true
|
||||||
load_depths: true
|
load_depths: true
|
||||||
load_depth_masks: true
|
load_depth_masks: true
|
||||||
@@ -74,6 +80,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
|||||||
n_frames_per_sequence: -1
|
n_frames_per_sequence: -1
|
||||||
seed: 0
|
seed: 0
|
||||||
sort_frames: false
|
sort_frames: false
|
||||||
|
eval_batches: null
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||||
@@ -83,16 +90,6 @@ dataset_map_provider_LlffDatasetMapProvider_args:
|
|||||||
n_known_frames_for_test: null
|
n_known_frames_for_test: null
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
downscale_factor: 4
|
|
||||||
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
|
|
||||||
num_views: 40
|
|
||||||
data_file: null
|
|
||||||
azimuth_range: 180.0
|
|
||||||
resolution: 128
|
|
||||||
use_point_light: true
|
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
|
||||||
path_manager_factory_PathManagerFactory_args:
|
|
||||||
silence_logs: true
|
|
||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
batch_size: 1
|
batch_size: 1
|
||||||
num_workers: 0
|
num_workers: 0
|
||||||
@@ -106,9 +103,3 @@ data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
|||||||
sample_consecutive_frames: false
|
sample_consecutive_frames: false
|
||||||
consecutive_frames_max_gap: 0
|
consecutive_frames_max_gap: 0
|
||||||
consecutive_frames_max_gap_seconds: 0.1
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
|
|
||||||
batch_size: 1
|
|
||||||
num_workers: 0
|
|
||||||
dataset_length_train: 0
|
|
||||||
dataset_length_val: 0
|
|
||||||
dataset_length_test: 0
|
|
||||||
|
|||||||
@@ -1,24 +1,3 @@
|
|||||||
log_vars:
|
|
||||||
- loss_rgb_psnr_fg
|
|
||||||
- loss_rgb_psnr
|
|
||||||
- loss_rgb_mse
|
|
||||||
- loss_rgb_huber
|
|
||||||
- loss_depth_abs
|
|
||||||
- loss_depth_abs_fg
|
|
||||||
- loss_mask_neg_iou
|
|
||||||
- loss_mask_bce
|
|
||||||
- loss_mask_beta_prior
|
|
||||||
- loss_eikonal
|
|
||||||
- loss_density_tv
|
|
||||||
- loss_depth_neg_penalty
|
|
||||||
- loss_autodecoder_norm
|
|
||||||
- loss_prev_stage_rgb_mse
|
|
||||||
- loss_prev_stage_rgb_psnr_fg
|
|
||||||
- loss_prev_stage_rgb_psnr
|
|
||||||
- loss_prev_stage_mask_bce
|
|
||||||
- objective
|
|
||||||
- epoch
|
|
||||||
- sec/it
|
|
||||||
mask_images: true
|
mask_images: true
|
||||||
mask_depths: true
|
mask_depths: true
|
||||||
render_image_width: 400
|
render_image_width: 400
|
||||||
@@ -49,13 +28,38 @@ loss_weights:
|
|||||||
loss_prev_stage_rgb_mse: 1.0
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
loss_mask_bce: 0.0
|
loss_mask_bce: 0.0
|
||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_rgb_huber
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_mask_beta_prior
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_density_tv
|
||||||
|
- loss_depth_neg_penalty
|
||||||
|
- loss_autodecoder_norm
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_psnr_fg
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
global_encoder_SequenceAutodecoder_args:
|
global_encoder_SequenceAutodecoder_args:
|
||||||
autodecoder_args:
|
autodecoder_args:
|
||||||
encoding_dim: 0
|
encoding_dim: 0
|
||||||
n_instances: 1
|
n_instances: 0
|
||||||
init_scale: 1.0
|
init_scale: 1.0
|
||||||
ignore_input: false
|
ignore_input: false
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
image_width: 400
|
||||||
|
image_height: 400
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
n_pts_per_ray_training: 64
|
n_pts_per_ray_training: 64
|
||||||
n_pts_per_ray_evaluation: 64
|
n_pts_per_ray_evaluation: 64
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
@@ -103,6 +107,7 @@ view_pooler_args:
|
|||||||
weight_by_ray_angle_gamma: 1.0
|
weight_by_ray_angle_gamma: 1.0
|
||||||
min_ray_angle_weight: 0.1
|
min_ray_angle_weight: 0.1
|
||||||
implicit_function_IdrFeatureField_args:
|
implicit_function_IdrFeatureField_args:
|
||||||
|
feature_vector_size: 3
|
||||||
d_in: 3
|
d_in: 3
|
||||||
d_out: 1
|
d_out: 1
|
||||||
dims:
|
dims:
|
||||||
@@ -120,5 +125,6 @@ implicit_function_IdrFeatureField_args:
|
|||||||
weight_norm: true
|
weight_norm: true
|
||||||
n_harmonic_functions_xyz: 1729
|
n_harmonic_functions_xyz: 1729
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
|
encoding_dim: 0
|
||||||
view_metrics_ViewMetrics_args: {}
|
view_metrics_ViewMetrics_args: {}
|
||||||
regularization_metrics_RegularizationMetrics_args: {}
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
|
|||||||
@@ -378,20 +378,14 @@ class TestConfig(unittest.TestCase):
|
|||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning, "New implementation of Grape is being chosen."
|
UserWarning, "New implementation of Grape is being chosen."
|
||||||
):
|
):
|
||||||
defaulted_bowl = FruitBowl()
|
bowl = FruitBowl(**bowl_args)
|
||||||
self.assertIsInstance(defaulted_bowl.main_fruit, Grape)
|
self.assertIsInstance(bowl.main_fruit, Grape)
|
||||||
self.assertEqual(defaulted_bowl.main_fruit.large, True)
|
|
||||||
self.assertEqual(defaulted_bowl.main_fruit.get_color(), "green")
|
|
||||||
|
|
||||||
with self.assertWarnsRegex(
|
|
||||||
UserWarning, "New implementation of Grape is being chosen."
|
|
||||||
):
|
|
||||||
args_bowl = FruitBowl(**bowl_args)
|
|
||||||
self.assertIsInstance(args_bowl.main_fruit, Grape)
|
|
||||||
# Redefining the same class won't help with defaults because encoded in args
|
# Redefining the same class won't help with defaults because encoded in args
|
||||||
self.assertEqual(args_bowl.main_fruit.large, False)
|
self.assertEqual(bowl.main_fruit.large, False)
|
||||||
|
|
||||||
# But the override worked.
|
# But the override worked.
|
||||||
self.assertEqual(args_bowl.main_fruit.get_color(), "green")
|
self.assertEqual(bowl.main_fruit.get_color(), "green")
|
||||||
|
|
||||||
# 2. Try redefining without the dataclass modifier
|
# 2. Try redefining without the dataclass modifier
|
||||||
# This relies on the fact that default creation processes the class.
|
# This relies on the fact that default creation processes the class.
|
||||||
@@ -403,7 +397,7 @@ class TestConfig(unittest.TestCase):
|
|||||||
with self.assertWarnsRegex(
|
with self.assertWarnsRegex(
|
||||||
UserWarning, "New implementation of Grape is being chosen."
|
UserWarning, "New implementation of Grape is being chosen."
|
||||||
):
|
):
|
||||||
FruitBowl(**bowl_args)
|
bowl = FruitBowl(**bowl_args)
|
||||||
|
|
||||||
# 3. Adding a new class doesn't get picked up, because the first
|
# 3. Adding a new class doesn't get picked up, because the first
|
||||||
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
||||||
@@ -691,17 +685,12 @@ class TestConfig(unittest.TestCase):
|
|||||||
fruit2_class_type: str = "Pear"
|
fruit2_class_type: str = "Pear"
|
||||||
a: A
|
a: A
|
||||||
a2: A
|
a2: A
|
||||||
a3: A
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def a_tweak_args(cls, type, args):
|
def a_tweak_args(cls, type, args):
|
||||||
assert type == A
|
assert type == A
|
||||||
args.n = 993
|
args.n = 993
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def a3_tweak_args(cls, type, args):
|
|
||||||
del args["n"]
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def fruit_tweak_args(cls, type, args):
|
def fruit_tweak_args(cls, type, args):
|
||||||
assert issubclass(type, Fruit)
|
assert issubclass(type, Fruit)
|
||||||
@@ -712,7 +701,6 @@ class TestConfig(unittest.TestCase):
|
|||||||
args = get_default_args(Wrapper)
|
args = get_default_args(Wrapper)
|
||||||
self.assertEqual(args.a_args.n, 993)
|
self.assertEqual(args.a_args.n, 993)
|
||||||
self.assertEqual(args.a2_args.n, 9)
|
self.assertEqual(args.a2_args.n, 9)
|
||||||
self.assertEqual(args.a3_args, {})
|
|
||||||
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
self.assertEqual(args.fruit_Pear_args.n_pips, 19)
|
||||||
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
self.assertEqual(args.fruit2_Pear_args.n_pips, 13)
|
||||||
|
|
||||||
|
|||||||
@@ -90,5 +90,5 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
remove_unused_components(instance_args)
|
remove_unused_components(instance_args)
|
||||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
(DATA_DIR / "overrides_.yaml").write_text(yaml)
|
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||||
|
|||||||
@@ -1,57 +0,0 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the BSD-style license found in the
|
|
||||||
# LICENSE file in the root directory of this source tree.
|
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
|
||||||
from pytorch3d.implicitron.dataset.rendered_mesh_dataset_map_provider import (
|
|
||||||
RenderedMeshDatasetMapProvider,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
|
||||||
from pytorch3d.renderer import FoVPerspectiveCameras
|
|
||||||
from tests.common_testing import TestCaseMixin
|
|
||||||
|
|
||||||
|
|
||||||
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataCow(TestCaseMixin, unittest.TestCase):
|
|
||||||
def test_simple(self):
|
|
||||||
if inside_re_worker:
|
|
||||||
return
|
|
||||||
expand_args_fields(RenderedMeshDatasetMapProvider)
|
|
||||||
self._runtest(use_point_light=True, num_views=4)
|
|
||||||
self._runtest(use_point_light=False, num_views=4)
|
|
||||||
|
|
||||||
def _runtest(self, **kwargs):
|
|
||||||
provider = RenderedMeshDatasetMapProvider(**kwargs)
|
|
||||||
dataset_map = provider.get_dataset_map()
|
|
||||||
known_matrix = torch.zeros(1, 4, 4)
|
|
||||||
known_matrix[0, 0, 0] = 1.7321
|
|
||||||
known_matrix[0, 1, 1] = 1.7321
|
|
||||||
known_matrix[0, 2, 2] = 1.0101
|
|
||||||
known_matrix[0, 3, 2] = -1.0101
|
|
||||||
known_matrix[0, 2, 3] = 1
|
|
||||||
|
|
||||||
self.assertIsNone(dataset_map.val)
|
|
||||||
self.assertIsNone(dataset_map.test)
|
|
||||||
self.assertEqual(len(dataset_map.train), provider.num_views)
|
|
||||||
|
|
||||||
value = dataset_map.train[0]
|
|
||||||
self.assertIsInstance(value, FrameData)
|
|
||||||
|
|
||||||
self.assertEqual(value.image_rgb.shape, (3, 128, 128))
|
|
||||||
self.assertEqual(value.fg_probability.shape, (1, 128, 128))
|
|
||||||
# corner of image is background
|
|
||||||
self.assertEqual(value.fg_probability[0, 0, 0], 0)
|
|
||||||
self.assertEqual(value.fg_probability.max(), 1.0)
|
|
||||||
self.assertIsInstance(value.camera, FoVPerspectiveCameras)
|
|
||||||
self.assertEqual(len(value.camera), 1)
|
|
||||||
self.assertIsNone(value.camera.K)
|
|
||||||
matrix = value.camera.get_projection_transform().get_matrix()
|
|
||||||
self.assertClose(matrix, known_matrix, atol=1e-4)
|
|
||||||
@@ -69,7 +69,6 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
|||||||
provider = LlffDatasetMapProvider(
|
provider = LlffDatasetMapProvider(
|
||||||
base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
|
base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
|
||||||
object_name="fern",
|
object_name="fern",
|
||||||
downscale_factor=8,
|
|
||||||
)
|
)
|
||||||
dataset_map = provider.get_dataset_map()
|
dataset_map = provider.get_dataset_map()
|
||||||
known_matrix = torch.zeros(1, 4, 4)
|
known_matrix = torch.zeros(1, 4, 4)
|
||||||
|
|||||||
@@ -10,10 +10,6 @@ import unittest.mock
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
|
||||||
SequenceDataLoaderMapProvider,
|
|
||||||
SimpleDataLoaderMapProvider,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
@@ -68,6 +64,7 @@ class TestDataSource(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
args = get_default_args(ImplicitronDataSource)
|
args = get_default_args(ImplicitronDataSource)
|
||||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
|
||||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
@@ -76,35 +73,8 @@ class TestDataSource(unittest.TestCase):
|
|||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
|
|
||||||
data_source = ImplicitronDataSource(**args)
|
data_source = ImplicitronDataSource(**args)
|
||||||
self.assertIsInstance(
|
|
||||||
data_source.data_loader_map_provider, SequenceDataLoaderMapProvider
|
|
||||||
)
|
|
||||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
self.assertEqual(len(data_loaders.train), 81)
|
self.assertEqual(len(data_loaders.train), 81)
|
||||||
for i in data_loaders.train:
|
for i in data_loaders.train:
|
||||||
self.assertEqual(i.frame_type, ["test_known"])
|
self.assertEqual(i.frame_type, ["test_known"])
|
||||||
break
|
break
|
||||||
|
|
||||||
def test_simple(self):
|
|
||||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
|
||||||
return
|
|
||||||
args = get_default_args(ImplicitronDataSource)
|
|
||||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
|
||||||
args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider"
|
|
||||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
|
||||||
dataset_args.category = "skateboard"
|
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
|
||||||
dataset_args.n_frames_per_sequence = -1
|
|
||||||
|
|
||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
|
||||||
|
|
||||||
data_source = ImplicitronDataSource(**args)
|
|
||||||
self.assertIsInstance(
|
|
||||||
data_source.data_loader_map_provider, SimpleDataLoaderMapProvider
|
|
||||||
)
|
|
||||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
|
||||||
|
|
||||||
self.assertEqual(len(data_loaders.train), 81)
|
|
||||||
for i in data_loaders.train:
|
|
||||||
self.assertEqual(i.frame_type, ["test_known"])
|
|
||||||
break
|
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user