implicitron v0 (#1133)

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
This commit is contained in:
Jeremy Reizenstein
2022-03-21 20:20:10 +00:00
committed by GitHub
parent 0e377c6850
commit cdd2142dd5
90 changed files with 17075 additions and 0 deletions

View File

@@ -0,0 +1,276 @@
# Introduction
Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling the neural-network based representations.
# License
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
# Installation
There are three ways to set up Implicitron, depending on the flexibility level required.
If you only want to train or evaluate models as they are implemented changing only the parameters, you can just install the package.
Implicitron also provides a flexible API that supports user-defined plug-ins;
if you want to re-implement some of the components without changing the high-level pipeline, you need to create a custom launcher script.
The most flexible option, though, is cloning PyTorch3D repo and building it from sources, which allows changing the code in arbitrary ways.
Below, we descibe all three options in more details.
## [Option 1] Running an executable from the package
This option allows you to use the code as is without changing the implementations.
Only configuration can be changed (see [Configuration system](#configuration-system)).
For this setup, install the dependencies and PyTorch3D from conda following [the guide](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#1-install-with-cuda-support-from-anaconda-cloud-on-linux-only). Then, install implicitron-specific dependencies:
```shell
pip install "hydra-core>=1.1" visdom lpips matplotlib
```
Runner executable is available as `pytorch3d_implicitron_runner` shell command.
See [Running](#running) section below for examples of training and evaluation commands.
## [Option 2] Supporting custom implementations
To plug in custom implementations, for example, of renderer or implicit-function protocols, you need to create your own runner script and import the plug-in implementations there.
First, install PyTorch3D and Implicitron dependencies as described in the previous section.
Then, implement the custom script; copying `pytorch3d/projects/implicitron_trainer/experiment.py` is a good place to start.
See [Custom plugins](#custom-plugins) for more information on how to import implementations and enable them in the configs.
## [Option 3] Cloning PyTorch3D repo
This is the most flexible way to set up Implicitron as it allows changing the code directly.
It allows modifying the high-level rendering pipeline or implementing yet-unsupported loss functions.
Please follow the instructions to [install PyTorch3D from a local clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone).
Then, install Implicitron-specific dependencies:
```shell
pip install "hydra-core>=1.1" visdom lpips matplotlib
```
You are still encouraged to implement custom plugins as above where possible as it makes reusing the code easier.
The executable is located in `pytorch3d/projects/implicitron_trainer`.
# Running
This section assumes that you use the executable provided by the installed package.
If you have a custom `experiment.py` script (as in the Option 2 above), replace the executable with the path to your script.
## Training
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
```
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
`--config-name` picks the config (in this case, `repro_singleseq_nerf.yaml`);
`test_when_finished` will launch evaluation script once training is finished.
Replace `<DATASET_ROOT>` with the location where the dataset in Implicitron format is stored
and `<CHECKPOINT_DIR>` with a directory where checkpoints will be dumped during training.
Other configuration parameters can be overridden in the same way.
See [Configuration system](#configuration-system) section for more information on this.
## Evaluation
To run evaluation on the latest checkpoint after (or during) training, simply add `eval_only=True` to your training command.
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
## Visualisation
The script produces a video of renders by a trained model assuming a pre-defined camera trajectory.
In order for it to work, `ffmpeg` needs to be installed:
```shell
conda install ffmpeg
```
Here is an example of calling the script:
```shell
projects/implicitron_trainer/visualize_reconstruction.py exp_dir=<CHECKPOINT_DIR> visdom_show_preds=True n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
```
The argument `n_eval_cameras` sets the number of renderring viewpoints sampled on a trajectory, which defaults to a circular fly-around;
`render_size` sets the size of a render passed to the model, which can be resized to `video_size` before writing.
Rendered videos of images, masks, and depth maps will be saved to `<CHECKPOINT_DIR>/vis`.
# Configuration system
We use hydra and OmegaConf to parse the configs.
The config schema and default values are defined by the dataclasses implementing the modules.
More specifically, if a class derives from `Configurable`, its fields can be set in config yaml files or overridden in CLI.
For example, `GenericModel` has a field `render_image_width` with the default value 400.
If it is specified in the yaml config file or in CLI command, the new value will be used.
Configurables can form hierarchies.
For example, `GenericModel` has a field `raysampler: RaySampler`, which is also Configurable.
In the config, inner parameters can be propagated using `_args` postfix, e.g. to change `raysampler.n_pts_per_ray_training` (the number of sampled points per ray), the node `raysampler_args.n_pts_per_ray_training` should be specified.
The root of the hierarchy is defined by `ExperimentConfig` dataclass.
It has top-level fields like `eval_only` which was used above for running evaluation by adding a CLI override.
Additionally, it has non-leaf nodes like `generic_model_args`, which dispatches the config parameters to `GenericModel`. Thus, changing the model parameters may be achieved in two ways: either by editing the config file, e.g.
```yaml
generic_model_args:
render_image_width: 800
raysampler_args:
n_pts_per_ray_training: 128
```
or, equivalently, by adding the following to `pytorch3d_implicitron_runner` arguments:
```shell
generic_model_args.render_image_width=800 generic_model_args.raysampler_args.n_pts_per_ray_training=128
```
See the documentation in `pytorch3d/implicitron/tools/config.py` for more details.
## Replaceable implementations
Sometimes changing the model parameters does not provide enough flexibility, and you want to provide a new implementation for a building block.
The configuration system also supports it!
Abstract classes like `BaseRenderer` derive from `ReplaceableBase` instead of `Configurable`.
This means that other Configurables can refer to them using the base type, while the specific implementation is chosen in the config using `_class_type`-postfixed node.
In that case, `_args` node name has to include the implementation type.
More specifically, to change renderer settings, the config will look like this:
```yaml
generic_model_args:
renderer_class_type: LSTMRenderer
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
hidden_size: 16
```
See the documentation in `pytorch3d/implicitron/tools/config.py` for more details on the configuration system.
## Custom plugins
If you have an idea for another implementation of a replaceable component, it can be plugged in without changing the core code.
For that, you need to set up Implicitron through option 2 or 3 above.
Let's say you want to implement a renderer that accumulates opacities similar to an X-ray machine.
First, create a module `x_ray_renderer.py` with a class deriving from `BaseRenderer`:
```python
from pytorch3d.implicitron.tools.config import registry
@registry.register
class XRayRenderer(BaseRenderer, torch.nn.Module):
n_pts_per_ray: int = 64
# if there are other base classes, make sure to call `super().__init__()` explicitly
def __post_init__(self):
super().__init__()
# custom initialization
def forward(
self,
ray_bundle,
implicit_functions=[],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:
...
```
Please note `@registry.register` decorator that registers the plug-in as an implementation of `Renderer`.
IMPORTANT: In order for it to run, the class (or its enclosing module) has to be imported in your launch script. Additionally, this has to be done before parsing the root configuration class `ExperimentConfig`.
Simply add `import .x_ray_renderer` in the beginning of `experiment.py`.
After that, you should be able to change the config with:
```yaml
generic_model_args:
renderer_class_type: XRayRenderer
renderer_XRayRenderer_args:
n_pts_per_ray: 128
```
to replace the implementation and potentially override the parameters.
# Code and config structure
As per above, the config structure is parsed automatically from the module hierarchy.
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node.
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
```
generic_model_args: GenericModel
└-- sequence_autodecoder_args: Autodecoder
└-- raysampler_args: RaySampler
└-- renderer_*_args: BaseRenderer
╘== MultiPassEmissionAbsorptionRenderer
╘== LSTMRenderer
╘== SignedDistanceFunctionRenderer
└-- ray_tracer_args: RayTracing
└-- ray_normal_coloring_network_args: RayNormalColoringNetwork
└-- implicit_function_*_args: ImplicitFunctionBase
╘== NeuralRadianceFieldImplicitFunction
╘== SRNImplicitFunction
└-- raymarch_function_args: SRNRaymarchFunction
└-- pixel_generator_args: SRNPixelGenerator
╘== SRNHyperNetImplicitFunction
└-- hypernet_args: SRNRaymarchHyperNet
└-- pixel_generator_args: SRNPixelGenerator
╘== IdrFeatureField
└-- image_feature_extractor_args: ResNetFeatureExtractor
└-- view_sampler_args: ViewSampler
└-- feature_aggregator_*_args: FeatureAggregatorBase
╘== IdentityFeatureAggregator
╘== AngleWeightedIdentityFeatureAggregator
╘== AngleWeightedReductionFeatureAggregator
╘== ReductionFeatureAggregator
solver_args: init_optimizer
dataset_args: dataset_zoo
dataloader_args: dataloader_zoo
```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
# Reproducing CO3D experiments
Common Objects in 3D (CO3D) is a large-scale dataset of videos of rigid objects grouped into 50 common categories.
Implicitron provides implementations and config files to reproduce the results from [the paper](https://arxiv.org/abs/2109.00512).
Please follow [the link](https://github.com/facebookresearch/co3d#automatic-batch-download) for the instructions to download the dataset.
In training and evaluation scripts, use the download location as `<DATASET_ROOT>`.
It is also possible to define environment variable `CO3D_DATASET_ROOT` instead of specifying it.
To reproduce the experiments from the paper, use the following configs. For single-sequence experiments:
| Method | config file |
|-----------------|-------------------------------------|
| NeRF | repro_singleseq_nerf.yaml |
| NeRF + WCE | repro_singleseq_nerf_wce.yaml |
| NerFormer | repro_singleseq_nerformer.yaml |
| IDR | repro_singleseq_idr.yaml |
| SRN | repro_singleseq_srn_noharm.yaml |
| SRN + γ | repro_singleseq_srn.yaml |
| SRN + WCE | repro_singleseq_srn_wce_noharm.yaml |
| SRN + WCE + γ | repro_singleseq_srn_wce_noharm.yaml |
For multi-sequence experiments (without generalisation to new sequences):
| Method | config file |
|-----------------|--------------------------------------------|
| NeRF + AD | repro_multiseq_nerf_ad.yaml |
| SRN + AD | repro_multiseq_srn_ad_hypernet_noharm.yaml |
| SRN + γ + AD | repro_multiseq_srn_ad_hypernet.yaml |
For multi-sequence experiments (with generalisation to new sequences):
| Method | config file |
|-----------------|--------------------------------------|
| NeRF + WCE | repro_multiseq_nerf_wce.yaml |
| NerFormer | repro_multiseq_nerformer.yaml |
| SRN + WCE | repro_multiseq_srn_wce_noharm.yaml |
| SRN + WCE + γ | repro_multiseq_srn_wce.yaml |

View File

@@ -0,0 +1,83 @@
defaults:
- default_config
- _self_
exp_dir: ./data/exps/base/
architecture: generic
visualize_interval: 0
visdom_port: 8097
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}"
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
generic_model_args:
loss_weights:
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 1.0
loss_autodecoder_norm: 0.01
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
output_rasterized_mc: false
chunk_size_grid: 102400
render_image_height: 400
render_image_width: 400
num_passes: 2
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_xyz: 256
n_hidden_neurons_dir: 128
n_layers_xyz: 8
append_xyz:
- 5
latent_dim: 0
raysampler_args:
n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.0
max_depth: 0.0
scene_extent: 8.0
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 64
n_pts_per_ray_fine_evaluation: 64
append_coarse_samples_to_fine: true
density_noise_std_train: 1.0
view_sampler_args:
masked_sampling: false
image_feature_extractor_args:
stages:
- 1
- 2
- 3
- 4
proj_dim: 16
image_rescale: 0.32
first_max_pool: false
solver_args:
breed: adam
lr: 0.0005
lr_policy: multistep
max_epochs: 2000
momentum: 0.9
weight_decay: 0.0

View File

@@ -0,0 +1,16 @@
generic_model_args:
image_feature_extractor_args:
add_images: true
add_masks: true
first_max_pool: true
image_rescale: 0.375
l2_norm: true
name: resnet34
normalize_image: true
pretrained: true
stages:
- 1
- 2
- 3
- 4
proj_dim: 32

View File

@@ -0,0 +1,16 @@
generic_model_args:
image_feature_extractor_args:
add_images: true
add_masks: true
first_max_pool: false
image_rescale: 0.375
l2_norm: true
name: resnet34
normalize_image: true
pretrained: true
stages:
- 1
- 2
- 3
- 4
proj_dim: 16

View File

@@ -0,0 +1,16 @@
generic_model_args:
image_feature_extractor_args:
stages:
- 1
- 2
- 3
first_max_pool: false
proj_dim: -1
l2_norm: false
image_rescale: 0.375
name: resnet34
normalize_image: true
pretrained: true
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
reduction_functions:
- AVG

View File

@@ -0,0 +1,31 @@
defaults:
- repro_base.yaml
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
assert_single_seq: false
dataset_name: co3d_multisequence
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
solver_args:
max_epochs: 3000
milestones:
- 1000

View File

@@ -0,0 +1,64 @@
defaults:
- repro_multiseq_base.yaml
- _self_
generic_model_args:
loss_weights:
loss_mask_bce: 100.0
loss_kl: 0.0
loss_rgb_mse: 1.0
loss_eikonal: 0.1
chunk_size_grid: 65536
num_passes: 1
output_rasterized_mc: true
sampling_mode_training: mask_sample
view_pool: false
sequence_autodecoder_args:
n_instances: 20000
init_scale: 1.0
encoding_dim: 256
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true
renderer_SignedDistanceFunctionRenderer_args:
ray_tracer_args:
line_search_step: 0.5
line_step_iters: 3
n_secant_steps: 8
n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
mode: idr
n_harmonic_functions_dir: 4
pooled_feature_dim: 0
weight_norm: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0
scene_extent: 8.0
renderer_class_type: SignedDistanceFunctionRenderer
implicit_function_class_type: IdrFeatureField

View File

@@ -0,0 +1,9 @@
defaults:
- repro_multiseq_base.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
sequence_autodecoder_args:
n_instances: 20000
encoding_dim: 256

View File

@@ -0,0 +1,10 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_unnormed.yaml
- _self_
clip_grad: 1.0
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@@ -0,0 +1,16 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_transformer.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@@ -0,0 +1,16 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_transformer.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator

View File

@@ -0,0 +1,32 @@
defaults:
- repro_multiseq_base.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
n_train_target_views: -1
num_passes: 1
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.001
depth_neg_penalty: 10000.0
sequence_autodecoder_args:
encoding_dim: 256
n_instances: 20000
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNHyperNetImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@@ -0,0 +1,10 @@
defaults:
- repro_multiseq_srn_ad_hypernet.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNHyperNetImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
hypernet_args:
n_harmonic_functions: 0

View File

@@ -0,0 +1,30 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_normed.yaml
- _self_
generic_model_args:
chunk_size_grid: 32000
view_pool: true
num_passes: 1
n_train_target_views: -1
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@@ -0,0 +1,10 @@
defaults:
- repro_multiseq_srn_wce.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0

View File

@@ -0,0 +1,41 @@
defaults:
- repro_base
- _self_
dataloader_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
dataset_args:
dataset_name: co3d_singlesequence
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0
test_on_train: false
generic_model_args:
render_image_height: 800
render_image_width: 800
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_eikonal
- loss_prev_stage_rgb_psnr
- loss_mask_bce
- loss_prev_stage_mask_bce
- loss_rgb_mse
- loss_prev_stage_rgb_mse
- loss_depth_abs
- loss_depth_abs_fg
- loss_kl
- loss_mask_neg_iou
- objective
- epoch
- sec/it
solver_args:
lr: 0.0005
max_epochs: 400
milestones:
- 200
- 300

View File

@@ -0,0 +1,57 @@
defaults:
- repro_singleseq_base
- _self_
generic_model_args:
loss_weights:
loss_mask_bce: 100.0
loss_kl: 0.0
loss_rgb_mse: 1.0
loss_eikonal: 0.1
chunk_size_grid: 65536
num_passes: 1
view_pool: false
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true
renderer_SignedDistanceFunctionRenderer_args:
ray_tracer_args:
line_search_step: 0.5
line_step_iters: 3
n_secant_steps: 8
n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
mode: idr
n_harmonic_functions_dir: 4
pooled_feature_dim: 0
weight_norm: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0
renderer_class_type: SignedDistanceFunctionRenderer
implicit_function_class_type: IdrFeatureField

View File

@@ -0,0 +1,4 @@
defaults:
- repro_singleseq_base
- _self_
exp_dir: ./data/nerf_single_apple/

View File

@@ -0,0 +1,9 @@
defaults:
- repro_singleseq_wce_base.yaml
- repro_feat_extractor_unnormed.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@@ -0,0 +1,16 @@
defaults:
- repro_singleseq_wce_base.yaml
- repro_feat_extractor_transformer.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
implicit_function_class_type: NeRFormerImplicitFunction
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@@ -0,0 +1,28 @@
defaults:
- repro_singleseq_base.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: false
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@@ -0,0 +1,10 @@
defaults:
- repro_singleseq_srn.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0

View File

@@ -0,0 +1,29 @@
defaults:
- repro_singleseq_wce_base
- repro_feat_extractor_normed.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: true
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@@ -0,0 +1,10 @@
defaults:
- repro_singleseq_srn_wce.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0

View File

@@ -0,0 +1,18 @@
defaults:
- repro_singleseq_base
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10

View File

@@ -0,0 +1,714 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
""""
This file is the entry point for launching experiments with Implicitron.
Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
pass, visualizations and metric printing
Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:
```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```
To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.
Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
- Checkpoints:
Model, optimizer and stats are stored in the directory
named by the `exp_dir` key from the config file / CLI parameters.
- Stats
Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file.
- Visualizations
Prredictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the
config file.
"""
import copy
import json
import logging
import os
import random
import time
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import hydra
import lpips
import numpy as np
import torch
import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
FrameData,
)
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
logger = logging.getLogger(__name__)
if version.parse(hydra.__version__) < version.Version("1.1"):
raise ValueError(
f"Hydra version {hydra.__version__} is too old."
" (Implicitron requires version 1.1 or later.)"
)
try:
# only makes sense in FAIR cluster
import pytorch3d.implicitron.fair_cluster.slurm # noqa: F401
except ModuleNotFoundError:
pass
def init_model(
cfg: DictConfig,
force_load: bool = False,
clear_stats: bool = False,
load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
"""
Returns an instance of `GenericModel`.
If `cfg.resume` is set or `force_load` is true,
attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
will return the model with initial weights, unless `force_load` is passed,
in which case a FileNotFoundError is raised.
Args:
force_load: If true, force load model from checkpoint even if
cfg.resume is false.
clear_stats: If true, clear the stats object loaded from checkpoint
load_model_only: If true, load only the model weights from checkpoint
and do not load the state of the optimizer and stats.
Returns:
model: The model with optionally loaded weights from checkpoint
stats: The stats structure (optionally loaded from checkpoint)
optimizer_state: The optimizer state dict containing
`state` and `param_groups` keys (optionally loaded from checkpoint)
Raise:
FileNotFoundError if `force_load` is passed but checkpoint is not found.
"""
# Initialize the model
if cfg.architecture == "generic":
model = GenericModel(**cfg.generic_model_args)
else:
raise ValueError(f"No such arch {cfg.architecture}.")
# Determine the network outputs that should be logged
if hasattr(model, "log_vars"):
log_vars = copy.deepcopy(list(model.log_vars))
else:
log_vars = ["objective"]
visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"
# Init the stats struct
stats = Stats(
log_vars,
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=cfg.visdom_server,
visdom_port=cfg.visdom_port,
)
# Retrieve the last checkpoint
if cfg.resume_epoch > 0:
model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(cfg.exp_dir)
optimizer_state = None
if model_path is not None:
logger.info("found previous model %s" % model_path)
if force_load or cfg.resume:
logger.info(" -> resuming")
if load_model_only:
model_state_dict = torch.load(model_io.get_model_path(model_path))
stats_load, optimizer_state = None, None
else:
model_state_dict, stats_load, optimizer_state = model_io.load_model(
model_path
)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = cfg.visdom_server
stats.visdom_port = cfg.visdom_port
stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
try:
# TODO: fix on creation of the buffers
# after the hack above, this will not pass in most cases
# ... but this is fine for now
model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info("Cant load state dict in strict mode! -> trying non-strict")
model.load_state_dict(model_state_dict, strict=False)
model.log_vars = log_vars
else:
logger.info(" -> but not resuming -> starting from scratch")
elif force_load:
raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")
return model, stats, optimizer_state
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: bool = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
# Add the max epochs here
scheduler.max_epochs = max_epochs
optimizer.zero_grad()
return optimizer, scheduler
def trainvalidate(
model,
stats,
epoch,
loader,
optimizer,
validation,
bp_var: str = "objective",
metric_print_interval: int = 5,
visualize_interval: int = 100,
visdom_env_root: str = "trainvalidate",
clip_grad: float = 0.0,
device: str = "cuda:0",
**kwargs,
) -> None:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args:
model: The model module optionally loaded from checkpoint
stats: The stats struct, also optionally loaded from checkpoint
epoch: The index of the current epoch
loader: The dataloader to use for the loop
optimizer: The optimizer module optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env_root: The name of the visdom environment to use for plotting
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
device: The device on which to run the model.
Returns:
None
"""
if validation:
model.eval()
trainmode = "val"
else:
model.train()
trainmode = "train"
t_start = time.time()
# get the visdom env name
visdom_env_imgs = visdom_env_root + "_images_" + trainmode
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
# Iterate through the batches
n_batches = len(loader)
for it, batch in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = batch.to(device)
# run the forward pass
if not validation:
optimizer.zero_grad()
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
else:
with torch.no_grad():
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
)
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
# merge everything into one big dict
preds.update(net_input)
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if it % metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if visualize_interval > 0 and it % visualize_interval == 0:
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
model.visualize(
viz,
visdom_env_imgs,
preds,
prefix,
)
# optimizer step
if not validation:
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
loss.backward()
if clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
model.parameters(), clip_grad
)
if total_norm > clip_grad:
logger.info(
f"Clipping gradient: {total_norm}"
+ f" with coef {clip_grad / total_norm}."
)
optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu"):
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
# set the debug mode
if cfg.detect_anomaly:
logger.info("Anomaly detection!")
torch.autograd.set_detect_anomaly(cfg.detect_anomaly)
# create the output folder
os.makedirs(cfg.exp_dir, exist_ok=True)
_seed_all_random_engines(cfg.seed)
remove_unused_components(cfg)
# dump the exp config to the exp dir
try:
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
OmegaConf.save(config=cfg, f=cfg_filename)
except PermissionError:
warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
# init the model
model, stats, optimizer_state = init_model(cfg)
start_epoch = stats.epoch + 1
# move model to gpu
model.to(device)
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
return
# init the optimizer
optimizer, scheduler = init_optimizer(
model,
optimizer_state=optimizer_state,
last_epoch=start_epoch,
**cfg.solver_args,
)
# check the scheduler and stats have been initialized correctly
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
past_scheduler_lrs = []
# loop through epochs
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start
with stats:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
_seed_all_random_engines(cfg.seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1])
logger.info(f"scheduler lr = {cur_lr:1.2e}")
past_scheduler_lrs.append(cur_lr)
# train loop
trainvalidate(
model,
stats,
epoch,
dataloaders["train"],
optimizer,
False,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
**cfg,
)
# val loop (optional)
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
trainvalidate(
model,
stats,
epoch,
dataloaders["val"],
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
**cfg,
)
# eval loop (optional)
if (
"test" in dataloaders
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders["test"], device=device)
assert stats.epoch == epoch, "inconsistent stats!"
# delete previous models if required
# save model
if cfg.store_checkpoints:
if cfg.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
scheduler.step()
new_lr = float(scheduler.get_last_lr()[-1])
if new_lr != cur_lr:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
"""
Run the evaluation loop with the test data loader and
save the predictions to the `exp_dir`.
"""
if "test" not in dataloaders:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
all_source_cameras = (
_get_all_source_cameras(datasets["train"])
if eval_task == "singlesequence"
else None
)
results = run_eval(
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
)
# add the evaluation epoch to the results
for r in results:
r["eval_epoch"] = int(stats.epoch)
logger.info("Evaluation results")
evaluate.pretty_print_nvs_metrics(results)
with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
json.dump(results, f)
def _get_eval_frame_data(frame_data):
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval = copy.deepcopy(frame_data)
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
frame_data.image_rgb
)[:, None, None, None]
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
setattr(frame_data_for_eval, k, value_masked)
return frame_data_for_eval
def run_eval(cfg, model, all_source_cameras, loader, task, device):
"""
Run the evaluation loop on the test dataloader
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(loader):
frame_data = frame_data.to(device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
with torch.no_grad():
preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
)
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
nvs_prediction,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_source_cameras,
)
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task
)
return category_result["results"]
def _get_all_source_cameras(
dataset: ImplicitronDataset,
num_workers: int = 8,
) -> CamerasBase:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data = next(
iter(
torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=len(dataset),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
)
)
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
return source_cameras
def _seed_all_random_engines(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
@dataclass(eq=False)
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
# Set the device
device = "cpu"
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
device = f"cuda:{cfg.gpu_idx}"
logger.info(f"Running experiment on device: {device}")
run_training(cfg, device)
if __name__ == "__main__":
experiment()

View File

@@ -0,0 +1,382 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
import math
import os
import random
import sys
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.implicitron.tools.video_writer import VideoWriter
from pytorch3d.implicitron.tools.vis_utils import (
get_visdom_connection,
make_depth_image,
)
from tqdm import tqdm
def render_sequence(
dataset: ImplicitronDataset,
sequence_name: str,
model: torch.nn.Module,
video_path,
n_eval_cameras=40,
fps=20,
max_angle=2 * math.pi,
trajectory_type="circular_lsq_fit",
trajectory_scale=1.1,
scene_center=(0.0, 0.0, 0.0),
up=(0.0, -1.0, 0.0),
traj_offset=0.0,
n_source_views=9,
viz_env="debug",
visdom_show_preds=False,
visdom_server="http://127.0.0.1",
visdom_port=8097,
num_workers=10,
seed=None,
video_resize=None,
):
if seed is None:
seed = hash(sequence_name)
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = dataset.seq_to_idx[sequence_name]
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
print(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
test_cameras = generate_eval_video_cameras(
train_cameras,
time=time,
n_eval_cams=n_eval_cameras,
trajectory_type=trajectory_type,
trajectory_scale=trajectory_scale,
scene_center=scene_center,
up=up,
focal_length=None,
principal_point=torch.zeros(n_eval_cameras, 2),
traj_offset_canonical=[0.0, 0.0, traj_offset],
)
# sample the source views reproducibly
with torch.random.fork_rng():
torch.manual_seed(seed)
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
# add the first dummy view that will get replaced with the target camera
source_views_i = Fu.pad(source_views_i, [1, 0])
source_views = [seq_idx[i] for i in source_views_i.tolist()]
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
preds_total = []
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
# set the first batch camera to the target camera
for k in ("R", "T", "focal_length", "principal_point"):
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
# Move to cuda
net_input = batch.cuda()
with torch.no_grad():
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
preds.update(net_input) # merge everything into one big dict
# Render the predictions to images
rendered_pred = images_from_preds(preds)
preds_total.append(rendered_pred)
# show the preds every 5% of the export iterations
if visdom_show_preds and (
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
):
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
show_predictions(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
)
print(f"Exporting videos for sequence {sequence_name} ...")
generate_prediction_videos(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
fps=fps,
video_path=video_path,
resize=video_resize,
)
def _load_whole_dataset(dataset, idx, num_workers=10):
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
num_workers=num_workers,
shuffle=False,
collate_fn=FrameData.collate,
)
return next(iter(load_all_dataloader))
def images_from_preds(preds):
imout = {}
for k in (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
"_all_source_images",
):
if k == "_all_source_images" and "image_rgb" in preds:
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
v = _stack_images(src_ims, None)[None]
else:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
mask_resize = Fu.interpolate(
preds["masks_render"],
size=preds[k].shape[2:],
mode="nearest",
)
v = make_depth_image(preds[k], mask_resize)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
imout[k] = v.detach().cpu()
return imout
def _stack_images(ims, size):
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
n_add = H * W - ba
if n_add > 0:
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
ims = ims.view(H, W, *ims.shape[1:])
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
if size is not None:
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
return cated.clamp(0.0, 1.0)
def show_predictions(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
n_samples=10,
one_image_width=200,
):
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)
pred_all = []
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
n_samples = min(n_samples, len(preds))
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
for predi in pred_idx:
# Make the concatentation for the same camera vertically
pred_all.append(
torch.cat(
[
torch.nn.functional.interpolate(
preds[predi][k].cpu(),
scale_factor=one_image_width / preds[predi][k].shape[3],
mode="bilinear",
).clamp(0.0, 1.0)
for k in predicted_keys
],
dim=2,
)
)
# Concatenate the images horizontally
pred_all_cat = torch.cat(pred_all, dim=3)[0]
viz.image(
pred_all_cat,
win="show_predictions",
env=viz_env,
opts={"title": f"pred_{sequence_name}"},
)
def generate_prediction_videos(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
fps=20,
video_path="/tmp/video",
resize=None,
):
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
assert isinstance(preds, list)
# make sure the target video directory exists
os.makedirs(os.path.dirname(video_path), exist_ok=True)
# init a video writer for each predicted key
vws = {}
for k in predicted_keys:
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
for rendered_pred in tqdm(preds):
for k in predicted_keys:
vws[k].write_frame(
rendered_pred[k][0].detach().cpu().numpy(),
resize=resize,
)
for k in predicted_keys:
vws[k].get_video(quiet=True)
print(f"Generated {vws[k].out_path}.")
viz.video(
videofile=vws[k].out_path,
env=viz_env,
win=k, # we reuse the same window otherwise visdom dies
opts={"title": sequence_name + " " + k},
)
def export_scenes(
exp_dir: str = "",
restrict_sequence_name: Optional[str] = None,
output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test
n_source_views: int = 9,
n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1",
visdom_port=8097,
visdom_show_preds: bool = False,
visdom_env: Optional[str] = None,
gpu_idx: int = 0,
):
# In case an output directory is specified use it. If no output_directory
# is specified create a vis folder inside the experiment directory
if output_directory is None:
output_directory = os.path.join(exp_dir, "vis")
else:
output_directory = output_directory
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Set the random seeds
torch.manual_seed(0)
np.random.seed(0)
# Get the config from the experiment_directory,
# and overwrite relevant fields
config = _get_config_from_experiment_directory(exp_dir)
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
# Load the previously trained model
model, _, _ = init_model(config, force_load=True, load_model_only=True)
model.cuda()
model.eval()
# Setup the dataset
dataset = dataset_zoo(**config.dataset_args)[split]
# iterate over the sequences in the dataset
for sequence_name in dataset.seq_to_idx.keys():
with torch.no_grad():
render_sequence(
dataset,
sequence_name,
model,
video_path="{}/video".format(output_directory),
n_source_views=n_source_views,
visdom_show_preds=visdom_show_preds,
n_eval_cameras=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
viz_env=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
video_resize=video_size,
)
def _get_config_from_experiment_directory(experiment_directory):
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
return config
def main(argv):
# automatically parses arguments of export_scenes
cfg = OmegaConf.create(get_default_args(export_scenes))
cfg.update(OmegaConf.from_cli())
with torch.no_grad():
export_scenes(**cfg)
if __name__ == "__main__":
main(sys.argv)