mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-30 02:28:56 +08:00
visualize_reconstruction fixes
Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it. Reviewed By: kjchalup Differential Revision: D39286691 fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
This commit is contained in:
committed by
Facebook GitHub Bot
parent
34ad77b841
commit
6e25fe8cb3
@@ -9,7 +9,7 @@
|
||||
# provide data for a single scene.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Iterable, List, Optional
|
||||
from typing import Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -46,6 +46,12 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
def __len__(self) -> int:
|
||||
return len(self.poses)
|
||||
|
||||
def sequence_frames_in_order(
|
||||
self, seq_name: str
|
||||
) -> Iterator[Tuple[float, int, int]]:
|
||||
for i in range(len(self)):
|
||||
yield (0.0, i, i)
|
||||
|
||||
def __getitem__(self, index) -> FrameData:
|
||||
if index >= len(self):
|
||||
raise IndexError(f"index {index} out of range {len(self)}")
|
||||
|
||||
@@ -61,7 +61,7 @@ def render_flyaround(
|
||||
"depths_render",
|
||||
"_all_source_images",
|
||||
),
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Uses `model` to generate a video consisting of renders of a scene imaged from
|
||||
a camera flying around the scene. The scene is specified with the `dataset` object and
|
||||
@@ -133,6 +133,7 @@ def render_flyaround(
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
|
||||
# pyre-ignore[6]
|
||||
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
|
||||
logger.info(f"Sequence set = {sequence_set_name}.")
|
||||
train_cameras = train_data.camera
|
||||
@@ -209,7 +210,7 @@ def render_flyaround(
|
||||
|
||||
def _load_whole_dataset(
|
||||
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
|
||||
):
|
||||
) -> FrameData:
|
||||
load_all_dataloader = torch.utils.data.DataLoader(
|
||||
torch.utils.data.Subset(dataset, idx),
|
||||
batch_size=len(idx),
|
||||
@@ -220,7 +221,7 @@ def _load_whole_dataset(
|
||||
return next(iter(load_all_dataloader))
|
||||
|
||||
|
||||
def _images_from_preds(preds: Dict[str, Any]):
|
||||
def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
|
||||
imout = {}
|
||||
for k in (
|
||||
"image_rgb",
|
||||
@@ -253,7 +254,7 @@ def _images_from_preds(preds: Dict[str, Any]):
|
||||
return imout
|
||||
|
||||
|
||||
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
|
||||
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
|
||||
ba = ims.shape[0]
|
||||
H = int(np.ceil(np.sqrt(ba)))
|
||||
W = H
|
||||
@@ -281,7 +282,7 @@ def _show_predictions(
|
||||
),
|
||||
n_samples=10,
|
||||
one_image_width=200,
|
||||
):
|
||||
) -> None:
|
||||
"""Given a list of predictions visualize them into a single image using visdom."""
|
||||
assert isinstance(preds, list)
|
||||
|
||||
@@ -329,7 +330,7 @@ def _generate_prediction_videos(
|
||||
video_path: str = "/tmp/video",
|
||||
video_frames_dir: Optional[str] = None,
|
||||
resize: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Given a list of predictions create and visualize rotating videos of the
|
||||
objects using visdom.
|
||||
"""
|
||||
@@ -359,7 +360,7 @@ def _generate_prediction_videos(
|
||||
)
|
||||
|
||||
for k in predicted_keys:
|
||||
vws[k].get_video(quiet=True)
|
||||
vws[k].get_video()
|
||||
logger.info(f"Generated {vws[k].out_path}.")
|
||||
if viz is not None:
|
||||
viz.video(
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union
|
||||
@@ -15,6 +16,7 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
@@ -27,13 +29,13 @@ class VideoWriter:
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Optional[str] = None,
|
||||
ffmpeg_bin: str = "ffmpeg",
|
||||
ffmpeg_bin: str = _DEFAULT_FFMPEG,
|
||||
out_path: str = "/tmp/video.mp4",
|
||||
fps: int = 20,
|
||||
output_format: str = "visdom",
|
||||
rmdir_allowed: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
cache_dir: A directory for storing the video frames. If `None`,
|
||||
@@ -74,7 +76,7 @@ class VideoWriter:
|
||||
self,
|
||||
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
|
||||
resize: Optional[Union[float, Tuple[int, int]]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Write a frame to the video.
|
||||
|
||||
@@ -114,7 +116,7 @@ class VideoWriter:
|
||||
self.frames.append(outfile)
|
||||
self.frame_num += 1
|
||||
|
||||
def get_video(self, quiet: bool = True):
|
||||
def get_video(self) -> str:
|
||||
"""
|
||||
Generate the video from the written frames.
|
||||
|
||||
@@ -127,23 +129,39 @@ class VideoWriter:
|
||||
|
||||
regexp = os.path.join(self.cache_dir, self.regexp)
|
||||
|
||||
if self.output_format == "visdom": # works for ppt too
|
||||
ffmcmd_ = (
|
||||
"%s -r %d -i %s -vcodec h264 -f mp4 \
|
||||
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
|
||||
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
if self.output_format == "visdom": # works for ppt too
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
"h264",
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
"18",
|
||||
"-b",
|
||||
"2000k",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
self.out_path,
|
||||
]
|
||||
|
||||
subprocess.check_call(args)
|
||||
else:
|
||||
raise ValueError("no such output type %s" % str(self.output_format))
|
||||
|
||||
if quiet:
|
||||
ffmcmd_ += " > /dev/null 2>&1"
|
||||
else:
|
||||
print(ffmcmd_)
|
||||
os.system(ffmcmd_)
|
||||
|
||||
return self.out_path
|
||||
|
||||
def __del__(self):
|
||||
def __del__(self) -> None:
|
||||
if self.tmp_dir is not None:
|
||||
self.tmp_dir.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user