visualize_reconstruction fixes

Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it.

Reviewed By: kjchalup

Differential Revision: D39286691

fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
This commit is contained in:
Jeremy Reizenstein
2022-09-07 20:10:07 -07:00
committed by Facebook GitHub Bot
parent 34ad77b841
commit 6e25fe8cb3
8 changed files with 125 additions and 58 deletions

View File

@@ -5,10 +5,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
"""
Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
pytorch3d_implicitron_visualizer \
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 \
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
@@ -18,9 +19,9 @@ from typing import Optional, Tuple
import numpy as np
import torch
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.visualization import render_flyaround
from pytorch3d.implicitron.tools.configurable import get_default_args
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
from pytorch3d.implicitron.tools.config import enable_get_default_args, get_default_args
from .experiment import Experiment
@@ -38,7 +39,7 @@ def visualize_reconstruction(
visdom_server: str = "http://127.0.0.1",
visdom_port: int = 8097,
visdom_env: Optional[str] = None,
):
) -> None:
"""
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
of renderes of sequences from the dataset used to train and evaluate the trained
@@ -76,22 +77,27 @@ def visualize_reconstruction(
config = _get_config_from_experiment_directory(exp_dir)
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
dataset_args = (
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
data_source_args = config.data_source_ImplicitronDataSource_args
if "dataset_map_provider_JsonIndexDatasetMapProvider_args" in data_source_args:
dataset_args = (
data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
if restrict_sequence_name is not None:
dataset_args.restrict_sequence_name = restrict_sequence_name
# Set the rendering image size
model_factory_args = config.model_factory_ImplicitronModelFactory_args
model_factory_args.force_resume = True
model_args = model_factory_args.model_GenericModel_args
model_args.render_image_width = render_size[0]
model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
dataset_args.restrict_sequence_name = restrict_sequence_name
# Load the previously trained model
experiment = Experiment(config)
model = experiment.model_factory(force_resume=True)
model.cuda()
experiment = Experiment(**config)
model = experiment.model_factory(exp_dir=exp_dir)
device = torch.device("cuda")
model.to(device)
model.eval()
# Setup the dataset
@@ -101,6 +107,11 @@ def visualize_reconstruction(
if dataset is None:
raise ValueError(f"{split} dataset not provided")
if visdom_env is None:
visdom_env = (
"visualizer_" + config.training_loop_ImplicitronTrainingLoop_args.visdom_env
)
# iterate over the sequences in the dataset
for sequence_name in dataset.sequence_names():
with torch.no_grad():
@@ -114,23 +125,26 @@ def visualize_reconstruction(
n_flyaround_poses=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
visdom_environment=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
visdom_environment=visdom_env,
video_resize=video_size,
device=device,
)
def _get_config_from_experiment_directory(experiment_directory):
enable_get_default_args(visualize_reconstruction)
def _get_config_from_experiment_directory(experiment_directory) -> DictConfig:
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
# pyre-ignore[7]
return config
def main(argv):
def main(argv) -> None:
# automatically parses arguments of visualize_reconstruction
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
cfg.update(OmegaConf.from_cli())
cfg.update(OmegaConf.from_cli(argv))
with torch.no_grad():
visualize_reconstruction(**cfg)