mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Rename and move render_flyaround into core implicitron
Summary: Move the flyaround rendering function into core implicitron. The unblocks an example in the facebookresearch/co3d repo. Reviewed By: bottler Differential Revision: D39257801 fbshipit-source-id: 6841a88a43d4aa364dd86ba83ca2d4c3cf0435a4
This commit is contained in:
parent
438c194ec6
commit
c79c954dea
@ -12,311 +12,60 @@
|
|||||||
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
|
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
import os
|
||||||
import random
|
|
||||||
import sys
|
import sys
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as Fu
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
from pytorch3d.implicitron.models.visualization import render_flyaround
|
||||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
|
||||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
|
||||||
from pytorch3d.implicitron.tools.configurable import get_default_args
|
from pytorch3d.implicitron.tools.configurable import get_default_args
|
||||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
|
||||||
generate_eval_video_cameras,
|
|
||||||
)
|
|
||||||
from pytorch3d.implicitron.tools.video_writer import VideoWriter
|
|
||||||
from pytorch3d.implicitron.tools.vis_utils import (
|
|
||||||
get_visdom_connection,
|
|
||||||
make_depth_image,
|
|
||||||
)
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from .experiment import Experiment
|
from .experiment import Experiment
|
||||||
|
|
||||||
|
|
||||||
def render_sequence(
|
def visualize_reconstruction(
|
||||||
dataset: DatasetBase,
|
|
||||||
sequence_name: str,
|
|
||||||
model: torch.nn.Module,
|
|
||||||
video_path,
|
|
||||||
n_eval_cameras=40,
|
|
||||||
fps=20,
|
|
||||||
max_angle=2 * math.pi,
|
|
||||||
trajectory_type="circular_lsq_fit",
|
|
||||||
trajectory_scale=1.1,
|
|
||||||
scene_center=(0.0, 0.0, 0.0),
|
|
||||||
up=(0.0, -1.0, 0.0),
|
|
||||||
traj_offset=0.0,
|
|
||||||
n_source_views=9,
|
|
||||||
viz_env="debug",
|
|
||||||
visdom_show_preds=False,
|
|
||||||
visdom_server="http://127.0.0.1",
|
|
||||||
visdom_port=8097,
|
|
||||||
num_workers=10,
|
|
||||||
seed=None,
|
|
||||||
video_resize=None,
|
|
||||||
):
|
|
||||||
if seed is None:
|
|
||||||
seed = hash(sequence_name)
|
|
||||||
|
|
||||||
if visdom_show_preds:
|
|
||||||
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
|
||||||
else:
|
|
||||||
viz = None
|
|
||||||
|
|
||||||
print(f"Loading all data of sequence '{sequence_name}'.")
|
|
||||||
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
|
||||||
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
|
||||||
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
|
|
||||||
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
|
|
||||||
print(f"Sequence set = {sequence_set_name}.")
|
|
||||||
train_cameras = train_data.camera
|
|
||||||
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
|
|
||||||
test_cameras = generate_eval_video_cameras(
|
|
||||||
train_cameras,
|
|
||||||
time=time,
|
|
||||||
n_eval_cams=n_eval_cameras,
|
|
||||||
trajectory_type=trajectory_type,
|
|
||||||
trajectory_scale=trajectory_scale,
|
|
||||||
scene_center=scene_center,
|
|
||||||
up=up,
|
|
||||||
focal_length=None,
|
|
||||||
principal_point=torch.zeros(n_eval_cameras, 2),
|
|
||||||
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
|
||||||
)
|
|
||||||
|
|
||||||
# sample the source views reproducibly
|
|
||||||
with torch.random.fork_rng():
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
|
|
||||||
# add the first dummy view that will get replaced with the target camera
|
|
||||||
source_views_i = Fu.pad(source_views_i, [1, 0])
|
|
||||||
source_views = [seq_idx[i] for i in source_views_i.tolist()]
|
|
||||||
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
|
|
||||||
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
|
|
||||||
|
|
||||||
preds_total = []
|
|
||||||
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
|
|
||||||
# set the first batch camera to the target camera
|
|
||||||
for k in ("R", "T", "focal_length", "principal_point"):
|
|
||||||
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
|
|
||||||
|
|
||||||
# Move to cuda
|
|
||||||
net_input = batch.cuda()
|
|
||||||
with torch.no_grad():
|
|
||||||
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
|
|
||||||
|
|
||||||
# make sure we dont overwrite something
|
|
||||||
assert all(k not in preds for k in net_input.keys())
|
|
||||||
preds.update(net_input) # merge everything into one big dict
|
|
||||||
|
|
||||||
# Render the predictions to images
|
|
||||||
rendered_pred = images_from_preds(preds)
|
|
||||||
preds_total.append(rendered_pred)
|
|
||||||
|
|
||||||
# show the preds every 5% of the export iterations
|
|
||||||
if visdom_show_preds and (
|
|
||||||
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
|
|
||||||
):
|
|
||||||
show_predictions(
|
|
||||||
preds_total,
|
|
||||||
sequence_name=batch.sequence_name[0],
|
|
||||||
viz=viz,
|
|
||||||
viz_env=viz_env,
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"Exporting videos for sequence {sequence_name} ...")
|
|
||||||
generate_prediction_videos(
|
|
||||||
preds_total,
|
|
||||||
sequence_name=batch.sequence_name[0],
|
|
||||||
viz=viz,
|
|
||||||
viz_env=viz_env,
|
|
||||||
fps=fps,
|
|
||||||
video_path=video_path,
|
|
||||||
resize=video_resize,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_whole_dataset(dataset, idx, num_workers=10):
|
|
||||||
load_all_dataloader = torch.utils.data.DataLoader(
|
|
||||||
torch.utils.data.Subset(dataset, idx),
|
|
||||||
batch_size=len(idx),
|
|
||||||
num_workers=num_workers,
|
|
||||||
shuffle=False,
|
|
||||||
collate_fn=FrameData.collate,
|
|
||||||
)
|
|
||||||
return next(iter(load_all_dataloader))
|
|
||||||
|
|
||||||
|
|
||||||
def images_from_preds(preds):
|
|
||||||
imout = {}
|
|
||||||
for k in (
|
|
||||||
"image_rgb",
|
|
||||||
"images_render",
|
|
||||||
"fg_probability",
|
|
||||||
"masks_render",
|
|
||||||
"depths_render",
|
|
||||||
"depth_map",
|
|
||||||
"_all_source_images",
|
|
||||||
):
|
|
||||||
if k == "_all_source_images" and "image_rgb" in preds:
|
|
||||||
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
|
||||||
v = _stack_images(src_ims, None)[None]
|
|
||||||
else:
|
|
||||||
if k not in preds or preds[k] is None:
|
|
||||||
print(f"cant show {k}")
|
|
||||||
continue
|
|
||||||
v = preds[k].cpu().detach().clone()
|
|
||||||
if k.startswith("depth"):
|
|
||||||
mask_resize = Fu.interpolate(
|
|
||||||
preds["masks_render"],
|
|
||||||
size=preds[k].shape[2:],
|
|
||||||
mode="nearest",
|
|
||||||
)
|
|
||||||
v = make_depth_image(preds[k], mask_resize)
|
|
||||||
if v.shape[1] == 1:
|
|
||||||
v = v.repeat(1, 3, 1, 1)
|
|
||||||
imout[k] = v.detach().cpu()
|
|
||||||
|
|
||||||
return imout
|
|
||||||
|
|
||||||
|
|
||||||
def _stack_images(ims, size):
|
|
||||||
ba = ims.shape[0]
|
|
||||||
H = int(np.ceil(np.sqrt(ba)))
|
|
||||||
W = H
|
|
||||||
n_add = H * W - ba
|
|
||||||
if n_add > 0:
|
|
||||||
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
|
|
||||||
|
|
||||||
ims = ims.view(H, W, *ims.shape[1:])
|
|
||||||
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
|
|
||||||
if size is not None:
|
|
||||||
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
|
|
||||||
return cated.clamp(0.0, 1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def show_predictions(
|
|
||||||
preds,
|
|
||||||
sequence_name,
|
|
||||||
viz,
|
|
||||||
viz_env="visualizer",
|
|
||||||
predicted_keys=(
|
|
||||||
"images_render",
|
|
||||||
"masks_render",
|
|
||||||
"depths_render",
|
|
||||||
"_all_source_images",
|
|
||||||
),
|
|
||||||
n_samples=10,
|
|
||||||
one_image_width=200,
|
|
||||||
):
|
|
||||||
"""Given a list of predictions visualize them into a single image using visdom."""
|
|
||||||
assert isinstance(preds, list)
|
|
||||||
|
|
||||||
pred_all = []
|
|
||||||
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
|
|
||||||
n_samples = min(n_samples, len(preds))
|
|
||||||
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
|
|
||||||
for predi in pred_idx:
|
|
||||||
# Make the concatentation for the same camera vertically
|
|
||||||
pred_all.append(
|
|
||||||
torch.cat(
|
|
||||||
[
|
|
||||||
torch.nn.functional.interpolate(
|
|
||||||
preds[predi][k].cpu(),
|
|
||||||
scale_factor=one_image_width / preds[predi][k].shape[3],
|
|
||||||
mode="bilinear",
|
|
||||||
).clamp(0.0, 1.0)
|
|
||||||
for k in predicted_keys
|
|
||||||
],
|
|
||||||
dim=2,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Concatenate the images horizontally
|
|
||||||
pred_all_cat = torch.cat(pred_all, dim=3)[0]
|
|
||||||
viz.image(
|
|
||||||
pred_all_cat,
|
|
||||||
win="show_predictions",
|
|
||||||
env=viz_env,
|
|
||||||
opts={"title": f"pred_{sequence_name}"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_prediction_videos(
|
|
||||||
preds,
|
|
||||||
sequence_name,
|
|
||||||
viz=None,
|
|
||||||
viz_env="visualizer",
|
|
||||||
predicted_keys=(
|
|
||||||
"images_render",
|
|
||||||
"masks_render",
|
|
||||||
"depths_render",
|
|
||||||
"_all_source_images",
|
|
||||||
),
|
|
||||||
fps=20,
|
|
||||||
video_path="/tmp/video",
|
|
||||||
resize=None,
|
|
||||||
):
|
|
||||||
"""Given a list of predictions create and visualize rotating videos of the
|
|
||||||
objects using visdom.
|
|
||||||
"""
|
|
||||||
assert isinstance(preds, list)
|
|
||||||
|
|
||||||
# make sure the target video directory exists
|
|
||||||
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
|
||||||
|
|
||||||
# init a video writer for each predicted key
|
|
||||||
vws = {}
|
|
||||||
for k in predicted_keys:
|
|
||||||
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
|
|
||||||
|
|
||||||
for rendered_pred in tqdm(preds):
|
|
||||||
for k in predicted_keys:
|
|
||||||
vws[k].write_frame(
|
|
||||||
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
|
||||||
resize=resize,
|
|
||||||
)
|
|
||||||
|
|
||||||
for k in predicted_keys:
|
|
||||||
vws[k].get_video(quiet=True)
|
|
||||||
print(f"Generated {vws[k].out_path}.")
|
|
||||||
if viz is not None:
|
|
||||||
viz.video(
|
|
||||||
videofile=vws[k].out_path,
|
|
||||||
env=viz_env,
|
|
||||||
win=k, # we reuse the same window otherwise visdom dies
|
|
||||||
opts={"title": sequence_name + " " + k},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def export_scenes(
|
|
||||||
exp_dir: str = "",
|
exp_dir: str = "",
|
||||||
restrict_sequence_name: Optional[str] = None,
|
restrict_sequence_name: Optional[str] = None,
|
||||||
output_directory: Optional[str] = None,
|
output_directory: Optional[str] = None,
|
||||||
render_size: Tuple[int, int] = (512, 512),
|
render_size: Tuple[int, int] = (512, 512),
|
||||||
video_size: Optional[Tuple[int, int]] = None,
|
video_size: Optional[Tuple[int, int]] = None,
|
||||||
split: str = "train", # train | val | test
|
split: str = "train",
|
||||||
n_source_views: int = 9,
|
n_source_views: int = 9,
|
||||||
n_eval_cameras: int = 40,
|
n_eval_cameras: int = 40,
|
||||||
visdom_server="http://127.0.0.1",
|
|
||||||
visdom_port=8097,
|
|
||||||
visdom_show_preds: bool = False,
|
visdom_show_preds: bool = False,
|
||||||
|
visdom_server: str = "http://127.0.0.1",
|
||||||
|
visdom_port: int = 8097,
|
||||||
visdom_env: Optional[str] = None,
|
visdom_env: Optional[str] = None,
|
||||||
gpu_idx: int = 0,
|
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
|
||||||
|
of renderes of sequences from the dataset used to train and evaluate the trained
|
||||||
|
Implicitron model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exp_dir: Implicitron experiment directory.
|
||||||
|
restrict_sequence_name: If set, defines the list of sequences to visualize.
|
||||||
|
output_directory: If set, defines a custom directory to output visualizations to.
|
||||||
|
render_size: The size (HxW) of the generated renders.
|
||||||
|
video_size: The size (HxW) of the output video.
|
||||||
|
split: The dataset split to use for visualization.
|
||||||
|
Can be "train" / "val" / "test".
|
||||||
|
n_source_views: The number of source views added to each rendered batch. These
|
||||||
|
views are required inputs for models such as NeRFormer / NeRF-WCE.
|
||||||
|
n_eval_cameras: The number of cameras each fly-around trajectory.
|
||||||
|
visdom_show_preds: If `True`, outputs visualizations to visdom.
|
||||||
|
visdom_server: The address of the visdom server.
|
||||||
|
visdom_port: The port of the visdom server.
|
||||||
|
visdom_env: If set, defines a custom name for the visdom environment.
|
||||||
|
"""
|
||||||
|
|
||||||
# In case an output directory is specified use it. If no output_directory
|
# In case an output directory is specified use it. If no output_directory
|
||||||
# is specified create a vis folder inside the experiment directory
|
# is specified create a vis folder inside the experiment directory
|
||||||
if output_directory is None:
|
if output_directory is None:
|
||||||
output_directory = os.path.join(exp_dir, "vis")
|
output_directory = os.path.join(exp_dir, "vis")
|
||||||
else:
|
os.makedirs(output_directory, exist_ok=True)
|
||||||
output_directory = output_directory
|
|
||||||
if not os.path.exists(output_directory):
|
|
||||||
os.makedirs(output_directory)
|
|
||||||
|
|
||||||
# Set the random seeds
|
# Set the random seeds
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
@ -325,7 +74,6 @@ def export_scenes(
|
|||||||
# Get the config from the experiment_directory,
|
# Get the config from the experiment_directory,
|
||||||
# and overwrite relevant fields
|
# and overwrite relevant fields
|
||||||
config = _get_config_from_experiment_directory(exp_dir)
|
config = _get_config_from_experiment_directory(exp_dir)
|
||||||
config.gpu_idx = gpu_idx
|
|
||||||
config.exp_dir = exp_dir
|
config.exp_dir = exp_dir
|
||||||
# important so that the CO3D dataset gets loaded in full
|
# important so that the CO3D dataset gets loaded in full
|
||||||
dataset_args = (
|
dataset_args = (
|
||||||
@ -340,10 +88,6 @@ def export_scenes(
|
|||||||
if restrict_sequence_name is not None:
|
if restrict_sequence_name is not None:
|
||||||
dataset_args.restrict_sequence_name = restrict_sequence_name
|
dataset_args.restrict_sequence_name = restrict_sequence_name
|
||||||
|
|
||||||
# Set up the CUDA env for the visualization
|
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
|
|
||||||
|
|
||||||
# Load the previously trained model
|
# Load the previously trained model
|
||||||
experiment = Experiment(config)
|
experiment = Experiment(config)
|
||||||
model = experiment.model_factory(force_resume=True)
|
model = experiment.model_factory(force_resume=True)
|
||||||
@ -360,17 +104,17 @@ def export_scenes(
|
|||||||
# iterate over the sequences in the dataset
|
# iterate over the sequences in the dataset
|
||||||
for sequence_name in dataset.sequence_names():
|
for sequence_name in dataset.sequence_names():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
render_sequence(
|
render_flyaround(
|
||||||
dataset,
|
dataset=dataset,
|
||||||
sequence_name,
|
sequence_name=sequence_name,
|
||||||
model,
|
model=model,
|
||||||
video_path="{}/video".format(output_directory),
|
output_video_path=os.path.join(output_directory, "video"),
|
||||||
n_source_views=n_source_views,
|
n_source_views=n_source_views,
|
||||||
visdom_show_preds=visdom_show_preds,
|
visdom_show_preds=visdom_show_preds,
|
||||||
n_eval_cameras=n_eval_cameras,
|
n_flyaround_poses=n_eval_cameras,
|
||||||
visdom_server=visdom_server,
|
visdom_server=visdom_server,
|
||||||
visdom_port=visdom_port,
|
visdom_port=visdom_port,
|
||||||
viz_env=f"visualizer_{config.visdom_env}"
|
visdom_environment=f"visualizer_{config.visdom_env}"
|
||||||
if visdom_env is None
|
if visdom_env is None
|
||||||
else visdom_env,
|
else visdom_env,
|
||||||
video_resize=video_size,
|
video_resize=video_size,
|
||||||
@ -384,11 +128,11 @@ def _get_config_from_experiment_directory(experiment_directory):
|
|||||||
|
|
||||||
|
|
||||||
def main(argv):
|
def main(argv):
|
||||||
# automatically parses arguments of export_scenes
|
# automatically parses arguments of visualize_reconstruction
|
||||||
cfg = OmegaConf.create(get_default_args(export_scenes))
|
cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
|
||||||
cfg.update(OmegaConf.from_cli())
|
cfg.update(OmegaConf.from_cli())
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
export_scenes(**cfg)
|
visualize_reconstruction(**cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
6
pytorch3d/implicitron/models/visualization/__init__.py
Normal file
6
pytorch3d/implicitron/models/visualization/__init__.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
363
pytorch3d/implicitron/models/visualization/render_flyaround.py
Normal file
363
pytorch3d/implicitron/models/visualization/render_flyaround.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as Fu
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||||
|
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||||
|
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||||
|
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||||
|
generate_eval_video_cameras,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.video_writer import VideoWriter
|
||||||
|
from pytorch3d.implicitron.tools.vis_utils import (
|
||||||
|
get_visdom_connection,
|
||||||
|
make_depth_image,
|
||||||
|
)
|
||||||
|
from tqdm import tqdm
|
||||||
|
from visdom import Visdom
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def render_flyaround(
|
||||||
|
dataset: DatasetBase,
|
||||||
|
sequence_name: str,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
output_video_path: str,
|
||||||
|
n_flyaround_poses: int = 40,
|
||||||
|
fps: int = 20,
|
||||||
|
trajectory_type: str = "circular_lsq_fit",
|
||||||
|
max_angle: float = 2 * math.pi,
|
||||||
|
trajectory_scale: float = 1.1,
|
||||||
|
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
|
||||||
|
up: Tuple[float, float, float] = (0.0, -1.0, 0.0),
|
||||||
|
traj_offset: float = 0.0,
|
||||||
|
n_source_views: int = 9,
|
||||||
|
visdom_show_preds: bool = False,
|
||||||
|
visdom_environment: str = "render_flyaround",
|
||||||
|
visdom_server: str = "http://127.0.0.1",
|
||||||
|
visdom_port: int = 8097,
|
||||||
|
num_workers: int = 10,
|
||||||
|
device: Union[str, torch.device] = "cuda",
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
video_resize: Optional[Tuple[int, int]] = None,
|
||||||
|
output_video_frames_dir: Optional[str] = None,
|
||||||
|
visualize_preds_keys: Sequence[str] = (
|
||||||
|
"images_render",
|
||||||
|
"masks_render",
|
||||||
|
"depths_render",
|
||||||
|
"_all_source_images",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Uses `model` to generate a video consisting of renders of a scene imaged from
|
||||||
|
a camera flying around the scene. The scene is specified with the `dataset` object and
|
||||||
|
`sequence_name` which denotes the name of the scene whose frames are in `dataset`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: The dataset object containing frames from a sequence in `sequence_name`.
|
||||||
|
sequence_name: Name of a sequence from `dataset`.
|
||||||
|
model: The model whose predictions are going to be visualized.
|
||||||
|
output_video_path: The path to the video output by this script.
|
||||||
|
n_flyaround_poses: The number of camera poses of the flyaround trajectory.
|
||||||
|
fps: Framerate of the output video.
|
||||||
|
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||||
|
circular_lsq_fit: Camera centers follow a trajectory obtained
|
||||||
|
by fitting a 3D circle to train_cameras centers.
|
||||||
|
All cameras are looking towards scene_center.
|
||||||
|
figure_eight: Figure-of-8 trajectory around the center of the
|
||||||
|
central camera of the training dataset.
|
||||||
|
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||||
|
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
|
||||||
|
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||||
|
of a figure-eight knot
|
||||||
|
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||||
|
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||||
|
circular_lsq_fit: Camera centers follow a trajectory obtained
|
||||||
|
by fitting a 3D circle to train_cameras centers.
|
||||||
|
All cameras are looking towards scene_center.
|
||||||
|
figure_eight: Figure-of-8 trajectory around the center of the
|
||||||
|
central camera of the training dataset.
|
||||||
|
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||||
|
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
|
||||||
|
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
|
||||||
|
of a figure-eight knot
|
||||||
|
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||||
|
max_angle: Defines the total length of the generated camera trajectory.
|
||||||
|
All possible trajectories (set with the `trajectory_type` argument) are
|
||||||
|
periodic with the period of `time==2pi`.
|
||||||
|
E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi` will generate
|
||||||
|
a trajectory of camera poses rotating the total of 720 deg around the object.
|
||||||
|
trajectory_scale: The extent of the trajectory.
|
||||||
|
scene_center: The center of the scene in world coordinates which all
|
||||||
|
the cameras from the generated trajectory look at.
|
||||||
|
up: The "up" vector of the scene (=the normal of the scene floor).
|
||||||
|
Active for the `trajectory_type="circular"`.
|
||||||
|
traj_offset: 3D offset vector added to each point of the trajectory.
|
||||||
|
n_source_views: The number of source views sampled from the known views of the
|
||||||
|
training sequence added to each evaluation batch.
|
||||||
|
visdom_show_preds: If `True`, exports the visualizations to visdom.
|
||||||
|
visdom_environment: The name of the visdom environment.
|
||||||
|
visdom_server: The address of the visdom server.
|
||||||
|
visdom_port: The visdom port.
|
||||||
|
num_workers: The number of workers used to load the training data.
|
||||||
|
seed: The random seed used for reproducible sampling of the source views.
|
||||||
|
video_resize: Optionally, defines the size of the output video.
|
||||||
|
output_video_frames_dir: If specified, the frames of the output video are going
|
||||||
|
to be permanently stored in this directory.
|
||||||
|
visualize_preds_keys: The names of the model predictions to visualize.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if seed is None:
|
||||||
|
seed = hash(sequence_name)
|
||||||
|
|
||||||
|
if visdom_show_preds:
|
||||||
|
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
|
||||||
|
else:
|
||||||
|
viz = None
|
||||||
|
|
||||||
|
logger.info(f"Loading all data of sequence '{sequence_name}'.")
|
||||||
|
seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
|
||||||
|
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
|
||||||
|
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
|
||||||
|
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
|
||||||
|
logger.info(f"Sequence set = {sequence_set_name}.")
|
||||||
|
train_cameras = train_data.camera
|
||||||
|
time = torch.linspace(0, max_angle, n_flyaround_poses + 1)[:n_flyaround_poses]
|
||||||
|
test_cameras = generate_eval_video_cameras(
|
||||||
|
train_cameras,
|
||||||
|
time=time,
|
||||||
|
n_eval_cams=n_flyaround_poses,
|
||||||
|
trajectory_type=trajectory_type,
|
||||||
|
trajectory_scale=trajectory_scale,
|
||||||
|
scene_center=scene_center,
|
||||||
|
up=up,
|
||||||
|
focal_length=None,
|
||||||
|
principal_point=torch.zeros(n_flyaround_poses, 2),
|
||||||
|
traj_offset_canonical=(0.0, 0.0, traj_offset),
|
||||||
|
)
|
||||||
|
|
||||||
|
# sample the source views reproducibly
|
||||||
|
with torch.random.fork_rng():
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
|
||||||
|
|
||||||
|
# add the first dummy view that will get replaced with the target camera
|
||||||
|
source_views_i = Fu.pad(source_views_i, [1, 0])
|
||||||
|
source_views = [seq_idx[i] for i in source_views_i.tolist()]
|
||||||
|
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
|
||||||
|
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
|
||||||
|
|
||||||
|
preds_total = []
|
||||||
|
for n in tqdm(range(n_flyaround_poses), total=n_flyaround_poses):
|
||||||
|
# set the first batch camera to the target camera
|
||||||
|
for k in ("R", "T", "focal_length", "principal_point"):
|
||||||
|
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
|
||||||
|
|
||||||
|
# Move to cuda
|
||||||
|
net_input = batch.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
|
||||||
|
|
||||||
|
# make sure we dont overwrite something
|
||||||
|
assert all(k not in preds for k in net_input.keys())
|
||||||
|
preds.update(net_input) # merge everything into one big dict
|
||||||
|
|
||||||
|
# Render the predictions to images
|
||||||
|
rendered_pred = _images_from_preds(preds)
|
||||||
|
preds_total.append(rendered_pred)
|
||||||
|
|
||||||
|
# show the preds every 5% of the export iterations
|
||||||
|
if visdom_show_preds and (
|
||||||
|
n % max(n_flyaround_poses // 20, 1) == 0 or n == n_flyaround_poses - 1
|
||||||
|
):
|
||||||
|
assert viz is not None
|
||||||
|
_show_predictions(
|
||||||
|
preds_total,
|
||||||
|
sequence_name=batch.sequence_name[0],
|
||||||
|
viz=viz,
|
||||||
|
viz_env=visdom_environment,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Exporting videos for sequence {sequence_name} ...")
|
||||||
|
_generate_prediction_videos(
|
||||||
|
preds_total,
|
||||||
|
sequence_name=batch.sequence_name[0],
|
||||||
|
viz=viz,
|
||||||
|
viz_env=visdom_environment,
|
||||||
|
fps=fps,
|
||||||
|
video_path=output_video_path,
|
||||||
|
resize=video_resize,
|
||||||
|
video_frames_dir=output_video_frames_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_whole_dataset(
|
||||||
|
dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
|
||||||
|
):
|
||||||
|
load_all_dataloader = torch.utils.data.DataLoader(
|
||||||
|
torch.utils.data.Subset(dataset, idx),
|
||||||
|
batch_size=len(idx),
|
||||||
|
num_workers=num_workers,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=FrameData.collate,
|
||||||
|
)
|
||||||
|
return next(iter(load_all_dataloader))
|
||||||
|
|
||||||
|
|
||||||
|
def _images_from_preds(preds: Dict[str, Any]):
|
||||||
|
imout = {}
|
||||||
|
for k in (
|
||||||
|
"image_rgb",
|
||||||
|
"images_render",
|
||||||
|
"fg_probability",
|
||||||
|
"masks_render",
|
||||||
|
"depths_render",
|
||||||
|
"depth_map",
|
||||||
|
"_all_source_images",
|
||||||
|
):
|
||||||
|
if k == "_all_source_images" and "image_rgb" in preds:
|
||||||
|
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
|
||||||
|
v = _stack_images(src_ims, None)[None]
|
||||||
|
else:
|
||||||
|
if k not in preds or preds[k] is None:
|
||||||
|
print(f"cant show {k}")
|
||||||
|
continue
|
||||||
|
v = preds[k].cpu().detach().clone()
|
||||||
|
if k.startswith("depth"):
|
||||||
|
mask_resize = Fu.interpolate(
|
||||||
|
preds["masks_render"],
|
||||||
|
size=preds[k].shape[2:],
|
||||||
|
mode="nearest",
|
||||||
|
)
|
||||||
|
v = make_depth_image(preds[k], mask_resize)
|
||||||
|
if v.shape[1] == 1:
|
||||||
|
v = v.repeat(1, 3, 1, 1)
|
||||||
|
imout[k] = v.detach().cpu()
|
||||||
|
|
||||||
|
return imout
|
||||||
|
|
||||||
|
|
||||||
|
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
|
||||||
|
ba = ims.shape[0]
|
||||||
|
H = int(np.ceil(np.sqrt(ba)))
|
||||||
|
W = H
|
||||||
|
n_add = H * W - ba
|
||||||
|
if n_add > 0:
|
||||||
|
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
|
||||||
|
|
||||||
|
ims = ims.view(H, W, *ims.shape[1:])
|
||||||
|
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
|
||||||
|
if size is not None:
|
||||||
|
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
|
||||||
|
return cated.clamp(0.0, 1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _show_predictions(
|
||||||
|
preds: List[Dict[str, Any]],
|
||||||
|
sequence_name: str,
|
||||||
|
viz: Visdom,
|
||||||
|
viz_env: str = "visualizer",
|
||||||
|
predicted_keys: Sequence[str] = (
|
||||||
|
"images_render",
|
||||||
|
"masks_render",
|
||||||
|
"depths_render",
|
||||||
|
"_all_source_images",
|
||||||
|
),
|
||||||
|
n_samples=10,
|
||||||
|
one_image_width=200,
|
||||||
|
):
|
||||||
|
"""Given a list of predictions visualize them into a single image using visdom."""
|
||||||
|
assert isinstance(preds, list)
|
||||||
|
|
||||||
|
pred_all = []
|
||||||
|
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
|
||||||
|
n_samples = min(n_samples, len(preds))
|
||||||
|
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
|
||||||
|
for predi in pred_idx:
|
||||||
|
# Make the concatentation for the same camera vertically
|
||||||
|
pred_all.append(
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.nn.functional.interpolate(
|
||||||
|
preds[predi][k].cpu(),
|
||||||
|
scale_factor=one_image_width / preds[predi][k].shape[3],
|
||||||
|
mode="bilinear",
|
||||||
|
).clamp(0.0, 1.0)
|
||||||
|
for k in predicted_keys
|
||||||
|
],
|
||||||
|
dim=2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Concatenate the images horizontally
|
||||||
|
pred_all_cat = torch.cat(pred_all, dim=3)[0]
|
||||||
|
viz.image(
|
||||||
|
pred_all_cat,
|
||||||
|
win="show_predictions",
|
||||||
|
env=viz_env,
|
||||||
|
opts={"title": f"pred_{sequence_name}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_prediction_videos(
|
||||||
|
preds: List[Dict[str, Any]],
|
||||||
|
sequence_name: str,
|
||||||
|
viz: Optional[Visdom] = None,
|
||||||
|
viz_env: str = "visualizer",
|
||||||
|
predicted_keys: Sequence[str] = (
|
||||||
|
"images_render",
|
||||||
|
"masks_render",
|
||||||
|
"depths_render",
|
||||||
|
"_all_source_images",
|
||||||
|
),
|
||||||
|
fps: int = 20,
|
||||||
|
video_path: str = "/tmp/video",
|
||||||
|
video_frames_dir: Optional[str] = None,
|
||||||
|
resize: Optional[Tuple[int, int]] = None,
|
||||||
|
):
|
||||||
|
"""Given a list of predictions create and visualize rotating videos of the
|
||||||
|
objects using visdom.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# make sure the target video directory exists
|
||||||
|
os.makedirs(os.path.dirname(video_path), exist_ok=True)
|
||||||
|
|
||||||
|
# init a video writer for each predicted key
|
||||||
|
vws = {}
|
||||||
|
for k in predicted_keys:
|
||||||
|
vws[k] = VideoWriter(
|
||||||
|
fps=fps,
|
||||||
|
out_path=f"{video_path}_{sequence_name}_{k}.mp4",
|
||||||
|
cache_dir=os.path.join(video_frames_dir, f"{sequence_name}_{k}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for rendered_pred in tqdm(preds):
|
||||||
|
for k in predicted_keys:
|
||||||
|
vws[k].write_frame(
|
||||||
|
rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
|
||||||
|
resize=resize,
|
||||||
|
)
|
||||||
|
|
||||||
|
for k in predicted_keys:
|
||||||
|
vws[k].get_video(quiet=True)
|
||||||
|
logger.info(f"Generated {vws[k].out_path}.")
|
||||||
|
if viz is not None:
|
||||||
|
viz.video(
|
||||||
|
videofile=vws[k].out_path,
|
||||||
|
env=viz_env,
|
||||||
|
win=k, # we reuse the same window otherwise visdom dies
|
||||||
|
opts={"title": sequence_name + " " + k},
|
||||||
|
)
|
@ -37,7 +37,7 @@ def generate_eval_video_cameras(
|
|||||||
Generate a camera trajectory rendering a scene from multiple viewpoints.
|
Generate a camera trajectory rendering a scene from multiple viewpoints.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
train_dataset: The training dataset object.
|
train_cameras: The set of cameras from the training dataset object.
|
||||||
n_eval_cams: Number of cameras in the trajectory.
|
n_eval_cams: Number of cameras in the trajectory.
|
||||||
trajectory_type: The type of the camera trajectory. Can be one of:
|
trajectory_type: The type of the camera trajectory. Can be one of:
|
||||||
circular_lsq_fit: Camera centers follow a trajectory obtained
|
circular_lsq_fit: Camera centers follow a trajectory obtained
|
||||||
@ -51,16 +51,30 @@ def generate_eval_video_cameras(
|
|||||||
of a figure-eight knot
|
of a figure-eight knot
|
||||||
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
|
||||||
trajectory_scale: The extent of the trajectory.
|
trajectory_scale: The extent of the trajectory.
|
||||||
up: The "up" vector of the scene (=the normal of the scene floor).
|
|
||||||
Active for the `trajectory_type="circular"`.
|
|
||||||
scene_center: The center of the scene in world coordinates which all
|
scene_center: The center of the scene in world coordinates which all
|
||||||
the cameras from the generated trajectory look at.
|
the cameras from the generated trajectory look at.
|
||||||
|
up: The "circular_lsq_fit" vector of the scene (=the normal of the scene floor).
|
||||||
|
Active for the `trajectory_type="circular"`.
|
||||||
|
focal_length: The focal length of the output cameras. If `None`, an average
|
||||||
|
focal length of the train_cameras is used.
|
||||||
|
principal_point: The principal point of the output cameras. If `None`, an average
|
||||||
|
principal point of all train_cameras is used.
|
||||||
|
time: Defines the total length of the generated camera trajectory. All possible
|
||||||
|
trajectories (set with the `trajectory_type` argument) are periodic with
|
||||||
|
the period of `time=2pi`.
|
||||||
|
E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi`, will generate
|
||||||
|
a trajectory of camera poses rotating the total of 720 deg around the object.
|
||||||
|
infer_up_as_plane_normal: Infer the camera `up` vector automatically as the normal
|
||||||
|
of the plane fit to the optical centers of `train_cameras`.
|
||||||
|
traj_offset: 3D offset vector added to each point of the trajectory.
|
||||||
|
traj_offset_canonical: 3D offset vector expressed in the local coordinates of
|
||||||
|
the estimated trajectory which is added to each point of the trajectory.
|
||||||
remove_outliers_rate: the number between 0 and 1; if > 0,
|
remove_outliers_rate: the number between 0 and 1; if > 0,
|
||||||
some outlier train_cameras will be removed from trajectory estimation;
|
some outlier train_cameras will be removed from trajectory estimation;
|
||||||
the filtering is based on camera center coordinates; top and
|
the filtering is based on camera center coordinates; top and
|
||||||
bottom `remove_outliers_rate` cameras on each dimension are removed.
|
bottom `remove_outliers_rate` cameras on each dimension are removed.
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of camera instances which can be used as the test dataset
|
Batch of camera instances which can be used as the test dataset
|
||||||
"""
|
"""
|
||||||
if remove_outliers_rate > 0.0:
|
if remove_outliers_rate > 0.0:
|
||||||
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
|
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
|
||||||
|
@ -68,7 +68,7 @@ def get_skateboard_data(
|
|||||||
if not os.environ.get("FB_TEST", False):
|
if not os.environ.get("FB_TEST", False):
|
||||||
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
|
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
|
||||||
raise unittest.SkipTest("Unknown environment. Data not available.")
|
raise unittest.SkipTest("Unknown environment. Data not available.")
|
||||||
yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", PathManager()
|
yield "/datasets01/co3d/081922", PathManager()
|
||||||
|
|
||||||
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
|
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
|
||||||
from libfb.py.parutil import get_file_path
|
from libfb.py.parutil import get_file_path
|
||||||
|
154
tests/implicitron/test_model_visualize.py
Normal file
154
tests/implicitron/test_model_visualize.py
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
|
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
|
||||||
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
|
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
from tests.common_testing import interactive_testing_requested
|
||||||
|
from visdom import Visdom
|
||||||
|
|
||||||
|
from .common_resources import get_skateboard_data
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelVisualize(unittest.TestCase):
|
||||||
|
def test_flyaround_one_sequence(
|
||||||
|
self,
|
||||||
|
image_size: int = 256,
|
||||||
|
):
|
||||||
|
if not interactive_testing_requested():
|
||||||
|
return
|
||||||
|
category = "skateboard"
|
||||||
|
stack = contextlib.ExitStack()
|
||||||
|
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
|
||||||
|
self.addCleanup(stack.close)
|
||||||
|
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||||
|
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||||
|
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
||||||
|
expand_args_fields(JsonIndexDataset)
|
||||||
|
train_dataset = JsonIndexDataset(
|
||||||
|
frame_annotations_file=frame_file,
|
||||||
|
sequence_annotations_file=sequence_file,
|
||||||
|
subset_lists_file=subset_lists_file,
|
||||||
|
dataset_root=dataset_root,
|
||||||
|
image_height=image_size,
|
||||||
|
image_width=image_size,
|
||||||
|
box_crop=True,
|
||||||
|
load_point_clouds=True,
|
||||||
|
path_manager=path_manager,
|
||||||
|
subsets=[
|
||||||
|
"train_known",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# select few sequences to visualize
|
||||||
|
sequence_names = list(train_dataset.seq_annots.keys())
|
||||||
|
|
||||||
|
# select the first sequence name
|
||||||
|
show_sequence_name = sequence_names[0]
|
||||||
|
|
||||||
|
output_dir = os.path.split(os.path.abspath(__file__))[0]
|
||||||
|
|
||||||
|
visdom_show_preds = Visdom().check_connection()
|
||||||
|
|
||||||
|
for load_dataset_pointcloud in [True, False]:
|
||||||
|
|
||||||
|
model = _PointcloudRenderingModel(
|
||||||
|
train_dataset,
|
||||||
|
show_sequence_name,
|
||||||
|
device="cuda:0",
|
||||||
|
load_dataset_pointcloud=load_dataset_pointcloud,
|
||||||
|
)
|
||||||
|
|
||||||
|
video_path = os.path.join(
|
||||||
|
output_dir,
|
||||||
|
f"load_pcl_{load_dataset_pointcloud}",
|
||||||
|
)
|
||||||
|
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
render_flyaround(
|
||||||
|
train_dataset,
|
||||||
|
show_sequence_name,
|
||||||
|
model,
|
||||||
|
video_path,
|
||||||
|
n_flyaround_poses=40,
|
||||||
|
fps=20,
|
||||||
|
max_angle=2 * math.pi,
|
||||||
|
trajectory_type="circular_lsq_fit",
|
||||||
|
trajectory_scale=1.1,
|
||||||
|
scene_center=(0.0, 0.0, 0.0),
|
||||||
|
up=(0.0, 1.0, 0.0),
|
||||||
|
traj_offset=1.0,
|
||||||
|
n_source_views=1,
|
||||||
|
visdom_show_preds=visdom_show_preds,
|
||||||
|
visdom_environment="test_model_visalize",
|
||||||
|
visdom_server="http://127.0.0.1",
|
||||||
|
visdom_port=8097,
|
||||||
|
num_workers=10,
|
||||||
|
seed=None,
|
||||||
|
video_resize=None,
|
||||||
|
visualize_preds_keys=[
|
||||||
|
"images_render",
|
||||||
|
"depths_render",
|
||||||
|
"masks_render",
|
||||||
|
"_all_source_images",
|
||||||
|
],
|
||||||
|
output_video_frames_dir=video_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _PointcloudRenderingModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train_dataset: JsonIndexDataset,
|
||||||
|
sequence_name: str,
|
||||||
|
render_size: Tuple[int, int] = (400, 400),
|
||||||
|
device=None,
|
||||||
|
load_dataset_pointcloud: bool = False,
|
||||||
|
max_frames: int = 30,
|
||||||
|
num_workers: int = 10,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self._render_size = render_size
|
||||||
|
point_cloud, _ = get_implicitron_sequence_pointcloud(
|
||||||
|
train_dataset,
|
||||||
|
sequence_name=sequence_name,
|
||||||
|
mask_points=True,
|
||||||
|
max_frames=max_frames,
|
||||||
|
num_workers=num_workers,
|
||||||
|
load_dataset_point_cloud=load_dataset_pointcloud,
|
||||||
|
)
|
||||||
|
self._point_cloud = point_cloud.to(device)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
camera: CamerasBase,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
image_render, mask_render, depth_render = render_point_cloud_pytorch3d(
|
||||||
|
camera[0],
|
||||||
|
self._point_cloud,
|
||||||
|
render_size=self._render_size,
|
||||||
|
point_radius=1e-2,
|
||||||
|
topk=10,
|
||||||
|
bg_color=0.0,
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"images_render": image_render.clamp(0.0, 1.0),
|
||||||
|
"masks_render": mask_render,
|
||||||
|
"depths_render": depth_render,
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user