mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Rename and move render_flyaround into core implicitron
Summary: Move the flyaround rendering function into core implicitron. The unblocks an example in the facebookresearch/co3d repo. Reviewed By: bottler Differential Revision: D39257801 fbshipit-source-id: 6841a88a43d4aa364dd86ba83ca2d4c3cf0435a4
This commit is contained in:
parent
438c194ec6
commit
c79c954dea
@ -12,311 +12,60 @@
|
||||
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as Fu
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
from pytorch3d.implicitron.models.visualization import render_flyaround
|
||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||
generate_eval_video_cameras,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.video_writer import VideoWriter
|
||||
from pytorch3d.implicitron.tools.vis_utils import (
|
||||
get_visdom_connection,
|
||||
make_depth_image,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from .experiment import Experiment
|
||||
|
||||
|
||||
def render_sequence(
|
||||
dataset: DatasetBase,
|
||||
sequence_name: str,
|
||||
model: torch.nn.Module,
|
||||
video_path,
|
||||
n_eval_cameras=40,
|
||||
fps=20,
|
||||
max_angle=2 * math.pi,
|
||||
trajectory_type="circular_lsq_fit",
|
||||
trajectory_scale=1.1,
|
||||
scene_center=(0.0, 0.0, 0.0),
|
||||
up=(0.0, -1.0, 0.0),
|
||||
traj_offset=0.0,
|
||||
n_source_views=9,
|
||||
viz_env="debug",
|
||||
visdom_show_preds=False,
|
||||
visdom_server="http://127.0.0.1",
|
||||
visdom_port=8097,
|
||||
num_workers=10,
|
||||
seed=None,
|
||||
video_resize=None,
|
||||
):
|
||||
if seed is None:
|
||||
seed = hash(sequence_name)
|
||||
|
||||
if visdom_show_preds:
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
else:
|
||||
viz = None
|
||||
|
||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
|
||||
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
|
||||
print(f"Sequence set = {sequence_set_name}.")
|
||||
train_cameras = train_data.camera
|
||||
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
|
||||
test_cameras = generate_eval_video_cameras(
|
||||
train_cameras,
|
||||
time=time,
|
||||
n_eval_cams=n_eval_cameras,
|
||||
trajectory_type=trajectory_type,
|
||||
trajectory_scale=trajectory_scale,
|
||||
scene_center=scene_center,
|
||||
up=up,
|
||||
focal_length=None,
|
||||
principal_point=torch.zeros(n_eval_cameras, 2),
|
||||
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||
)
|
||||
|
||||
# sample the source views reproducibly
|
||||
with torch.random.fork_rng():
|
||||
torch.manual_seed(seed)
|
||||
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
|
||||
# add the first dummy view that will get replaced with the target camera
|
||||
source_views_i = Fu.pad(source_views_i, [1, 0])
|
||||
source_views = [seq_idx[i] for i in source_views_i.tolist()]
|
||||
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
|
||||
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
|
||||
|
||||
preds_total = []
|
||||
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
|
||||
# set the first batch camera to the target camera
|
||||
for k in ("R", "T", "focal_length", "principal_point"):
|
||||
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
|
||||
|
||||
# Move to cuda
|
||||
net_input = batch.cuda()
|
||||
with torch.no_grad():
|
||||
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
|
||||
|
||||
# make sure we dont overwrite something
|
||||
assert all(k not in preds for k in net_input.keys())
|
||||
preds.update(net_input) # merge everything into one big dict
|
||||
|
||||
# Render the predictions to images
|
||||
rendered_pred = images_from_preds(preds)
|
||||
preds_total.append(rendered_pred)
|
||||
|
||||
# show the preds every 5% of the export iterations
|
||||
if visdom_show_preds and (
|
||||
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
||||
):
|
||||
show_predictions(
|
||||
preds_total,
|
||||
sequence_name=batch.sequence_name[0],
|
||||
viz=viz,
|
||||
viz_env=viz_env,
|
||||
)
|
||||
|
||||
print(f"Exporting videos for sequence {sequence_name} ...")
|
||||
generate_prediction_videos(
|
||||
preds_total,
|
||||
sequence_name=batch.sequence_name[0],
|
||||
viz=viz,
|
||||
viz_env=viz_env,
|
||||
fps=fps,
|
||||
video_path=video_path,
|
||||
resize=video_resize,
|
||||
)
|
||||
|
||||
|
||||
def _load_whole_dataset(dataset, idx, num_workers=10):
|
||||
load_all_dataloader = torch.utils.data.DataLoader(
|
||||
torch.utils.data.Subset(dataset, idx),
|
||||
batch_size=len(idx),
|
||||
num_workers=num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
return next(iter(load_all_dataloader))
|
||||
|
||||
|
||||
def images_from_preds(preds):
|
||||
imout = {}
|
||||
for k in (
|
||||
"image_rgb",
|
||||
"images_render",
|
||||
"fg_probability",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"depth_map",
|
||||
"_all_source_images",
|
||||
):
|
||||
if k == "_all_source_images" and "image_rgb" in preds:
|
||||
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
||||
v = _stack_images(src_ims, None)[None]
|
||||
else:
|
||||
if k not in preds or preds[k] is None:
|
||||
print(f"cant show {k}")
|
||||
continue
|
||||
v = preds[k].cpu().detach().clone()
|
||||
if k.startswith("depth"):
|
||||
mask_resize = Fu.interpolate(
|
||||
preds["masks_render"],
|
||||
size=preds[k].shape[2:],
|
||||
mode="nearest",
|
||||
)
|
||||
v = make_depth_image(preds[k], mask_resize)
|
||||
if v.shape[1] == 1:
|
||||
v = v.repeat(1, 3, 1, 1)
|
||||
imout[k] = v.detach().cpu()
|
||||
|
||||
return imout
|
||||
|
||||
|
||||
def _stack_images(ims, size):
|
||||
ba = ims.shape[0]
|
||||
H = int(np.ceil(np.sqrt(ba)))
|
||||
W = H
|
||||
n_add = H * W - ba
|
||||
if n_add > 0:
|
||||
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
|
||||
|
||||
ims = ims.view(H, W, *ims.shape[1:])
|
||||
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
|
||||
if size is not None:
|
||||
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
|
||||
return cated.clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def show_predictions(
|
||||
preds,
|
||||
sequence_name,
|
||||
viz,
|
||||
viz_env="visualizer",
|
||||
predicted_keys=(
|
||||
"images_render",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"_all_source_images",
|
||||
),
|
||||
n_samples=10,
|
||||
one_image_width=200,
|
||||
):
|
||||
"""Given a list of predictions visualize them into a single image using visdom."""
|
||||
assert isinstance(preds, list)
|
||||
|
||||
pred_all = []
|
||||
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
|
||||
n_samples = min(n_samples, len(preds))
|
||||
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
|
||||
for predi in pred_idx:
|
||||
# Make the concatentation for the same camera vertically
|
||||
pred_all.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.nn.functional.interpolate(
|
||||
preds[predi][k].cpu(),
|
||||
scale_factor=one_image_width / preds[predi][k].shape[3],
|
||||
mode="bilinear",
|
||||
).clamp(0.0, 1.0)
|
||||
for k in predicted_keys
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
)
|
||||
# Concatenate the images horizontally
|
||||
pred_all_cat = torch.cat(pred_all, dim=3)[0]
|
||||
viz.image(
|
||||
pred_all_cat,
|
||||
win="show_predictions",
|
||||
env=viz_env,
|
||||
opts={"title": f"pred_{sequence_name}"},
|
||||
)
|
||||
|
||||
|
||||
def generate_prediction_videos(
|
||||
preds,
|
||||
sequence_name,
|
||||
viz=None,
|
||||
viz_env="visualizer",
|
||||
predicted_keys=(
|
||||
"images_render",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"_all_source_images",
|
||||
),
|
||||
fps=20,
|
||||
video_path="/tmp/video",
|
||||
resize=None,
|
||||
):
|
||||
"""Given a list of predictions create and visualize rotating videos of the
|
||||
objects using visdom.
|
||||
"""
|
||||
assert isinstance(preds, list)
|
||||
|
||||
# make sure the target video directory exists
|
||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||
|
||||
# init a video writer for each predicted key
|
||||
vws = {}
|
||||
for k in predicted_keys:
|
||||
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
|
||||
|
||||
for rendered_pred in tqdm(preds):
|
||||
for k in predicted_keys:
|
||||
vws[k].write_frame(
|
||||
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||
resize=resize,
|
||||
)
|
||||
|
||||
for k in predicted_keys:
|
||||
vws[k].get_video(quiet=True)
|
||||
print(f"Generated {vws[k].out_path}.")
|
||||
if viz is not None:
|
||||
viz.video(
|
||||
videofile=vws[k].out_path,
|
||||
env=viz_env,
|
||||
win=k, # we reuse the same window otherwise visdom dies
|
||||
opts={"title": sequence_name + " " + k},
|
||||
)
|
||||
|
||||
|
||||
def export_scenes(
|
||||
def visualize_reconstruction(
|
||||
exp_dir: str = "",
|
||||
restrict_sequence_name: Optional[str] = None,
|
||||
output_directory: Optional[str] = None,
|
||||
render_size: Tuple[int, int] = (512, 512),
|
||||
video_size: Optional[Tuple[int, int]] = None,
|
||||
split: str = "train", # train | val | test
|
||||
split: str = "train",
|
||||
n_source_views: int = 9,
|
||||
n_eval_cameras: int = 40,
|
||||
visdom_server="http://127.0.0.1",
|
||||
visdom_port=8097,
|
||||
visdom_show_preds: bool = False,
|
||||
visdom_server: str = "http://127.0.0.1",
|
||||
visdom_port: int = 8097,
|
||||
visdom_env: Optional[str] = None,
|
||||
gpu_idx: int = 0,
|
||||
):
|
||||
"""
|
||||
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
|
||||
of renderes of sequences from the dataset used to train and evaluate the trained
|
||||
Implicitron model.
|
||||
|
||||
Args:
|
||||
exp_dir: Implicitron experiment directory.
|
||||
restrict_sequence_name: If set, defines the list of sequences to visualize.
|
||||
output_directory: If set, defines a custom directory to output visualizations to.
|
||||
render_size: The size (HxW) of the generated renders.
|
||||
video_size: The size (HxW) of the output video.
|
||||
split: The dataset split to use for visualization.
|
||||
Can be "train" / "val" / "test".
|
||||
n_source_views: The number of source views added to each rendered batch. These
|
||||
views are required inputs for models such as NeRFormer / NeRF-WCE.
|
||||
n_eval_cameras: The number of cameras each fly-around trajectory.
|
||||
visdom_show_preds: If `True`, outputs visualizations to visdom.
|
||||
visdom_server: The address of the visdom server.
|
||||
visdom_port: The port of the visdom server.
|
||||
visdom_env: If set, defines a custom name for the visdom environment.
|
||||
"""
|
||||
|
||||
# In case an output directory is specified use it. If no output_directory
|
||||
# is specified create a vis folder inside the experiment directory
|
||||
if output_directory is None:
|
||||
output_directory = os.path.join(exp_dir, "vis")
|
||||
else:
|
||||
output_directory = output_directory
|
||||
if not os.path.exists(output_directory):
|
||||
os.makedirs(output_directory)
|
||||
os.makedirs(output_directory, exist_ok=True)
|
||||
|
||||
# Set the random seeds
|
||||
torch.manual_seed(0)
|
||||
@ -325,7 +74,6 @@ def export_scenes(
|
||||
# Get the config from the experiment_directory,
|
||||
# and overwrite relevant fields
|
||||
config = _get_config_from_experiment_directory(exp_dir)
|
||||
config.gpu_idx = gpu_idx
|
||||
config.exp_dir = exp_dir
|
||||
# important so that the CO3D dataset gets loaded in full
|
||||
dataset_args = (
|
||||
@ -340,10 +88,6 @@ def export_scenes(
|
||||
if restrict_sequence_name is not None:
|
||||
dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||
|
||||
# Set up the CUDA env for the visualization
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
||||
|
||||
# Load the previously trained model
|
||||
experiment = Experiment(config)
|
||||
model = experiment.model_factory(force_resume=True)
|
||||
@ -360,17 +104,17 @@ def export_scenes(
|
||||
# iterate over the sequences in the dataset
|
||||
for sequence_name in dataset.sequence_names():
|
||||
with torch.no_grad():
|
||||
render_sequence(
|
||||
dataset,
|
||||
sequence_name,
|
||||
model,
|
||||
video_path="{}/video".format(output_directory),
|
||||
render_flyaround(
|
||||
dataset=dataset,
|
||||
sequence_name=sequence_name,
|
||||
model=model,
|
||||
output_video_path=os.path.join(output_directory, "video"),
|
||||
n_source_views=n_source_views,
|
||||
visdom_show_preds=visdom_show_preds,
|
||||
n_eval_cameras=n_eval_cameras,
|
||||
n_flyaround_poses=n_eval_cameras,
|
||||
visdom_server=visdom_server,
|
||||
visdom_port=visdom_port,
|
||||
viz_env=f"visualizer_{config.visdom_env}"
|
||||
visdom_environment=f"visualizer_{config.visdom_env}"
|
||||
if visdom_env is None
|
||||
else visdom_env,
|
||||
video_resize=video_size,
|
||||
@ -384,11 +128,11 @@ def _get_config_from_experiment_directory(experiment_directory):
|
||||
|
||||
|
||||
def main(argv):
|
||||
# automatically parses arguments of export_scenes
|
||||
cfg = OmegaConf.create(get_default_args(export_scenes))
|
||||
# automatically parses arguments of visualize_reconstruction
|
||||
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
|
||||
cfg.update(OmegaConf.from_cli())
|
||||
with torch.no_grad():
|
||||
export_scenes(**cfg)
|
||||
visualize_reconstruction(**cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
6
pytorch3d/implicitron/models/visualization/__init__.py
Normal file
6
pytorch3d/implicitron/models/visualization/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
363
pytorch3d/implicitron/models/visualization/render_flyaround.py
Normal file
363
pytorch3d/implicitron/models/visualization/render_flyaround.py
Normal file
@ -0,0 +1,363 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as Fu
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||
generate_eval_video_cameras,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.video_writer import VideoWriter
|
||||
from pytorch3d.implicitron.tools.vis_utils import (
|
||||
get_visdom_connection,
|
||||
make_depth_image,
|
||||
)
|
||||
from tqdm import tqdm
|
||||
from visdom import Visdom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def render_flyaround(
|
||||
dataset: DatasetBase,
|
||||
sequence_name: str,
|
||||
model: torch.nn.Module,
|
||||
output_video_path: str,
|
||||
n_flyaround_poses: int = 40,
|
||||
fps: int = 20,
|
||||
trajectory_type: str = "circular_lsq_fit",
|
||||
max_angle: float = 2 * math.pi,
|
||||
trajectory_scale: float = 1.1,
|
||||
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||
up: Tuple[float, float, float] = (0.0, -1.0, 0.0),
|
||||
traj_offset: float = 0.0,
|
||||
n_source_views: int = 9,
|
||||
visdom_show_preds: bool = False,
|
||||
visdom_environment: str = "render_flyaround",
|
||||
visdom_server: str = "http://127.0.0.1",
|
||||
visdom_port: int = 8097,
|
||||
num_workers: int = 10,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
seed: Optional[int] = None,
|
||||
video_resize: Optional[Tuple[int, int]] = None,
|
||||
output_video_frames_dir: Optional[str] = None,
|
||||
visualize_preds_keys: Sequence[str] = (
|
||||
"images_render",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"_all_source_images",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Uses `model` to generate a video consisting of renders of a scene imaged from
|
||||
a camera flying around the scene. The scene is specified with the `dataset` object and
|
||||
`sequence_name` which denotes the name of the scene whose frames are in `dataset`.
|
||||
|
||||
Args:
|
||||
dataset: The dataset object containing frames from a sequence in `sequence_name`.
|
||||
sequence_name: Name of a sequence from `dataset`.
|
||||
model: The model whose predictions are going to be visualized.
|
||||
output_video_path: The path to the video output by this script.
|
||||
n_flyaround_poses: The number of camera poses of the flyaround trajectory.
|
||||
fps: Framerate of the output video.
|
||||
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||
circular_lsq_fit: Camera centers follow a trajectory obtained
|
||||
by fitting a 3D circle to train_cameras centers.
|
||||
All cameras are looking towards scene_center.
|
||||
figure_eight: Figure-of-8 trajectory around the center of the
|
||||
central camera of the training dataset.
|
||||
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
|
||||
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||
of a figure-eight knot
|
||||
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||
circular_lsq_fit: Camera centers follow a trajectory obtained
|
||||
by fitting a 3D circle to train_cameras centers.
|
||||
All cameras are looking towards scene_center.
|
||||
figure_eight: Figure-of-8 trajectory around the center of the
|
||||
central camera of the training dataset.
|
||||
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
|
||||
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||
of a figure-eight knot
|
||||
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||
max_angle: Defines the total length of the generated camera trajectory.
|
||||
All possible trajectories (set with the `trajectory_type` argument) are
|
||||
periodic with the period of `time==2pi`.
|
||||
E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi` will generate
|
||||
a trajectory of camera poses rotating the total of 720 deg around the object.
|
||||
trajectory_scale: The extent of the trajectory.
|
||||
scene_center: The center of the scene in world coordinates which all
|
||||
the cameras from the generated trajectory look at.
|
||||
up: The "up" vector of the scene (=the normal of the scene floor).
|
||||
Active for the `trajectory_type="circular"`.
|
||||
traj_offset: 3D offset vector added to each point of the trajectory.
|
||||
n_source_views: The number of source views sampled from the known views of the
|
||||
training sequence added to each evaluation batch.
|
||||
visdom_show_preds: If `True`, exports the visualizations to visdom.
|
||||
visdom_environment: The name of the visdom environment.
|
||||
visdom_server: The address of the visdom server.
|
||||
visdom_port: The visdom port.
|
||||
num_workers: The number of workers used to load the training data.
|
||||
seed: The random seed used for reproducible sampling of the source views.
|
||||
video_resize: Optionally, defines the size of the output video.
|
||||
output_video_frames_dir: If specified, the frames of the output video are going
|
||||
to be permanently stored in this directory.
|
||||
visualize_preds_keys: The names of the model predictions to visualize.
|
||||
"""
|
||||
|
||||
if seed is None:
|
||||
seed = hash(sequence_name)
|
||||
|
||||
if visdom_show_preds:
|
||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||
else:
|
||||
viz = None
|
||||
|
||||
logger.info(f"Loading all data of sequence '{sequence_name}'.")
|
||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
|
||||
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
|
||||
logger.info(f"Sequence set = {sequence_set_name}.")
|
||||
train_cameras = train_data.camera
|
||||
time = torch.linspace(0, max_angle, n_flyaround_poses + 1)[:n_flyaround_poses]
|
||||
test_cameras = generate_eval_video_cameras(
|
||||
train_cameras,
|
||||
time=time,
|
||||
n_eval_cams=n_flyaround_poses,
|
||||
trajectory_type=trajectory_type,
|
||||
trajectory_scale=trajectory_scale,
|
||||
scene_center=scene_center,
|
||||
up=up,
|
||||
focal_length=None,
|
||||
principal_point=torch.zeros(n_flyaround_poses, 2),
|
||||
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||
)
|
||||
|
||||
# sample the source views reproducibly
|
||||
with torch.random.fork_rng():
|
||||
torch.manual_seed(seed)
|
||||
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
|
||||
|
||||
# add the first dummy view that will get replaced with the target camera
|
||||
source_views_i = Fu.pad(source_views_i, [1, 0])
|
||||
source_views = [seq_idx[i] for i in source_views_i.tolist()]
|
||||
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
|
||||
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
|
||||
|
||||
preds_total = []
|
||||
for n in tqdm(range(n_flyaround_poses), total=n_flyaround_poses):
|
||||
# set the first batch camera to the target camera
|
||||
for k in ("R", "T", "focal_length", "principal_point"):
|
||||
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
|
||||
|
||||
# Move to cuda
|
||||
net_input = batch.to(device)
|
||||
with torch.no_grad():
|
||||
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
|
||||
|
||||
# make sure we dont overwrite something
|
||||
assert all(k not in preds for k in net_input.keys())
|
||||
preds.update(net_input) # merge everything into one big dict
|
||||
|
||||
# Render the predictions to images
|
||||
rendered_pred = _images_from_preds(preds)
|
||||
preds_total.append(rendered_pred)
|
||||
|
||||
# show the preds every 5% of the export iterations
|
||||
if visdom_show_preds and (
|
||||
n % max(n_flyaround_poses // 20, 1) == 0 or n == n_flyaround_poses - 1
|
||||
):
|
||||
assert viz is not None
|
||||
_show_predictions(
|
||||
preds_total,
|
||||
sequence_name=batch.sequence_name[0],
|
||||
viz=viz,
|
||||
viz_env=visdom_environment,
|
||||
)
|
||||
|
||||
logger.info(f"Exporting videos for sequence {sequence_name} ...")
|
||||
_generate_prediction_videos(
|
||||
preds_total,
|
||||
sequence_name=batch.sequence_name[0],
|
||||
viz=viz,
|
||||
viz_env=visdom_environment,
|
||||
fps=fps,
|
||||
video_path=output_video_path,
|
||||
resize=video_resize,
|
||||
video_frames_dir=output_video_frames_dir,
|
||||
)
|
||||
|
||||
|
||||
def _load_whole_dataset(
|
||||
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
|
||||
):
|
||||
load_all_dataloader = torch.utils.data.DataLoader(
|
||||
torch.utils.data.Subset(dataset, idx),
|
||||
batch_size=len(idx),
|
||||
num_workers=num_workers,
|
||||
shuffle=False,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
return next(iter(load_all_dataloader))
|
||||
|
||||
|
||||
def _images_from_preds(preds: Dict[str, Any]):
|
||||
imout = {}
|
||||
for k in (
|
||||
"image_rgb",
|
||||
"images_render",
|
||||
"fg_probability",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"depth_map",
|
||||
"_all_source_images",
|
||||
):
|
||||
if k == "_all_source_images" and "image_rgb" in preds:
|
||||
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
||||
v = _stack_images(src_ims, None)[None]
|
||||
else:
|
||||
if k not in preds or preds[k] is None:
|
||||
print(f"cant show {k}")
|
||||
continue
|
||||
v = preds[k].cpu().detach().clone()
|
||||
if k.startswith("depth"):
|
||||
mask_resize = Fu.interpolate(
|
||||
preds["masks_render"],
|
||||
size=preds[k].shape[2:],
|
||||
mode="nearest",
|
||||
)
|
||||
v = make_depth_image(preds[k], mask_resize)
|
||||
if v.shape[1] == 1:
|
||||
v = v.repeat(1, 3, 1, 1)
|
||||
imout[k] = v.detach().cpu()
|
||||
|
||||
return imout
|
||||
|
||||
|
||||
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
|
||||
ba = ims.shape[0]
|
||||
H = int(np.ceil(np.sqrt(ba)))
|
||||
W = H
|
||||
n_add = H * W - ba
|
||||
if n_add > 0:
|
||||
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
|
||||
|
||||
ims = ims.view(H, W, *ims.shape[1:])
|
||||
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
|
||||
if size is not None:
|
||||
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
|
||||
return cated.clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def _show_predictions(
|
||||
preds: List[Dict[str, Any]],
|
||||
sequence_name: str,
|
||||
viz: Visdom,
|
||||
viz_env: str = "visualizer",
|
||||
predicted_keys: Sequence[str] = (
|
||||
"images_render",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"_all_source_images",
|
||||
),
|
||||
n_samples=10,
|
||||
one_image_width=200,
|
||||
):
|
||||
"""Given a list of predictions visualize them into a single image using visdom."""
|
||||
assert isinstance(preds, list)
|
||||
|
||||
pred_all = []
|
||||
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
|
||||
n_samples = min(n_samples, len(preds))
|
||||
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
|
||||
for predi in pred_idx:
|
||||
# Make the concatentation for the same camera vertically
|
||||
pred_all.append(
|
||||
torch.cat(
|
||||
[
|
||||
torch.nn.functional.interpolate(
|
||||
preds[predi][k].cpu(),
|
||||
scale_factor=one_image_width / preds[predi][k].shape[3],
|
||||
mode="bilinear",
|
||||
).clamp(0.0, 1.0)
|
||||
for k in predicted_keys
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
)
|
||||
# Concatenate the images horizontally
|
||||
pred_all_cat = torch.cat(pred_all, dim=3)[0]
|
||||
viz.image(
|
||||
pred_all_cat,
|
||||
win="show_predictions",
|
||||
env=viz_env,
|
||||
opts={"title": f"pred_{sequence_name}"},
|
||||
)
|
||||
|
||||
|
||||
def _generate_prediction_videos(
|
||||
preds: List[Dict[str, Any]],
|
||||
sequence_name: str,
|
||||
viz: Optional[Visdom] = None,
|
||||
viz_env: str = "visualizer",
|
||||
predicted_keys: Sequence[str] = (
|
||||
"images_render",
|
||||
"masks_render",
|
||||
"depths_render",
|
||||
"_all_source_images",
|
||||
),
|
||||
fps: int = 20,
|
||||
video_path: str = "/tmp/video",
|
||||
video_frames_dir: Optional[str] = None,
|
||||
resize: Optional[Tuple[int, int]] = None,
|
||||
):
|
||||
"""Given a list of predictions create and visualize rotating videos of the
|
||||
objects using visdom.
|
||||
"""
|
||||
|
||||
# make sure the target video directory exists
|
||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||
|
||||
# init a video writer for each predicted key
|
||||
vws = {}
|
||||
for k in predicted_keys:
|
||||
vws[k] = VideoWriter(
|
||||
fps=fps,
|
||||
out_path=f"{video_path}_{sequence_name}_{k}.mp4",
|
||||
cache_dir=os.path.join(video_frames_dir, f"{sequence_name}_{k}"),
|
||||
)
|
||||
|
||||
for rendered_pred in tqdm(preds):
|
||||
for k in predicted_keys:
|
||||
vws[k].write_frame(
|
||||
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||
resize=resize,
|
||||
)
|
||||
|
||||
for k in predicted_keys:
|
||||
vws[k].get_video(quiet=True)
|
||||
logger.info(f"Generated {vws[k].out_path}.")
|
||||
if viz is not None:
|
||||
viz.video(
|
||||
videofile=vws[k].out_path,
|
||||
env=viz_env,
|
||||
win=k, # we reuse the same window otherwise visdom dies
|
||||
opts={"title": sequence_name + " " + k},
|
||||
)
|
@ -37,7 +37,7 @@ def generate_eval_video_cameras(
|
||||
Generate a camera trajectory rendering a scene from multiple viewpoints.
|
||||
|
||||
Args:
|
||||
train_dataset: The training dataset object.
|
||||
train_cameras: The set of cameras from the training dataset object.
|
||||
n_eval_cams: Number of cameras in the trajectory.
|
||||
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||
circular_lsq_fit: Camera centers follow a trajectory obtained
|
||||
@ -51,16 +51,30 @@ def generate_eval_video_cameras(
|
||||
of a figure-eight knot
|
||||
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||
trajectory_scale: The extent of the trajectory.
|
||||
up: The "up" vector of the scene (=the normal of the scene floor).
|
||||
Active for the `trajectory_type="circular"`.
|
||||
scene_center: The center of the scene in world coordinates which all
|
||||
the cameras from the generated trajectory look at.
|
||||
up: The "circular_lsq_fit" vector of the scene (=the normal of the scene floor).
|
||||
Active for the `trajectory_type="circular"`.
|
||||
focal_length: The focal length of the output cameras. If `None`, an average
|
||||
focal length of the train_cameras is used.
|
||||
principal_point: The principal point of the output cameras. If `None`, an average
|
||||
principal point of all train_cameras is used.
|
||||
time: Defines the total length of the generated camera trajectory. All possible
|
||||
trajectories (set with the `trajectory_type` argument) are periodic with
|
||||
the period of `time=2pi`.
|
||||
E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi`, will generate
|
||||
a trajectory of camera poses rotating the total of 720 deg around the object.
|
||||
infer_up_as_plane_normal: Infer the camera `up` vector automatically as the normal
|
||||
of the plane fit to the optical centers of `train_cameras`.
|
||||
traj_offset: 3D offset vector added to each point of the trajectory.
|
||||
traj_offset_canonical: 3D offset vector expressed in the local coordinates of
|
||||
the estimated trajectory which is added to each point of the trajectory.
|
||||
remove_outliers_rate: the number between 0 and 1; if > 0,
|
||||
some outlier train_cameras will be removed from trajectory estimation;
|
||||
the filtering is based on camera center coordinates; top and
|
||||
bottom `remove_outliers_rate` cameras on each dimension are removed.
|
||||
Returns:
|
||||
Dictionary of camera instances which can be used as the test dataset
|
||||
Batch of camera instances which can be used as the test dataset
|
||||
"""
|
||||
if remove_outliers_rate > 0.0:
|
||||
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
|
||||
|
@ -68,7 +68,7 @@ def get_skateboard_data(
|
||||
if not os.environ.get("FB_TEST", False):
|
||||
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
|
||||
raise unittest.SkipTest("Unknown environment. Data not available.")
|
||||
yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", PathManager()
|
||||
yield "/datasets01/co3d/081922", PathManager()
|
||||
|
||||
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
|
||||
from libfb.py.parutil import get_file_path
|
||||
|
154
tests/implicitron/test_model_visualize.py
Normal file
154
tests/implicitron/test_model_visualize.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
|
||||
|
||||
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from tests.common_testing import interactive_testing_requested
|
||||
from visdom import Visdom
|
||||
|
||||
from .common_resources import get_skateboard_data
|
||||
|
||||
|
||||
class TestModelVisualize(unittest.TestCase):
|
||||
def test_flyaround_one_sequence(
|
||||
self,
|
||||
image_size: int = 256,
|
||||
):
|
||||
if not interactive_testing_requested():
|
||||
return
|
||||
category = "skateboard"
|
||||
stack = contextlib.ExitStack()
|
||||
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
|
||||
self.addCleanup(stack.close)
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
||||
expand_args_fields(JsonIndexDataset)
|
||||
train_dataset = JsonIndexDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
subset_lists_file=subset_lists_file,
|
||||
dataset_root=dataset_root,
|
||||
image_height=image_size,
|
||||
image_width=image_size,
|
||||
box_crop=True,
|
||||
load_point_clouds=True,
|
||||
path_manager=path_manager,
|
||||
subsets=[
|
||||
"train_known",
|
||||
],
|
||||
)
|
||||
|
||||
# select few sequences to visualize
|
||||
sequence_names = list(train_dataset.seq_annots.keys())
|
||||
|
||||
# select the first sequence name
|
||||
show_sequence_name = sequence_names[0]
|
||||
|
||||
output_dir = os.path.split(os.path.abspath(__file__))[0]
|
||||
|
||||
visdom_show_preds = Visdom().check_connection()
|
||||
|
||||
for load_dataset_pointcloud in [True, False]:
|
||||
|
||||
model = _PointcloudRenderingModel(
|
||||
train_dataset,
|
||||
show_sequence_name,
|
||||
device="cuda:0",
|
||||
load_dataset_pointcloud=load_dataset_pointcloud,
|
||||
)
|
||||
|
||||
video_path = os.path.join(
|
||||
output_dir,
|
||||
f"load_pcl_{load_dataset_pointcloud}",
|
||||
)
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
render_flyaround(
|
||||
train_dataset,
|
||||
show_sequence_name,
|
||||
model,
|
||||
video_path,
|
||||
n_flyaround_poses=40,
|
||||
fps=20,
|
||||
max_angle=2 * math.pi,
|
||||
trajectory_type="circular_lsq_fit",
|
||||
trajectory_scale=1.1,
|
||||
scene_center=(0.0, 0.0, 0.0),
|
||||
up=(0.0, 1.0, 0.0),
|
||||
traj_offset=1.0,
|
||||
n_source_views=1,
|
||||
visdom_show_preds=visdom_show_preds,
|
||||
visdom_environment="test_model_visalize",
|
||||
visdom_server="http://127.0.0.1",
|
||||
visdom_port=8097,
|
||||
num_workers=10,
|
||||
seed=None,
|
||||
video_resize=None,
|
||||
visualize_preds_keys=[
|
||||
"images_render",
|
||||
"depths_render",
|
||||
"masks_render",
|
||||
"_all_source_images",
|
||||
],
|
||||
output_video_frames_dir=video_path,
|
||||
)
|
||||
|
||||
|
||||
class _PointcloudRenderingModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_dataset: JsonIndexDataset,
|
||||
sequence_name: str,
|
||||
render_size: Tuple[int, int] = (400, 400),
|
||||
device=None,
|
||||
load_dataset_pointcloud: bool = False,
|
||||
max_frames: int = 30,
|
||||
num_workers: int = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self._render_size = render_size
|
||||
point_cloud, _ = get_implicitron_sequence_pointcloud(
|
||||
train_dataset,
|
||||
sequence_name=sequence_name,
|
||||
mask_points=True,
|
||||
max_frames=max_frames,
|
||||
num_workers=num_workers,
|
||||
load_dataset_point_cloud=load_dataset_pointcloud,
|
||||
)
|
||||
self._point_cloud = point_cloud.to(device)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
camera: CamerasBase,
|
||||
**kwargs,
|
||||
):
|
||||
image_render, mask_render, depth_render = render_point_cloud_pytorch3d(
|
||||
camera[0],
|
||||
self._point_cloud,
|
||||
render_size=self._render_size,
|
||||
point_radius=1e-2,
|
||||
topk=10,
|
||||
bg_color=0.0,
|
||||
)
|
||||
return {
|
||||
"images_render": image_render.clamp(0.0, 1.0),
|
||||
"masks_render": mask_render,
|
||||
"depths_render": depth_render,
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user