visualize_reconstruction fixes

Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it.

Reviewed By: kjchalup

Differential Revision: D39286691

fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
This commit is contained in:
Jeremy Reizenstein 2022-09-07 20:10:07 -07:00 committed by Facebook GitHub Bot
parent 34ad77b841
commit 6e25fe8cb3
8 changed files with 125 additions and 58 deletions

View File

@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
$dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' \
$dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
@ -87,7 +87,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
$dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' \
$dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True

View File

@ -13,16 +13,7 @@ from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf
from .. import experiment
from .utils import intercept_logs
def interactive_testing_requested() -> bool:
"""
Certain tests are only useful when run interactively, and so are not regularly run.
These are activated by this funciton returning True, which the user requests by
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
"""
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
from .utils import interactive_testing_requested, intercept_logs
internal = os.environ.get("FB_TEST", False)

View File

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from .. import visualize_reconstruction
from .utils import interactive_testing_requested
internal = os.environ.get("FB_TEST", False)
class TestVisualize(unittest.TestCase):
def test_from_defaults(self):
if not interactive_testing_requested():
return
checkpoint_dir = os.environ["exp_dir"]
argv = [
f"exp_dir={checkpoint_dir}",
"n_eval_cameras=40",
"render_size=[64,64]",
"video_size=[256,256]",
]
visualize_reconstruction.main(argv)

View File

@ -6,6 +6,7 @@
import contextlib
import logging
import os
import re
@ -28,3 +29,12 @@ def intercept_logs(logger_name: str, regexp: str):
yield intercepted_messages
finally:
logger.removeFilter(interceptor)
def interactive_testing_requested() -> bool:
"""
Certain tests are only useful when run interactively, and so are not regularly run.
These are activated by this funciton returning True, which the user requests by
setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
"""
return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"

View File

