diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 8c2f4fac..aa957e73 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -12,311 +12,60 @@ n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]" """ -import math import os -import random import sys from typing import Optional, Tuple import numpy as np import torch -import torch.nn.functional as Fu from omegaconf import OmegaConf -from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData -from pytorch3d.implicitron.dataset.utils import is_train_frame -from pytorch3d.implicitron.models.base_model import EvaluationMode +from pytorch3d.implicitron.models.visualization import render_flyaround from pytorch3d.implicitron.tools.configurable import get_default_args -from pytorch3d.implicitron.tools.eval_video_trajectory import ( - generate_eval_video_cameras, -) -from pytorch3d.implicitron.tools.video_writer import VideoWriter -from pytorch3d.implicitron.tools.vis_utils import ( - get_visdom_connection, - make_depth_image, -) -from tqdm import tqdm from .experiment import Experiment -def render_sequence( - dataset: DatasetBase, - sequence_name: str, - model: torch.nn.Module, - video_path, - n_eval_cameras=40, - fps=20, - max_angle=2 * math.pi, - trajectory_type="circular_lsq_fit", - trajectory_scale=1.1, - scene_center=(0.0, 0.0, 0.0), - up=(0.0, -1.0, 0.0), - traj_offset=0.0, - n_source_views=9, - viz_env="debug", - visdom_show_preds=False, - visdom_server="http://127.0.0.1", - visdom_port=8097, - num_workers=10, - seed=None, - video_resize=None, -): - if seed is None: - seed = hash(sequence_name) - - if visdom_show_preds: - viz = get_visdom_connection(server=visdom_server, port=visdom_port) - else: - viz = None - - print(f"Loading all data of sequence '{sequence_name}'.") - seq_idx = list(dataset.sequence_indices_in_order(sequence_name)) - train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers) - assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name) - sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test" - print(f"Sequence set = {sequence_set_name}.") - train_cameras = train_data.camera - time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras] - test_cameras = generate_eval_video_cameras( - train_cameras, - time=time, - n_eval_cams=n_eval_cameras, - trajectory_type=trajectory_type, - trajectory_scale=trajectory_scale, - scene_center=scene_center, - up=up, - focal_length=None, - principal_point=torch.zeros(n_eval_cameras, 2), - traj_offset_canonical=(0.0, 0.0, traj_offset), - ) - - # sample the source views reproducibly - with torch.random.fork_rng(): - torch.manual_seed(seed) - source_views_i = torch.randperm(len(seq_idx))[:n_source_views] - # add the first dummy view that will get replaced with the target camera - source_views_i = Fu.pad(source_views_i, [1, 0]) - source_views = [seq_idx[i] for i in source_views_i.tolist()] - batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers) - assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name) - - preds_total = [] - for n in tqdm(range(n_eval_cameras), total=n_eval_cameras): - # set the first batch camera to the target camera - for k in ("R", "T", "focal_length", "principal_point"): - getattr(batch.camera, k)[0] = getattr(test_cameras[n], k) - - # Move to cuda - net_input = batch.cuda() - with torch.no_grad(): - preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}) - - # make sure we dont overwrite something - assert all(k not in preds for k in net_input.keys()) - preds.update(net_input) # merge everything into one big dict - - # Render the predictions to images - rendered_pred = images_from_preds(preds) - preds_total.append(rendered_pred) - - # show the preds every 5% of the export iterations - if visdom_show_preds and ( - n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1 - ): - show_predictions( - preds_total, - sequence_name=batch.sequence_name[0], - viz=viz, - viz_env=viz_env, - ) - - print(f"Exporting videos for sequence {sequence_name} ...") - generate_prediction_videos( - preds_total, - sequence_name=batch.sequence_name[0], - viz=viz, - viz_env=viz_env, - fps=fps, - video_path=video_path, - resize=video_resize, - ) - - -def _load_whole_dataset(dataset, idx, num_workers=10): - load_all_dataloader = torch.utils.data.DataLoader( - torch.utils.data.Subset(dataset, idx), - batch_size=len(idx), - num_workers=num_workers, - shuffle=False, - collate_fn=FrameData.collate, - ) - return next(iter(load_all_dataloader)) - - -def images_from_preds(preds): - imout = {} - for k in ( - "image_rgb", - "images_render", - "fg_probability", - "masks_render", - "depths_render", - "depth_map", - "_all_source_images", - ): - if k == "_all_source_images" and "image_rgb" in preds: - src_ims = preds["image_rgb"][1:].cpu().detach().clone() - v = _stack_images(src_ims, None)[None] - else: - if k not in preds or preds[k] is None: - print(f"cant show {k}") - continue - v = preds[k].cpu().detach().clone() - if k.startswith("depth"): - mask_resize = Fu.interpolate( - preds["masks_render"], - size=preds[k].shape[2:], - mode="nearest", - ) - v = make_depth_image(preds[k], mask_resize) - if v.shape[1] == 1: - v = v.repeat(1, 3, 1, 1) - imout[k] = v.detach().cpu() - - return imout - - -def _stack_images(ims, size): - ba = ims.shape[0] - H = int(np.ceil(np.sqrt(ba))) - W = H - n_add = H * W - ba - if n_add > 0: - ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1))) - - ims = ims.view(H, W, *ims.shape[1:]) - cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1) - if size is not None: - cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0] - return cated.clamp(0.0, 1.0) - - -def show_predictions( - preds, - sequence_name, - viz, - viz_env="visualizer", - predicted_keys=( - "images_render", - "masks_render", - "depths_render", - "_all_source_images", - ), - n_samples=10, - one_image_width=200, -): - """Given a list of predictions visualize them into a single image using visdom.""" - assert isinstance(preds, list) - - pred_all = [] - # Randomly choose a subset of the rendered images, sort by ordr in the sequence - n_samples = min(n_samples, len(preds)) - pred_idx = sorted(random.sample(list(range(len(preds))), n_samples)) - for predi in pred_idx: - # Make the concatentation for the same camera vertically - pred_all.append( - torch.cat( - [ - torch.nn.functional.interpolate( - preds[predi][k].cpu(), - scale_factor=one_image_width / preds[predi][k].shape[3], - mode="bilinear", - ).clamp(0.0, 1.0) - for k in predicted_keys - ], - dim=2, - ) - ) - # Concatenate the images horizontally - pred_all_cat = torch.cat(pred_all, dim=3)[0] - viz.image( - pred_all_cat, - win="show_predictions", - env=viz_env, - opts={"title": f"pred_{sequence_name}"}, - ) - - -def generate_prediction_videos( - preds, - sequence_name, - viz=None, - viz_env="visualizer", - predicted_keys=( - "images_render", - "masks_render", - "depths_render", - "_all_source_images", - ), - fps=20, - video_path="/tmp/video", - resize=None, -): - """Given a list of predictions create and visualize rotating videos of the - objects using visdom. - """ - assert isinstance(preds, list) - - # make sure the target video directory exists - os.makedirs(os.path.dirname(video_path), exist_ok=True) - - # init a video writer for each predicted key - vws = {} - for k in predicted_keys: - vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps) - - for rendered_pred in tqdm(preds): - for k in predicted_keys: - vws[k].write_frame( - rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(), - resize=resize, - ) - - for k in predicted_keys: - vws[k].get_video(quiet=True) - print(f"Generated {vws[k].out_path}.") - if viz is not None: - viz.video( - videofile=vws[k].out_path, - env=viz_env, - win=k, # we reuse the same window otherwise visdom dies - opts={"title": sequence_name + " " + k}, - ) - - -def export_scenes( +def visualize_reconstruction( exp_dir: str = "", restrict_sequence_name: Optional[str] = None, output_directory: Optional[str] = None, render_size: Tuple[int, int] = (512, 512), video_size: Optional[Tuple[int, int]] = None, - split: str = "train", # train | val | test + split: str = "train", n_source_views: int = 9, n_eval_cameras: int = 40, - visdom_server="http://127.0.0.1", - visdom_port=8097, visdom_show_preds: bool = False, + visdom_server: str = "http://127.0.0.1", + visdom_port: int = 8097, visdom_env: Optional[str] = None, - gpu_idx: int = 0, ): + """ + Given an `exp_dir` containing a trained Implicitron model, generates videos consisting + of renderes of sequences from the dataset used to train and evaluate the trained + Implicitron model. + + Args: + exp_dir: Implicitron experiment directory. + restrict_sequence_name: If set, defines the list of sequences to visualize. + output_directory: If set, defines a custom directory to output visualizations to. + render_size: The size (HxW) of the generated renders. + video_size: The size (HxW) of the output video. + split: The dataset split to use for visualization. + Can be "train" / "val" / "test". + n_source_views: The number of source views added to each rendered batch. These + views are required inputs for models such as NeRFormer / NeRF-WCE. + n_eval_cameras: The number of cameras each fly-around trajectory. + visdom_show_preds: If `True`, outputs visualizations to visdom. + visdom_server: The address of the visdom server. + visdom_port: The port of the visdom server. + visdom_env: If set, defines a custom name for the visdom environment. + """ + # In case an output directory is specified use it. If no output_directory # is specified create a vis folder inside the experiment directory if output_directory is None: output_directory = os.path.join(exp_dir, "vis") - else: - output_directory = output_directory - if not os.path.exists(output_directory): - os.makedirs(output_directory) + os.makedirs(output_directory, exist_ok=True) # Set the random seeds torch.manual_seed(0) @@ -325,7 +74,6 @@ def export_scenes( # Get the config from the experiment_directory, # and overwrite relevant fields config = _get_config_from_experiment_directory(exp_dir) - config.gpu_idx = gpu_idx config.exp_dir = exp_dir # important so that the CO3D dataset gets loaded in full dataset_args = ( @@ -340,10 +88,6 @@ def export_scenes( if restrict_sequence_name is not None: dataset_args.restrict_sequence_name = restrict_sequence_name - # Set up the CUDA env for the visualization - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx) - # Load the previously trained model experiment = Experiment(config) model = experiment.model_factory(force_resume=True) @@ -360,17 +104,17 @@ def export_scenes( # iterate over the sequences in the dataset for sequence_name in dataset.sequence_names(): with torch.no_grad(): - render_sequence( - dataset, - sequence_name, - model, - video_path="{}/video".format(output_directory), + render_flyaround( + dataset=dataset, + sequence_name=sequence_name, + model=model, + output_video_path=os.path.join(output_directory, "video"), n_source_views=n_source_views, visdom_show_preds=visdom_show_preds, - n_eval_cameras=n_eval_cameras, + n_flyaround_poses=n_eval_cameras, visdom_server=visdom_server, visdom_port=visdom_port, - viz_env=f"visualizer_{config.visdom_env}" + visdom_environment=f"visualizer_{config.visdom_env}" if visdom_env is None else visdom_env, video_resize=video_size, @@ -384,11 +128,11 @@ def _get_config_from_experiment_directory(experiment_directory): def main(argv): - # automatically parses arguments of export_scenes - cfg = OmegaConf.create(get_default_args(export_scenes)) + # automatically parses arguments of visualize_reconstruction + cfg = OmegaConf.create(get_default_args(visualize_reconstruction)) cfg.update(OmegaConf.from_cli()) with torch.no_grad(): - export_scenes(**cfg) + visualize_reconstruction(**cfg) if __name__ == "__main__": diff --git a/pytorch3d/implicitron/models/visualization/__init__.py b/pytorch3d/implicitron/models/visualization/__init__.py new file mode 100644 index 00000000..a9fdb3b9 --- /dev/null +++ b/pytorch3d/implicitron/models/visualization/__init__.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pytorch3d/implicitron/models/visualization/render_flyaround.py b/pytorch3d/implicitron/models/visualization/render_flyaround.py new file mode 100644 index 00000000..a1634616 --- /dev/null +++ b/pytorch3d/implicitron/models/visualization/render_flyaround.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +import math +import os +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as Fu +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData +from pytorch3d.implicitron.dataset.utils import is_train_frame +from pytorch3d.implicitron.models.base_model import EvaluationMode +from pytorch3d.implicitron.tools.eval_video_trajectory import ( + generate_eval_video_cameras, +) +from pytorch3d.implicitron.tools.video_writer import VideoWriter +from pytorch3d.implicitron.tools.vis_utils import ( + get_visdom_connection, + make_depth_image, +) +from tqdm import tqdm +from visdom import Visdom + +logger = logging.getLogger(__name__) + + +def render_flyaround( + dataset: DatasetBase, + sequence_name: str, + model: torch.nn.Module, + output_video_path: str, + n_flyaround_poses: int = 40, + fps: int = 20, + trajectory_type: str = "circular_lsq_fit", + max_angle: float = 2 * math.pi, + trajectory_scale: float = 1.1, + scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0), + up: Tuple[float, float, float] = (0.0, -1.0, 0.0), + traj_offset: float = 0.0, + n_source_views: int = 9, + visdom_show_preds: bool = False, + visdom_environment: str = "render_flyaround", + visdom_server: str = "http://127.0.0.1", + visdom_port: int = 8097, + num_workers: int = 10, + device: Union[str, torch.device] = "cuda", + seed: Optional[int] = None, + video_resize: Optional[Tuple[int, int]] = None, + output_video_frames_dir: Optional[str] = None, + visualize_preds_keys: Sequence[str] = ( + "images_render", + "masks_render", + "depths_render", + "_all_source_images", + ), +): + """ + Uses `model` to generate a video consisting of renders of a scene imaged from + a camera flying around the scene. The scene is specified with the `dataset` object and + `sequence_name` which denotes the name of the scene whose frames are in `dataset`. + + Args: + dataset: The dataset object containing frames from a sequence in `sequence_name`. + sequence_name: Name of a sequence from `dataset`. + model: The model whose predictions are going to be visualized. + output_video_path: The path to the video output by this script. + n_flyaround_poses: The number of camera poses of the flyaround trajectory. + fps: Framerate of the output video. + trajectory_type: The type of the camera trajectory. Can be one of: + circular_lsq_fit: Camera centers follow a trajectory obtained + by fitting a 3D circle to train_cameras centers. + All cameras are looking towards scene_center. + figure_eight: Figure-of-8 trajectory around the center of the + central camera of the training dataset. + trefoil_knot: Same as 'figure_eight', but the trajectory has a shape + of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot). + figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape + of a figure-eight knot + (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)). + trajectory_type: The type of the camera trajectory. Can be one of: + circular_lsq_fit: Camera centers follow a trajectory obtained + by fitting a 3D circle to train_cameras centers. + All cameras are looking towards scene_center. + figure_eight: Figure-of-8 trajectory around the center of the + central camera of the training dataset. + trefoil_knot: Same as 'figure_eight', but the trajectory has a shape + of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot). + figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape + of a figure-eight knot + (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)). + max_angle: Defines the total length of the generated camera trajectory. + All possible trajectories (set with the `trajectory_type` argument) are + periodic with the period of `time==2pi`. + E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi` will generate + a trajectory of camera poses rotating the total of 720 deg around the object. + trajectory_scale: The extent of the trajectory. + scene_center: The center of the scene in world coordinates which all + the cameras from the generated trajectory look at. + up: The "up" vector of the scene (=the normal of the scene floor). + Active for the `trajectory_type="circular"`. + traj_offset: 3D offset vector added to each point of the trajectory. + n_source_views: The number of source views sampled from the known views of the + training sequence added to each evaluation batch. + visdom_show_preds: If `True`, exports the visualizations to visdom. + visdom_environment: The name of the visdom environment. + visdom_server: The address of the visdom server. + visdom_port: The visdom port. + num_workers: The number of workers used to load the training data. + seed: The random seed used for reproducible sampling of the source views. + video_resize: Optionally, defines the size of the output video. + output_video_frames_dir: If specified, the frames of the output video are going + to be permanently stored in this directory. + visualize_preds_keys: The names of the model predictions to visualize. + """ + + if seed is None: + seed = hash(sequence_name) + + if visdom_show_preds: + viz = get_visdom_connection(server=visdom_server, port=visdom_port) + else: + viz = None + + logger.info(f"Loading all data of sequence '{sequence_name}'.") + seq_idx = list(dataset.sequence_indices_in_order(sequence_name)) + train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers) + assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name) + sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test" + logger.info(f"Sequence set = {sequence_set_name}.") + train_cameras = train_data.camera + time = torch.linspace(0, max_angle, n_flyaround_poses + 1)[:n_flyaround_poses] + test_cameras = generate_eval_video_cameras( + train_cameras, + time=time, + n_eval_cams=n_flyaround_poses, + trajectory_type=trajectory_type, + trajectory_scale=trajectory_scale, + scene_center=scene_center, + up=up, + focal_length=None, + principal_point=torch.zeros(n_flyaround_poses, 2), + traj_offset_canonical=(0.0, 0.0, traj_offset), + ) + + # sample the source views reproducibly + with torch.random.fork_rng(): + torch.manual_seed(seed) + source_views_i = torch.randperm(len(seq_idx))[:n_source_views] + + # add the first dummy view that will get replaced with the target camera + source_views_i = Fu.pad(source_views_i, [1, 0]) + source_views = [seq_idx[i] for i in source_views_i.tolist()] + batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers) + assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name) + + preds_total = [] + for n in tqdm(range(n_flyaround_poses), total=n_flyaround_poses): + # set the first batch camera to the target camera + for k in ("R", "T", "focal_length", "principal_point"): + getattr(batch.camera, k)[0] = getattr(test_cameras[n], k) + + # Move to cuda + net_input = batch.to(device) + with torch.no_grad(): + preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}) + + # make sure we dont overwrite something + assert all(k not in preds for k in net_input.keys()) + preds.update(net_input) # merge everything into one big dict + + # Render the predictions to images + rendered_pred = _images_from_preds(preds) + preds_total.append(rendered_pred) + + # show the preds every 5% of the export iterations + if visdom_show_preds and ( + n % max(n_flyaround_poses // 20, 1) == 0 or n == n_flyaround_poses - 1 + ): + assert viz is not None + _show_predictions( + preds_total, + sequence_name=batch.sequence_name[0], + viz=viz, + viz_env=visdom_environment, + ) + + logger.info(f"Exporting videos for sequence {sequence_name} ...") + _generate_prediction_videos( + preds_total, + sequence_name=batch.sequence_name[0], + viz=viz, + viz_env=visdom_environment, + fps=fps, + video_path=output_video_path, + resize=video_resize, + video_frames_dir=output_video_frames_dir, + ) + + +def _load_whole_dataset( + dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10 +): + load_all_dataloader = torch.utils.data.DataLoader( + torch.utils.data.Subset(dataset, idx), + batch_size=len(idx), + num_workers=num_workers, + shuffle=False, + collate_fn=FrameData.collate, + ) + return next(iter(load_all_dataloader)) + + +def _images_from_preds(preds: Dict[str, Any]): + imout = {} + for k in ( + "image_rgb", + "images_render", + "fg_probability", + "masks_render", + "depths_render", + "depth_map", + "_all_source_images", + ): + if k == "_all_source_images" and "image_rgb" in preds: + src_ims = preds["image_rgb"][1:].cpu().detach().clone() + v = _stack_images(src_ims, None)[None] + else: + if k not in preds or preds[k] is None: + print(f"cant show {k}") + continue + v = preds[k].cpu().detach().clone() + if k.startswith("depth"): + mask_resize = Fu.interpolate( + preds["masks_render"], + size=preds[k].shape[2:], + mode="nearest", + ) + v = make_depth_image(preds[k], mask_resize) + if v.shape[1] == 1: + v = v.repeat(1, 3, 1, 1) + imout[k] = v.detach().cpu() + + return imout + + +def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]): + ba = ims.shape[0] + H = int(np.ceil(np.sqrt(ba))) + W = H + n_add = H * W - ba + if n_add > 0: + ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1))) + + ims = ims.view(H, W, *ims.shape[1:]) + cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1) + if size is not None: + cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0] + return cated.clamp(0.0, 1.0) + + +def _show_predictions( + preds: List[Dict[str, Any]], + sequence_name: str, + viz: Visdom, + viz_env: str = "visualizer", + predicted_keys: Sequence[str] = ( + "images_render", + "masks_render", + "depths_render", + "_all_source_images", + ), + n_samples=10, + one_image_width=200, +): + """Given a list of predictions visualize them into a single image using visdom.""" + assert isinstance(preds, list) + + pred_all = [] + # Randomly choose a subset of the rendered images, sort by ordr in the sequence + n_samples = min(n_samples, len(preds)) + pred_idx = sorted(random.sample(list(range(len(preds))), n_samples)) + for predi in pred_idx: + # Make the concatentation for the same camera vertically + pred_all.append( + torch.cat( + [ + torch.nn.functional.interpolate( + preds[predi][k].cpu(), + scale_factor=one_image_width / preds[predi][k].shape[3], + mode="bilinear", + ).clamp(0.0, 1.0) + for k in predicted_keys + ], + dim=2, + ) + ) + # Concatenate the images horizontally + pred_all_cat = torch.cat(pred_all, dim=3)[0] + viz.image( + pred_all_cat, + win="show_predictions", + env=viz_env, + opts={"title": f"pred_{sequence_name}"}, + ) + + +def _generate_prediction_videos( + preds: List[Dict[str, Any]], + sequence_name: str, + viz: Optional[Visdom] = None, + viz_env: str = "visualizer", + predicted_keys: Sequence[str] = ( + "images_render", + "masks_render", + "depths_render", + "_all_source_images", + ), + fps: int = 20, + video_path: str = "/tmp/video", + video_frames_dir: Optional[str] = None, + resize: Optional[Tuple[int, int]] = None, +): + """Given a list of predictions create and visualize rotating videos of the + objects using visdom. + """ + + # make sure the target video directory exists + os.makedirs(os.path.dirname(video_path), exist_ok=True) + + # init a video writer for each predicted key + vws = {} + for k in predicted_keys: + vws[k] = VideoWriter( + fps=fps, + out_path=f"{video_path}_{sequence_name}_{k}.mp4", + cache_dir=os.path.join(video_frames_dir, f"{sequence_name}_{k}"), + ) + + for rendered_pred in tqdm(preds): + for k in predicted_keys: + vws[k].write_frame( + rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(), + resize=resize, + ) + + for k in predicted_keys: + vws[k].get_video(quiet=True) + logger.info(f"Generated {vws[k].out_path}.") + if viz is not None: + viz.video( + videofile=vws[k].out_path, + env=viz_env, + win=k, # we reuse the same window otherwise visdom dies + opts={"title": sequence_name + " " + k}, + ) diff --git a/pytorch3d/implicitron/tools/eval_video_trajectory.py b/pytorch3d/implicitron/tools/eval_video_trajectory.py index 31c2b181..bda9ec29 100644 --- a/pytorch3d/implicitron/tools/eval_video_trajectory.py +++ b/pytorch3d/implicitron/tools/eval_video_trajectory.py @@ -37,7 +37,7 @@ def generate_eval_video_cameras( Generate a camera trajectory rendering a scene from multiple viewpoints. Args: - train_dataset: The training dataset object. + train_cameras: The set of cameras from the training dataset object. n_eval_cams: Number of cameras in the trajectory. trajectory_type: The type of the camera trajectory. Can be one of: circular_lsq_fit: Camera centers follow a trajectory obtained @@ -51,16 +51,30 @@ def generate_eval_video_cameras( of a figure-eight knot (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)). trajectory_scale: The extent of the trajectory. - up: The "up" vector of the scene (=the normal of the scene floor). - Active for the `trajectory_type="circular"`. scene_center: The center of the scene in world coordinates which all the cameras from the generated trajectory look at. + up: The "circular_lsq_fit" vector of the scene (=the normal of the scene floor). + Active for the `trajectory_type="circular"`. + focal_length: The focal length of the output cameras. If `None`, an average + focal length of the train_cameras is used. + principal_point: The principal point of the output cameras. If `None`, an average + principal point of all train_cameras is used. + time: Defines the total length of the generated camera trajectory. All possible + trajectories (set with the `trajectory_type` argument) are periodic with + the period of `time=2pi`. + E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi`, will generate + a trajectory of camera poses rotating the total of 720 deg around the object. + infer_up_as_plane_normal: Infer the camera `up` vector automatically as the normal + of the plane fit to the optical centers of `train_cameras`. + traj_offset: 3D offset vector added to each point of the trajectory. + traj_offset_canonical: 3D offset vector expressed in the local coordinates of + the estimated trajectory which is added to each point of the trajectory. remove_outliers_rate: the number between 0 and 1; if > 0, some outlier train_cameras will be removed from trajectory estimation; the filtering is based on camera center coordinates; top and bottom `remove_outliers_rate` cameras on each dimension are removed. Returns: - Dictionary of camera instances which can be used as the test dataset + Batch of camera instances which can be used as the test dataset """ if remove_outliers_rate > 0.0: train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate) diff --git a/tests/implicitron/common_resources.py b/tests/implicitron/common_resources.py index 41a83575..935f3ee1 100644 --- a/tests/implicitron/common_resources.py +++ b/tests/implicitron/common_resources.py @@ -68,7 +68,7 @@ def get_skateboard_data( if not os.environ.get("FB_TEST", False): if os.getenv("FAIR_ENV_CLUSTER", "") == "": raise unittest.SkipTest("Unknown environment. Data not available.") - yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", PathManager() + yield "/datasets01/co3d/081922", PathManager() elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False): from libfb.py.parutil import get_file_path diff --git a/tests/implicitron/test_model_visualize.py b/tests/implicitron/test_model_visualize.py new file mode 100644 index 00000000..1815d678 --- /dev/null +++ b/tests/implicitron/test_model_visualize.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import math +import os +import unittest +from typing import Tuple + +import torch +from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset +from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud + +from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround +from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d +from pytorch3d.renderer.cameras import CamerasBase +from tests.common_testing import interactive_testing_requested +from visdom import Visdom + +from .common_resources import get_skateboard_data + + +class TestModelVisualize(unittest.TestCase): + def test_flyaround_one_sequence( + self, + image_size: int = 256, + ): + if not interactive_testing_requested(): + return + category = "skateboard" + stack = contextlib.ExitStack() + dataset_root, path_manager = stack.enter_context(get_skateboard_data()) + self.addCleanup(stack.close) + frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") + sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") + subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") + expand_args_fields(JsonIndexDataset) + train_dataset = JsonIndexDataset( + frame_annotations_file=frame_file, + sequence_annotations_file=sequence_file, + subset_lists_file=subset_lists_file, + dataset_root=dataset_root, + image_height=image_size, + image_width=image_size, + box_crop=True, + load_point_clouds=True, + path_manager=path_manager, + subsets=[ + "train_known", + ], + ) + + # select few sequences to visualize + sequence_names = list(train_dataset.seq_annots.keys()) + + # select the first sequence name + show_sequence_name = sequence_names[0] + + output_dir = os.path.split(os.path.abspath(__file__))[0] + + visdom_show_preds = Visdom().check_connection() + + for load_dataset_pointcloud in [True, False]: + + model = _PointcloudRenderingModel( + train_dataset, + show_sequence_name, + device="cuda:0", + load_dataset_pointcloud=load_dataset_pointcloud, + ) + + video_path = os.path.join( + output_dir, + f"load_pcl_{load_dataset_pointcloud}", + ) + + os.makedirs(output_dir, exist_ok=True) + + render_flyaround( + train_dataset, + show_sequence_name, + model, + video_path, + n_flyaround_poses=40, + fps=20, + max_angle=2 * math.pi, + trajectory_type="circular_lsq_fit", + trajectory_scale=1.1, + scene_center=(0.0, 0.0, 0.0), + up=(0.0, 1.0, 0.0), + traj_offset=1.0, + n_source_views=1, + visdom_show_preds=visdom_show_preds, + visdom_environment="test_model_visalize", + visdom_server="http://127.0.0.1", + visdom_port=8097, + num_workers=10, + seed=None, + video_resize=None, + visualize_preds_keys=[ + "images_render", + "depths_render", + "masks_render", + "_all_source_images", + ], + output_video_frames_dir=video_path, + ) + + +class _PointcloudRenderingModel(torch.nn.Module): + def __init__( + self, + train_dataset: JsonIndexDataset, + sequence_name: str, + render_size: Tuple[int, int] = (400, 400), + device=None, + load_dataset_pointcloud: bool = False, + max_frames: int = 30, + num_workers: int = 10, + ): + super().__init__() + self._render_size = render_size + point_cloud, _ = get_implicitron_sequence_pointcloud( + train_dataset, + sequence_name=sequence_name, + mask_points=True, + max_frames=max_frames, + num_workers=num_workers, + load_dataset_point_cloud=load_dataset_pointcloud, + ) + self._point_cloud = point_cloud.to(device) + + def forward( + self, + camera: CamerasBase, + **kwargs, + ): + image_render, mask_render, depth_render = render_point_cloud_pytorch3d( + camera[0], + self._point_cloud, + render_size=self._render_size, + point_radius=1e-2, + topk=10, + bg_color=0.0, + ) + return { + "images_render": image_render.clamp(0.0, 1.0), + "masks_render": mask_render, + "depths_render": depth_render, + }