visualize_reconstruction fixes

Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it.

Reviewed By: kjchalup

Differential Revision: D39286691

fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
This commit is contained in:
Jeremy Reizenstein
2022-09-07 20:10:07 -07:00
committed by Facebook GitHub Bot
parent 34ad77b841
commit 6e25fe8cb3
8 changed files with 125 additions and 58 deletions

View File

@@ -9,7 +9,7 @@
# provide data for a single scene.
from dataclasses import field
from typing import Iterable, List, Optional
from typing import Iterable, Iterator, List, Optional, Tuple
import numpy as np
import torch
@@ -46,6 +46,12 @@ class SingleSceneDataset(DatasetBase, Configurable):
def __len__(self) -> int:
return len(self.poses)
def sequence_frames_in_order(
self, seq_name: str
) -> Iterator[Tuple[float, int, int]]:
for i in range(len(self)):
yield (0.0, i, i)
def __getitem__(self, index) -> FrameData:
if index >= len(self):
raise IndexError(f"index {index} out of range {len(self)}")

View File

@@ -61,7 +61,7 @@ def render_flyaround(
"depths_render",
"_all_source_images",
),
):
) -> None:
"""
Uses `model` to generate a video consisting of renders of a scene imaged from
a camera flying around the scene. The scene is specified with the `dataset` object and
@@ -133,6 +133,7 @@ def render_flyaround(
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
# pyre-ignore[6]
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
logger.info(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
@@ -209,7 +210,7 @@ def render_flyaround(
def _load_whole_dataset(
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
):
) -> FrameData:
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
@@ -220,7 +221,7 @@ def _load_whole_dataset(
return next(iter(load_all_dataloader))
def _images_from_preds(preds: Dict[str, Any]):
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
imout = {}
for k in (
"image_rgb",
@@ -253,7 +254,7 @@ def _images_from_preds(preds: Dict[str, Any]):
return imout
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
@@ -281,7 +282,7 @@ def _show_predictions(
),
n_samples=10,
one_image_width=200,
):
) -> None:
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)
@@ -329,7 +330,7 @@ def _generate_prediction_videos(
video_path: str = "/tmp/video",
video_frames_dir: Optional[str] = None,
resize: Optional[Tuple[int, int]] = None,
):
) -> None:
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
@@ -359,7 +360,7 @@ def _generate_prediction_videos(
)
for k in predicted_keys:
vws[k].get_video(quiet=True)
vws[k].get_video()
logger.info(f"Generated {vws[k].out_path}.")
if viz is not None:
viz.video(

View File

@@ -6,6 +6,7 @@
import os
import shutil
import subprocess
import tempfile
import warnings
from typing import Optional, Tuple, Union
@@ -15,6 +16,7 @@ import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
matplotlib.use("Agg")
@@ -27,13 +29,13 @@ class VideoWriter:
def __init__(
self,
cache_dir: Optional[str] = None,
ffmpeg_bin: str = "ffmpeg",
ffmpeg_bin: str = _DEFAULT_FFMPEG,
out_path: str = "/tmp/video.mp4",
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
**kwargs,
):
) -> None:
"""
Args:
cache_dir: A directory for storing the video frames. If `None`,
@@ -74,7 +76,7 @@ class VideoWriter:
self,
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
resize: Optional[Union[float, Tuple[int, int]]] = None,
):
) -> None:
"""
Write a frame to the video.
@@ -114,7 +116,7 @@ class VideoWriter:
self.frames.append(outfile)
self.frame_num += 1
def get_video(self, quiet: bool = True):
def get_video(self) -> str:
"""
Generate the video from the written frames.
@@ -127,23 +129,39 @@ class VideoWriter:
regexp = os.path.join(self.cache_dir, self.regexp)
if self.output_format == "visdom": # works for ppt too
ffmcmd_ = (
"%s -r %d -i %s -vcodec h264 -f mp4 \
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
if self.output_format == "visdom": # works for ppt too
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
"h264",
"-f",
"mp4",
"-y",
"-crf",
"18",
"-b",
"2000k",
"-pix_fmt",
"yuv420p",
self.out_path,
]
subprocess.check_call(args)
else:
raise ValueError("no such output type %s" % str(self.output_format))
if quiet:
ffmcmd_ += " > /dev/null 2>&1"
else:
print(ffmcmd_)
os.system(ffmcmd_)
return self.out_path
def __del__(self):
def __del__(self) -> None:
if self.tmp_dir is not None:
self.tmp_dir.cleanup()