@ -5,10 +5,11 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
"""
Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
pytorch3d_implicitron_visualizer \
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 \
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
@ -18,9 +19,9 @@ from typing import Optional, Tuple
import numpy as np
import torch
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.visualization import render_flyaround
from pytorch3d.implicitron.tools.configurable import get_default_args
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
from pytorch3d.implicitron.tools.config import enable_get_default_args, get_default_args
from .experiment import Experiment
@ -38,7 +39,7 @@ def visualize_reconstruction(
visdom_server: str = "http://127.0.0.1",
visdom_port: int = 8097,
visdom_env: Optional[str] = None,
):
) -> None:
"""
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
of renderes of sequences from the dataset used to train and evaluate the trained
@ -76,22 +77,27 @@ def visualize_reconstruction(
config = _get_config_from_experiment_directory(exp_dir)
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
data_source_args = config.data_source_ImplicitronDataSource_args
if "dataset_map_provider_JsonIndexDatasetMapProvider_args" in data_source_args:
dataset_args = (
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
# Set the rendering image size
model_factory_args = config.model_factory_ImplicitronModelFactory_args
model_args = model_factory_args.model_GenericModel_args
model_args.render_image_width = render_size[0]
model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
dataset_args.restrict_sequence_name = restrict_sequence_name
# Set the rendering image size
model_factory_args = config.model_factory_ImplicitronModelFactory_args
model_factory_args.force_resume = True
model_args = model_factory_args.model_GenericModel_args
model_args.render_image_width = render_size[0]
model_args.render_image_height = render_size[1]
# Load the previously trained model
experiment = Experiment(config)
model = experiment.model_factory(force_resume=True)
model.cuda()
experiment = Experiment(**config)
model = experiment.model_factory(exp_dir=exp_dir)
device = torch.device("cuda")
model.to(device)
model.eval()
# Setup the dataset
@ -101,6 +107,11 @@ def visualize_reconstruction(
if dataset is None:
raise ValueError(f"{split} dataset not provided")
if visdom_env is None:
visdom_env = (
"visualizer_" + config.training_loop_ImplicitronTrainingLoop_args.visdom_env
)
# iterate over the sequences in the dataset
for sequence_name in dataset.sequence_names():
with torch.no_grad():
@ -114,23 +125,26 @@ def visualize_reconstruction(
n_flyaround_poses=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
visdom_environment=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
visdom_environment=visdom_env,
video_resize=video_size,
device=device,
)
def _get_config_from_experiment_directory(experiment_directory):
enable_get_default_args(visualize_reconstruction)
def _get_config_from_experiment_directory(experiment_directory) -> DictConfig:
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
# pyre-ignore[7]
return config
def main(argv):
def main(argv) -> None:
# automatically parses arguments of visualize_reconstruction
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
cfg.update(OmegaConf.from_cli())
cfg.update(OmegaConf.from_cli(argv))
with torch.no_grad():
visualize_reconstruction(**cfg)

View File

@ -9,7 +9,7 @@
# provide data for a single scene.
from dataclasses import field
from typing import Iterable, List, Optional
from typing import Iterable, Iterator, List, Optional, Tuple
import numpy as np
import torch
@ -46,6 +46,12 @@ class SingleSceneDataset(DatasetBase, Configurable):
def __len__(self) -> int:
return len(self.poses)
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
for i in range(len(self)):
yield (0.0, i, i)
def __getitem__(self, index) -> FrameData:
if index >= len(self):
raise IndexError(f"index {index} out of range {len(self)}")

View File

@ -61,7 +61,7 @@ def render_flyaround(
"depths_render",
"_all_source_images",
),
):
) -> None:
"""
Uses `model` to generate a video consisting of renders of a scene imaged from
a camera flying around the scene. The scene is specified with the `dataset` object and
@ -133,6 +133,7 @@ def render_flyaround(
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
# pyre-ignore[6]
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
logger.info(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
@ -209,7 +210,7 @@ def render_flyaround(
def _load_whole_dataset(
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
):
) -> FrameData:
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
@ -220,7 +221,7 @@ def _load_whole_dataset(
return next(iter(load_all_dataloader))
def _images_from_preds(preds: Dict[str, Any]):
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
imout = {}
for k in (
"image_rgb",
@ -253,7 +254,7 @@ def _images_from_preds(preds: Dict[str, Any]):
return imout
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
@ -281,7 +282,7 @@ def _show_predictions(
),
n_samples=10,
one_image_width=200,
):
) -> None:
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)
@ -329,7 +330,7 @@ def _generate_prediction_videos(
video_path: str = "/tmp/video",
video_frames_dir: Optional[str] = None,
resize: Optional[Tuple[int, int]] = None,
):
) -> None:
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
@ -359,7 +360,7 @@ def _generate_prediction_videos(
)
for k in predicted_keys:
vws[k].get_video(quiet=True)
vws[k].get_video()
logger.info(f"Generated {vws[k].out_path}.")
if viz is not None:
viz.video(

View File

@ -6,6 +6,7 @@
import os
import shutil
import subprocess
import tempfile
import warnings
from typing import Optional, Tuple, Union
@ -15,6 +16,7 @@ import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
matplotlib.use("Agg")
@ -27,13 +29,13 @@ class VideoWriter:
def __init__(
self,
cache_dir: Optional[str] = None,
ffmpeg_bin: str = "ffmpeg",
ffmpeg_bin: str = _DEFAULT_FFMPEG,
out_path: str = "/tmp/video.mp4",
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
**kwargs,
):
) -> None:
"""
Args:
cache_dir: A directory for storing the video frames. If `None`,
@ -74,7 +76,7 @@ class VideoWriter:
self,
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
resize: Optional[Union[float, Tuple[int, int]]] = None,
):
) -> None:
"""
Write a frame to the video.
@ -114,7 +116,7 @@ class VideoWriter:
self.frames.append(outfile)
self.frame_num += 1
def get_video(self, quiet: bool = True):
def get_video(self) -> str:
"""
Generate the video from the written frames.
@ -127,23 +129,39 @@ class VideoWriter:
regexp = os.path.join(self.cache_dir, self.regexp)
if self.output_format == "visdom": # works for ppt too
ffmcmd_ = (
"%s -r %d -i %s -vcodec h264 -f mp4 \
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
if self.output_format == "visdom": # works for ppt too
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
"h264",
"-f",
"mp4",
"-y",
"-crf",
"18",
"-b",
"2000k",
"-pix_fmt",
"yuv420p",
self.out_path,
]
subprocess.check_call(args)
else:
raise ValueError("no such output type %s" % str(self.output_format))
if quiet:
ffmcmd_ += " > /dev/null 2>&1"
else:
print(ffmcmd_)
os.system(ffmcmd_)
return self.out_path
def __del__(self):
def __del__(self) -> None:
if self.tmp_dir is not None:
self.tmp_dir.cleanup()