mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Rename and move render_flyaround into core implicitron
Summary: Move the flyaround rendering function into core implicitron. The unblocks an example in the facebookresearch/co3d repo. Reviewed By: bottler Differential Revision: D39257801 fbshipit-source-id: 6841a88a43d4aa364dd86ba83ca2d4c3cf0435a4
This commit is contained in:
		
							parent
							
								
									438c194ec6
								
							
						
					
					
						commit
						c79c954dea
					
				@ -12,311 +12,60 @@
 | 
			
		||||
    n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import sys
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as Fu
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
 | 
			
		||||
from pytorch3d.implicitron.models.visualization import render_flyaround
 | 
			
		||||
from pytorch3d.implicitron.tools.configurable import get_default_args
 | 
			
		||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
 | 
			
		||||
    generate_eval_video_cameras,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.video_writer import VideoWriter
 | 
			
		||||
from pytorch3d.implicitron.tools.vis_utils import (
 | 
			
		||||
    get_visdom_connection,
 | 
			
		||||
    make_depth_image,
 | 
			
		||||
)
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
from .experiment import Experiment
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def render_sequence(
 | 
			
		||||
    dataset: DatasetBase,
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    model: torch.nn.Module,
 | 
			
		||||
    video_path,
 | 
			
		||||
    n_eval_cameras=40,
 | 
			
		||||
    fps=20,
 | 
			
		||||
    max_angle=2 * math.pi,
 | 
			
		||||
    trajectory_type="circular_lsq_fit",
 | 
			
		||||
    trajectory_scale=1.1,
 | 
			
		||||
    scene_center=(0.0, 0.0, 0.0),
 | 
			
		||||
    up=(0.0, -1.0, 0.0),
 | 
			
		||||
    traj_offset=0.0,
 | 
			
		||||
    n_source_views=9,
 | 
			
		||||
    viz_env="debug",
 | 
			
		||||
    visdom_show_preds=False,
 | 
			
		||||
    visdom_server="http://127.0.0.1",
 | 
			
		||||
    visdom_port=8097,
 | 
			
		||||
    num_workers=10,
 | 
			
		||||
    seed=None,
 | 
			
		||||
    video_resize=None,
 | 
			
		||||
):
 | 
			
		||||
    if seed is None:
 | 
			
		||||
        seed = hash(sequence_name)
 | 
			
		||||
 | 
			
		||||
    if visdom_show_preds:
 | 
			
		||||
        viz = get_visdom_connection(server=visdom_server, port=visdom_port)
 | 
			
		||||
    else:
 | 
			
		||||
        viz = None
 | 
			
		||||
 | 
			
		||||
    print(f"Loading all data of sequence '{sequence_name}'.")
 | 
			
		||||
    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
 | 
			
		||||
    train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
 | 
			
		||||
    assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
 | 
			
		||||
    sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
 | 
			
		||||
    print(f"Sequence set = {sequence_set_name}.")
 | 
			
		||||
    train_cameras = train_data.camera
 | 
			
		||||
    time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
 | 
			
		||||
    test_cameras = generate_eval_video_cameras(
 | 
			
		||||
        train_cameras,
 | 
			
		||||
        time=time,
 | 
			
		||||
        n_eval_cams=n_eval_cameras,
 | 
			
		||||
        trajectory_type=trajectory_type,
 | 
			
		||||
        trajectory_scale=trajectory_scale,
 | 
			
		||||
        scene_center=scene_center,
 | 
			
		||||
        up=up,
 | 
			
		||||
        focal_length=None,
 | 
			
		||||
        principal_point=torch.zeros(n_eval_cameras, 2),
 | 
			
		||||
        traj_offset_canonical=(0.0, 0.0, traj_offset),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # sample the source views reproducibly
 | 
			
		||||
    with torch.random.fork_rng():
 | 
			
		||||
        torch.manual_seed(seed)
 | 
			
		||||
        source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
 | 
			
		||||
    # add the first dummy view that will get replaced with the target camera
 | 
			
		||||
    source_views_i = Fu.pad(source_views_i, [1, 0])
 | 
			
		||||
    source_views = [seq_idx[i] for i in source_views_i.tolist()]
 | 
			
		||||
    batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
 | 
			
		||||
    assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
 | 
			
		||||
 | 
			
		||||
    preds_total = []
 | 
			
		||||
    for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
 | 
			
		||||
        # set the first batch camera to the target camera
 | 
			
		||||
        for k in ("R", "T", "focal_length", "principal_point"):
 | 
			
		||||
            getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
 | 
			
		||||
 | 
			
		||||
        # Move to cuda
 | 
			
		||||
        net_input = batch.cuda()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
 | 
			
		||||
 | 
			
		||||
            # make sure we dont overwrite something
 | 
			
		||||
            assert all(k not in preds for k in net_input.keys())
 | 
			
		||||
            preds.update(net_input)  # merge everything into one big dict
 | 
			
		||||
 | 
			
		||||
            # Render the predictions to images
 | 
			
		||||
            rendered_pred = images_from_preds(preds)
 | 
			
		||||
            preds_total.append(rendered_pred)
 | 
			
		||||
 | 
			
		||||
            # show the preds every 5% of the export iterations
 | 
			
		||||
            if visdom_show_preds and (
 | 
			
		||||
                n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
 | 
			
		||||
            ):
 | 
			
		||||
                show_predictions(
 | 
			
		||||
                    preds_total,
 | 
			
		||||
                    sequence_name=batch.sequence_name[0],
 | 
			
		||||
                    viz=viz,
 | 
			
		||||
                    viz_env=viz_env,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    print(f"Exporting videos for sequence {sequence_name} ...")
 | 
			
		||||
    generate_prediction_videos(
 | 
			
		||||
        preds_total,
 | 
			
		||||
        sequence_name=batch.sequence_name[0],
 | 
			
		||||
        viz=viz,
 | 
			
		||||
        viz_env=viz_env,
 | 
			
		||||
        fps=fps,
 | 
			
		||||
        video_path=video_path,
 | 
			
		||||
        resize=video_resize,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_whole_dataset(dataset, idx, num_workers=10):
 | 
			
		||||
    load_all_dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
        torch.utils.data.Subset(dataset, idx),
 | 
			
		||||
        batch_size=len(idx),
 | 
			
		||||
        num_workers=num_workers,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        collate_fn=FrameData.collate,
 | 
			
		||||
    )
 | 
			
		||||
    return next(iter(load_all_dataloader))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def images_from_preds(preds):
 | 
			
		||||
    imout = {}
 | 
			
		||||
    for k in (
 | 
			
		||||
        "image_rgb",
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "fg_probability",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "depth_map",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ):
 | 
			
		||||
        if k == "_all_source_images" and "image_rgb" in preds:
 | 
			
		||||
            src_ims = preds["image_rgb"][1:].cpu().detach().clone()
 | 
			
		||||
            v = _stack_images(src_ims, None)[None]
 | 
			
		||||
        else:
 | 
			
		||||
            if k not in preds or preds[k] is None:
 | 
			
		||||
                print(f"cant show {k}")
 | 
			
		||||
                continue
 | 
			
		||||
            v = preds[k].cpu().detach().clone()
 | 
			
		||||
        if k.startswith("depth"):
 | 
			
		||||
            mask_resize = Fu.interpolate(
 | 
			
		||||
                preds["masks_render"],
 | 
			
		||||
                size=preds[k].shape[2:],
 | 
			
		||||
                mode="nearest",
 | 
			
		||||
            )
 | 
			
		||||
            v = make_depth_image(preds[k], mask_resize)
 | 
			
		||||
        if v.shape[1] == 1:
 | 
			
		||||
            v = v.repeat(1, 3, 1, 1)
 | 
			
		||||
        imout[k] = v.detach().cpu()
 | 
			
		||||
 | 
			
		||||
    return imout
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _stack_images(ims, size):
 | 
			
		||||
    ba = ims.shape[0]
 | 
			
		||||
    H = int(np.ceil(np.sqrt(ba)))
 | 
			
		||||
    W = H
 | 
			
		||||
    n_add = H * W - ba
 | 
			
		||||
    if n_add > 0:
 | 
			
		||||
        ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
 | 
			
		||||
 | 
			
		||||
    ims = ims.view(H, W, *ims.shape[1:])
 | 
			
		||||
    cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
 | 
			
		||||
    if size is not None:
 | 
			
		||||
        cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
 | 
			
		||||
    return cated.clamp(0.0, 1.0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def show_predictions(
 | 
			
		||||
    preds,
 | 
			
		||||
    sequence_name,
 | 
			
		||||
    viz,
 | 
			
		||||
    viz_env="visualizer",
 | 
			
		||||
    predicted_keys=(
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ),
 | 
			
		||||
    n_samples=10,
 | 
			
		||||
    one_image_width=200,
 | 
			
		||||
):
 | 
			
		||||
    """Given a list of predictions visualize them into a single image using visdom."""
 | 
			
		||||
    assert isinstance(preds, list)
 | 
			
		||||
 | 
			
		||||
    pred_all = []
 | 
			
		||||
    # Randomly choose a subset of the rendered images, sort by ordr in the sequence
 | 
			
		||||
    n_samples = min(n_samples, len(preds))
 | 
			
		||||
    pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
 | 
			
		||||
    for predi in pred_idx:
 | 
			
		||||
        # Make the concatentation for the same camera vertically
 | 
			
		||||
        pred_all.append(
 | 
			
		||||
            torch.cat(
 | 
			
		||||
                [
 | 
			
		||||
                    torch.nn.functional.interpolate(
 | 
			
		||||
                        preds[predi][k].cpu(),
 | 
			
		||||
                        scale_factor=one_image_width / preds[predi][k].shape[3],
 | 
			
		||||
                        mode="bilinear",
 | 
			
		||||
                    ).clamp(0.0, 1.0)
 | 
			
		||||
                    for k in predicted_keys
 | 
			
		||||
                ],
 | 
			
		||||
                dim=2,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    # Concatenate the images horizontally
 | 
			
		||||
    pred_all_cat = torch.cat(pred_all, dim=3)[0]
 | 
			
		||||
    viz.image(
 | 
			
		||||
        pred_all_cat,
 | 
			
		||||
        win="show_predictions",
 | 
			
		||||
        env=viz_env,
 | 
			
		||||
        opts={"title": f"pred_{sequence_name}"},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_prediction_videos(
 | 
			
		||||
    preds,
 | 
			
		||||
    sequence_name,
 | 
			
		||||
    viz=None,
 | 
			
		||||
    viz_env="visualizer",
 | 
			
		||||
    predicted_keys=(
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ),
 | 
			
		||||
    fps=20,
 | 
			
		||||
    video_path="/tmp/video",
 | 
			
		||||
    resize=None,
 | 
			
		||||
):
 | 
			
		||||
    """Given a list of predictions create and visualize rotating videos of the
 | 
			
		||||
    objects using visdom.
 | 
			
		||||
    """
 | 
			
		||||
    assert isinstance(preds, list)
 | 
			
		||||
 | 
			
		||||
    # make sure the target video directory exists
 | 
			
		||||
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    # init a video writer for each predicted key
 | 
			
		||||
    vws = {}
 | 
			
		||||
    for k in predicted_keys:
 | 
			
		||||
        vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
 | 
			
		||||
 | 
			
		||||
    for rendered_pred in tqdm(preds):
 | 
			
		||||
        for k in predicted_keys:
 | 
			
		||||
            vws[k].write_frame(
 | 
			
		||||
                rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
 | 
			
		||||
                resize=resize,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    for k in predicted_keys:
 | 
			
		||||
        vws[k].get_video(quiet=True)
 | 
			
		||||
        print(f"Generated {vws[k].out_path}.")
 | 
			
		||||
        if viz is not None:
 | 
			
		||||
            viz.video(
 | 
			
		||||
                videofile=vws[k].out_path,
 | 
			
		||||
                env=viz_env,
 | 
			
		||||
                win=k,  # we reuse the same window otherwise visdom dies
 | 
			
		||||
                opts={"title": sequence_name + " " + k},
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def export_scenes(
 | 
			
		||||
def visualize_reconstruction(
 | 
			
		||||
    exp_dir: str = "",
 | 
			
		||||
    restrict_sequence_name: Optional[str] = None,
 | 
			
		||||
    output_directory: Optional[str] = None,
 | 
			
		||||
    render_size: Tuple[int, int] = (512, 512),
 | 
			
		||||
    video_size: Optional[Tuple[int, int]] = None,
 | 
			
		||||
    split: str = "train",  # train | val | test
 | 
			
		||||
    split: str = "train",
 | 
			
		||||
    n_source_views: int = 9,
 | 
			
		||||
    n_eval_cameras: int = 40,
 | 
			
		||||
    visdom_server="http://127.0.0.1",
 | 
			
		||||
    visdom_port=8097,
 | 
			
		||||
    visdom_show_preds: bool = False,
 | 
			
		||||
    visdom_server: str = "http://127.0.0.1",
 | 
			
		||||
    visdom_port: int = 8097,
 | 
			
		||||
    visdom_env: Optional[str] = None,
 | 
			
		||||
    gpu_idx: int = 0,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
 | 
			
		||||
    of renderes of sequences from the dataset used to train and evaluate the trained
 | 
			
		||||
    Implicitron model.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        exp_dir: Implicitron experiment directory.
 | 
			
		||||
        restrict_sequence_name: If set, defines the list of sequences to visualize.
 | 
			
		||||
        output_directory: If set, defines a custom directory to output visualizations to.
 | 
			
		||||
        render_size: The size (HxW) of the generated renders.
 | 
			
		||||
        video_size: The size (HxW) of the output video.
 | 
			
		||||
        split: The dataset split to use for visualization.
 | 
			
		||||
            Can be "train" / "val" / "test".
 | 
			
		||||
        n_source_views: The number of source views added to each rendered batch. These
 | 
			
		||||
            views are required inputs for models such as NeRFormer / NeRF-WCE.
 | 
			
		||||
        n_eval_cameras: The number of cameras each fly-around trajectory.
 | 
			
		||||
        visdom_show_preds: If `True`, outputs visualizations to visdom.
 | 
			
		||||
        visdom_server: The address of the visdom server.
 | 
			
		||||
        visdom_port: The port of the visdom server.
 | 
			
		||||
        visdom_env: If set, defines a custom name for the visdom environment.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # In case an output directory is specified use it. If no output_directory
 | 
			
		||||
    # is specified create a vis folder inside the experiment directory
 | 
			
		||||
    if output_directory is None:
 | 
			
		||||
        output_directory = os.path.join(exp_dir, "vis")
 | 
			
		||||
    else:
 | 
			
		||||
        output_directory = output_directory
 | 
			
		||||
    if not os.path.exists(output_directory):
 | 
			
		||||
        os.makedirs(output_directory)
 | 
			
		||||
    os.makedirs(output_directory, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    # Set the random seeds
 | 
			
		||||
    torch.manual_seed(0)
 | 
			
		||||
@ -325,7 +74,6 @@ def export_scenes(
 | 
			
		||||
    # Get the config from the experiment_directory,
 | 
			
		||||
    # and overwrite relevant fields
 | 
			
		||||
    config = _get_config_from_experiment_directory(exp_dir)
 | 
			
		||||
    config.gpu_idx = gpu_idx
 | 
			
		||||
    config.exp_dir = exp_dir
 | 
			
		||||
    # important so that the CO3D dataset gets loaded in full
 | 
			
		||||
    dataset_args = (
 | 
			
		||||
@ -340,10 +88,6 @@ def export_scenes(
 | 
			
		||||
    if restrict_sequence_name is not None:
 | 
			
		||||
        dataset_args.restrict_sequence_name = restrict_sequence_name
 | 
			
		||||
 | 
			
		||||
    # Set up the CUDA env for the visualization
 | 
			
		||||
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
 | 
			
		||||
    os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
 | 
			
		||||
 | 
			
		||||
    # Load the previously trained model
 | 
			
		||||
    experiment = Experiment(config)
 | 
			
		||||
    model = experiment.model_factory(force_resume=True)
 | 
			
		||||
@ -360,17 +104,17 @@ def export_scenes(
 | 
			
		||||
    # iterate over the sequences in the dataset
 | 
			
		||||
    for sequence_name in dataset.sequence_names():
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            render_sequence(
 | 
			
		||||
                dataset,
 | 
			
		||||
                sequence_name,
 | 
			
		||||
                model,
 | 
			
		||||
                video_path="{}/video".format(output_directory),
 | 
			
		||||
            render_flyaround(
 | 
			
		||||
                dataset=dataset,
 | 
			
		||||
                sequence_name=sequence_name,
 | 
			
		||||
                model=model,
 | 
			
		||||
                output_video_path=os.path.join(output_directory, "video"),
 | 
			
		||||
                n_source_views=n_source_views,
 | 
			
		||||
                visdom_show_preds=visdom_show_preds,
 | 
			
		||||
                n_eval_cameras=n_eval_cameras,
 | 
			
		||||
                n_flyaround_poses=n_eval_cameras,
 | 
			
		||||
                visdom_server=visdom_server,
 | 
			
		||||
                visdom_port=visdom_port,
 | 
			
		||||
                viz_env=f"visualizer_{config.visdom_env}"
 | 
			
		||||
                visdom_environment=f"visualizer_{config.visdom_env}"
 | 
			
		||||
                if visdom_env is None
 | 
			
		||||
                else visdom_env,
 | 
			
		||||
                video_resize=video_size,
 | 
			
		||||
@ -384,11 +128,11 @@ def _get_config_from_experiment_directory(experiment_directory):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(argv):
 | 
			
		||||
    # automatically parses arguments of export_scenes
 | 
			
		||||
    cfg = OmegaConf.create(get_default_args(export_scenes))
 | 
			
		||||
    # automatically parses arguments of visualize_reconstruction
 | 
			
		||||
    cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
 | 
			
		||||
    cfg.update(OmegaConf.from_cli())
 | 
			
		||||
    with torch.no_grad():
 | 
			
		||||
        export_scenes(**cfg)
 | 
			
		||||
        visualize_reconstruction(**cfg)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								pytorch3d/implicitron/models/visualization/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								pytorch3d/implicitron/models/visualization/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
							
								
								
									
										363
									
								
								pytorch3d/implicitron/models/visualization/render_flyaround.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										363
									
								
								pytorch3d/implicitron/models/visualization/render_flyaround.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,363 @@
 | 
			
		||||
#!/usr/bin/env python3
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as Fu
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
 | 
			
		||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
 | 
			
		||||
    generate_eval_video_cameras,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.video_writer import VideoWriter
 | 
			
		||||
from pytorch3d.implicitron.tools.vis_utils import (
 | 
			
		||||
    get_visdom_connection,
 | 
			
		||||
    make_depth_image,
 | 
			
		||||
)
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def render_flyaround(
 | 
			
		||||
    dataset: DatasetBase,
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    model: torch.nn.Module,
 | 
			
		||||
    output_video_path: str,
 | 
			
		||||
    n_flyaround_poses: int = 40,
 | 
			
		||||
    fps: int = 20,
 | 
			
		||||
    trajectory_type: str = "circular_lsq_fit",
 | 
			
		||||
    max_angle: float = 2 * math.pi,
 | 
			
		||||
    trajectory_scale: float = 1.1,
 | 
			
		||||
    scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
 | 
			
		||||
    up: Tuple[float, float, float] = (0.0, -1.0, 0.0),
 | 
			
		||||
    traj_offset: float = 0.0,
 | 
			
		||||
    n_source_views: int = 9,
 | 
			
		||||
    visdom_show_preds: bool = False,
 | 
			
		||||
    visdom_environment: str = "render_flyaround",
 | 
			
		||||
    visdom_server: str = "http://127.0.0.1",
 | 
			
		||||
    visdom_port: int = 8097,
 | 
			
		||||
    num_workers: int = 10,
 | 
			
		||||
    device: Union[str, torch.device] = "cuda",
 | 
			
		||||
    seed: Optional[int] = None,
 | 
			
		||||
    video_resize: Optional[Tuple[int, int]] = None,
 | 
			
		||||
    output_video_frames_dir: Optional[str] = None,
 | 
			
		||||
    visualize_preds_keys: Sequence[str] = (
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ),
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Uses `model` to generate a video consisting of renders of a scene imaged from
 | 
			
		||||
    a camera flying around the scene. The scene is specified with the `dataset` object and
 | 
			
		||||
    `sequence_name` which denotes the name of the scene whose frames are in `dataset`.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        dataset: The dataset object containing frames from a sequence in `sequence_name`.
 | 
			
		||||
        sequence_name: Name of a sequence from `dataset`.
 | 
			
		||||
        model: The model whose predictions are going to be visualized.
 | 
			
		||||
        output_video_path: The path to the video output by this script.
 | 
			
		||||
        n_flyaround_poses: The number of camera poses of the flyaround trajectory.
 | 
			
		||||
        fps: Framerate of the output video.
 | 
			
		||||
        trajectory_type: The type of the camera trajectory. Can be one of:
 | 
			
		||||
            circular_lsq_fit: Camera centers follow a trajectory obtained
 | 
			
		||||
                by fitting a 3D circle to train_cameras centers.
 | 
			
		||||
                All cameras are looking towards scene_center.
 | 
			
		||||
            figure_eight: Figure-of-8 trajectory around the center of the
 | 
			
		||||
                central camera of the training dataset.
 | 
			
		||||
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
 | 
			
		||||
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
 | 
			
		||||
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
 | 
			
		||||
                of a figure-eight knot
 | 
			
		||||
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
 | 
			
		||||
        trajectory_type: The type of the camera trajectory. Can be one of:
 | 
			
		||||
            circular_lsq_fit: Camera centers follow a trajectory obtained
 | 
			
		||||
                by fitting a 3D circle to train_cameras centers.
 | 
			
		||||
                All cameras are looking towards scene_center.
 | 
			
		||||
            figure_eight: Figure-of-8 trajectory around the center of the
 | 
			
		||||
                central camera of the training dataset.
 | 
			
		||||
            trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
 | 
			
		||||
                of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
 | 
			
		||||
            figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
 | 
			
		||||
                of a figure-eight knot
 | 
			
		||||
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
 | 
			
		||||
        max_angle: Defines the total length of the generated camera trajectory.
 | 
			
		||||
            All possible trajectories (set with the `trajectory_type` argument) are
 | 
			
		||||
            periodic with the period of `time==2pi`.
 | 
			
		||||
            E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi` will generate
 | 
			
		||||
            a trajectory of camera poses rotating the total of 720 deg around the object.
 | 
			
		||||
        trajectory_scale: The extent of the trajectory.
 | 
			
		||||
        scene_center: The center of the scene in world coordinates which all
 | 
			
		||||
            the cameras from the generated trajectory look at.
 | 
			
		||||
        up: The "up" vector of the scene (=the normal of the scene floor).
 | 
			
		||||
            Active for the `trajectory_type="circular"`.
 | 
			
		||||
        traj_offset: 3D offset vector added to each point of the trajectory.
 | 
			
		||||
        n_source_views: The number of source views sampled from the known views of the
 | 
			
		||||
            training sequence added to each evaluation batch.
 | 
			
		||||
        visdom_show_preds: If `True`, exports the visualizations to visdom.
 | 
			
		||||
        visdom_environment: The name of the visdom environment.
 | 
			
		||||
        visdom_server: The address of the visdom server.
 | 
			
		||||
        visdom_port: The visdom port.
 | 
			
		||||
        num_workers: The number of workers used to load the training data.
 | 
			
		||||
        seed: The random seed used for reproducible sampling of the source views.
 | 
			
		||||
        video_resize: Optionally, defines the size of the output video.
 | 
			
		||||
        output_video_frames_dir: If specified, the frames of the output video are going
 | 
			
		||||
            to be permanently stored in this directory.
 | 
			
		||||
        visualize_preds_keys: The names of the model predictions to visualize.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if seed is None:
 | 
			
		||||
        seed = hash(sequence_name)
 | 
			
		||||
 | 
			
		||||
    if visdom_show_preds:
 | 
			
		||||
        viz = get_visdom_connection(server=visdom_server, port=visdom_port)
 | 
			
		||||
    else:
 | 
			
		||||
        viz = None
 | 
			
		||||
 | 
			
		||||
    logger.info(f"Loading all data of sequence '{sequence_name}'.")
 | 
			
		||||
    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
 | 
			
		||||
    train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
 | 
			
		||||
    assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
 | 
			
		||||
    sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
 | 
			
		||||
    logger.info(f"Sequence set = {sequence_set_name}.")
 | 
			
		||||
    train_cameras = train_data.camera
 | 
			
		||||
    time = torch.linspace(0, max_angle, n_flyaround_poses + 1)[:n_flyaround_poses]
 | 
			
		||||
    test_cameras = generate_eval_video_cameras(
 | 
			
		||||
        train_cameras,
 | 
			
		||||
        time=time,
 | 
			
		||||
        n_eval_cams=n_flyaround_poses,
 | 
			
		||||
        trajectory_type=trajectory_type,
 | 
			
		||||
        trajectory_scale=trajectory_scale,
 | 
			
		||||
        scene_center=scene_center,
 | 
			
		||||
        up=up,
 | 
			
		||||
        focal_length=None,
 | 
			
		||||
        principal_point=torch.zeros(n_flyaround_poses, 2),
 | 
			
		||||
        traj_offset_canonical=(0.0, 0.0, traj_offset),
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # sample the source views reproducibly
 | 
			
		||||
    with torch.random.fork_rng():
 | 
			
		||||
        torch.manual_seed(seed)
 | 
			
		||||
        source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
 | 
			
		||||
 | 
			
		||||
    # add the first dummy view that will get replaced with the target camera
 | 
			
		||||
    source_views_i = Fu.pad(source_views_i, [1, 0])
 | 
			
		||||
    source_views = [seq_idx[i] for i in source_views_i.tolist()]
 | 
			
		||||
    batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
 | 
			
		||||
    assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
 | 
			
		||||
 | 
			
		||||
    preds_total = []
 | 
			
		||||
    for n in tqdm(range(n_flyaround_poses), total=n_flyaround_poses):
 | 
			
		||||
        # set the first batch camera to the target camera
 | 
			
		||||
        for k in ("R", "T", "focal_length", "principal_point"):
 | 
			
		||||
            getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
 | 
			
		||||
 | 
			
		||||
        # Move to cuda
 | 
			
		||||
        net_input = batch.to(device)
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
 | 
			
		||||
 | 
			
		||||
            # make sure we dont overwrite something
 | 
			
		||||
            assert all(k not in preds for k in net_input.keys())
 | 
			
		||||
            preds.update(net_input)  # merge everything into one big dict
 | 
			
		||||
 | 
			
		||||
            # Render the predictions to images
 | 
			
		||||
            rendered_pred = _images_from_preds(preds)
 | 
			
		||||
            preds_total.append(rendered_pred)
 | 
			
		||||
 | 
			
		||||
            # show the preds every 5% of the export iterations
 | 
			
		||||
            if visdom_show_preds and (
 | 
			
		||||
                n % max(n_flyaround_poses // 20, 1) == 0 or n == n_flyaround_poses - 1
 | 
			
		||||
            ):
 | 
			
		||||
                assert viz is not None
 | 
			
		||||
                _show_predictions(
 | 
			
		||||
                    preds_total,
 | 
			
		||||
                    sequence_name=batch.sequence_name[0],
 | 
			
		||||
                    viz=viz,
 | 
			
		||||
                    viz_env=visdom_environment,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    logger.info(f"Exporting videos for sequence {sequence_name} ...")
 | 
			
		||||
    _generate_prediction_videos(
 | 
			
		||||
        preds_total,
 | 
			
		||||
        sequence_name=batch.sequence_name[0],
 | 
			
		||||
        viz=viz,
 | 
			
		||||
        viz_env=visdom_environment,
 | 
			
		||||
        fps=fps,
 | 
			
		||||
        video_path=output_video_path,
 | 
			
		||||
        resize=video_resize,
 | 
			
		||||
        video_frames_dir=output_video_frames_dir,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _load_whole_dataset(
 | 
			
		||||
    dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
 | 
			
		||||
):
 | 
			
		||||
    load_all_dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
        torch.utils.data.Subset(dataset, idx),
 | 
			
		||||
        batch_size=len(idx),
 | 
			
		||||
        num_workers=num_workers,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        collate_fn=FrameData.collate,
 | 
			
		||||
    )
 | 
			
		||||
    return next(iter(load_all_dataloader))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _images_from_preds(preds: Dict[str, Any]):
 | 
			
		||||
    imout = {}
 | 
			
		||||
    for k in (
 | 
			
		||||
        "image_rgb",
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "fg_probability",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "depth_map",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ):
 | 
			
		||||
        if k == "_all_source_images" and "image_rgb" in preds:
 | 
			
		||||
            src_ims = preds["image_rgb"][1:].cpu().detach().clone()
 | 
			
		||||
            v = _stack_images(src_ims, None)[None]
 | 
			
		||||
        else:
 | 
			
		||||
            if k not in preds or preds[k] is None:
 | 
			
		||||
                print(f"cant show {k}")
 | 
			
		||||
                continue
 | 
			
		||||
            v = preds[k].cpu().detach().clone()
 | 
			
		||||
        if k.startswith("depth"):
 | 
			
		||||
            mask_resize = Fu.interpolate(
 | 
			
		||||
                preds["masks_render"],
 | 
			
		||||
                size=preds[k].shape[2:],
 | 
			
		||||
                mode="nearest",
 | 
			
		||||
            )
 | 
			
		||||
            v = make_depth_image(preds[k], mask_resize)
 | 
			
		||||
        if v.shape[1] == 1:
 | 
			
		||||
            v = v.repeat(1, 3, 1, 1)
 | 
			
		||||
        imout[k] = v.detach().cpu()
 | 
			
		||||
 | 
			
		||||
    return imout
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
 | 
			
		||||
    ba = ims.shape[0]
 | 
			
		||||
    H = int(np.ceil(np.sqrt(ba)))
 | 
			
		||||
    W = H
 | 
			
		||||
    n_add = H * W - ba
 | 
			
		||||
    if n_add > 0:
 | 
			
		||||
        ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
 | 
			
		||||
 | 
			
		||||
    ims = ims.view(H, W, *ims.shape[1:])
 | 
			
		||||
    cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
 | 
			
		||||
    if size is not None:
 | 
			
		||||
        cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
 | 
			
		||||
    return cated.clamp(0.0, 1.0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _show_predictions(
 | 
			
		||||
    preds: List[Dict[str, Any]],
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    viz: Visdom,
 | 
			
		||||
    viz_env: str = "visualizer",
 | 
			
		||||
    predicted_keys: Sequence[str] = (
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ),
 | 
			
		||||
    n_samples=10,
 | 
			
		||||
    one_image_width=200,
 | 
			
		||||
):
 | 
			
		||||
    """Given a list of predictions visualize them into a single image using visdom."""
 | 
			
		||||
    assert isinstance(preds, list)
 | 
			
		||||
 | 
			
		||||
    pred_all = []
 | 
			
		||||
    # Randomly choose a subset of the rendered images, sort by ordr in the sequence
 | 
			
		||||
    n_samples = min(n_samples, len(preds))
 | 
			
		||||
    pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
 | 
			
		||||
    for predi in pred_idx:
 | 
			
		||||
        # Make the concatentation for the same camera vertically
 | 
			
		||||
        pred_all.append(
 | 
			
		||||
            torch.cat(
 | 
			
		||||
                [
 | 
			
		||||
                    torch.nn.functional.interpolate(
 | 
			
		||||
                        preds[predi][k].cpu(),
 | 
			
		||||
                        scale_factor=one_image_width / preds[predi][k].shape[3],
 | 
			
		||||
                        mode="bilinear",
 | 
			
		||||
                    ).clamp(0.0, 1.0)
 | 
			
		||||
                    for k in predicted_keys
 | 
			
		||||
                ],
 | 
			
		||||
                dim=2,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    # Concatenate the images horizontally
 | 
			
		||||
    pred_all_cat = torch.cat(pred_all, dim=3)[0]
 | 
			
		||||
    viz.image(
 | 
			
		||||
        pred_all_cat,
 | 
			
		||||
        win="show_predictions",
 | 
			
		||||
        env=viz_env,
 | 
			
		||||
        opts={"title": f"pred_{sequence_name}"},
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _generate_prediction_videos(
 | 
			
		||||
    preds: List[Dict[str, Any]],
 | 
			
		||||
    sequence_name: str,
 | 
			
		||||
    viz: Optional[Visdom] = None,
 | 
			
		||||
    viz_env: str = "visualizer",
 | 
			
		||||
    predicted_keys: Sequence[str] = (
 | 
			
		||||
        "images_render",
 | 
			
		||||
        "masks_render",
 | 
			
		||||
        "depths_render",
 | 
			
		||||
        "_all_source_images",
 | 
			
		||||
    ),
 | 
			
		||||
    fps: int = 20,
 | 
			
		||||
    video_path: str = "/tmp/video",
 | 
			
		||||
    video_frames_dir: Optional[str] = None,
 | 
			
		||||
    resize: Optional[Tuple[int, int]] = None,
 | 
			
		||||
):
 | 
			
		||||
    """Given a list of predictions create and visualize rotating videos of the
 | 
			
		||||
    objects using visdom.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # make sure the target video directory exists
 | 
			
		||||
    os.makedirs(os.path.dirname(video_path), exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    # init a video writer for each predicted key
 | 
			
		||||
    vws = {}
 | 
			
		||||
    for k in predicted_keys:
 | 
			
		||||
        vws[k] = VideoWriter(
 | 
			
		||||
            fps=fps,
 | 
			
		||||
            out_path=f"{video_path}_{sequence_name}_{k}.mp4",
 | 
			
		||||
            cache_dir=os.path.join(video_frames_dir, f"{sequence_name}_{k}"),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    for rendered_pred in tqdm(preds):
 | 
			
		||||
        for k in predicted_keys:
 | 
			
		||||
            vws[k].write_frame(
 | 
			
		||||
                rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(),
 | 
			
		||||
                resize=resize,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    for k in predicted_keys:
 | 
			
		||||
        vws[k].get_video(quiet=True)
 | 
			
		||||
        logger.info(f"Generated {vws[k].out_path}.")
 | 
			
		||||
        if viz is not None:
 | 
			
		||||
            viz.video(
 | 
			
		||||
                videofile=vws[k].out_path,
 | 
			
		||||
                env=viz_env,
 | 
			
		||||
                win=k,  # we reuse the same window otherwise visdom dies
 | 
			
		||||
                opts={"title": sequence_name + " " + k},
 | 
			
		||||
            )
 | 
			
		||||
@ -37,7 +37,7 @@ def generate_eval_video_cameras(
 | 
			
		||||
    Generate a camera trajectory rendering a scene from multiple viewpoints.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        train_dataset: The training dataset object.
 | 
			
		||||
        train_cameras: The set of cameras from the training dataset object.
 | 
			
		||||
        n_eval_cams: Number of cameras in the trajectory.
 | 
			
		||||
        trajectory_type: The type of the camera trajectory. Can be one of:
 | 
			
		||||
            circular_lsq_fit: Camera centers follow a trajectory obtained
 | 
			
		||||
@ -51,16 +51,30 @@ def generate_eval_video_cameras(
 | 
			
		||||
                of a figure-eight knot
 | 
			
		||||
                (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
 | 
			
		||||
        trajectory_scale: The extent of the trajectory.
 | 
			
		||||
        up: The "up" vector of the scene (=the normal of the scene floor).
 | 
			
		||||
            Active for the `trajectory_type="circular"`.
 | 
			
		||||
        scene_center: The center of the scene in world coordinates which all
 | 
			
		||||
            the cameras from the generated trajectory look at.
 | 
			
		||||
        up: The "circular_lsq_fit" vector of the scene (=the normal of the scene floor).
 | 
			
		||||
            Active for the `trajectory_type="circular"`.
 | 
			
		||||
        focal_length: The focal length of the output cameras. If `None`, an average
 | 
			
		||||
            focal length of the train_cameras is used.
 | 
			
		||||
        principal_point: The principal point of the output cameras. If `None`, an average
 | 
			
		||||
            principal point of all train_cameras is used.
 | 
			
		||||
        time: Defines the total length of the generated camera trajectory. All possible
 | 
			
		||||
            trajectories (set with the `trajectory_type` argument) are periodic with
 | 
			
		||||
            the period of `time=2pi`.
 | 
			
		||||
            E.g. setting `trajectory_type=circular_lsq_fit` and `time=4pi`, will generate
 | 
			
		||||
            a trajectory of camera poses rotating the total of 720 deg around the object.
 | 
			
		||||
        infer_up_as_plane_normal: Infer the camera `up` vector automatically as the normal
 | 
			
		||||
            of the plane fit to the optical centers of `train_cameras`.
 | 
			
		||||
        traj_offset: 3D offset vector added to each point of the trajectory.
 | 
			
		||||
        traj_offset_canonical: 3D offset vector expressed in the local coordinates of
 | 
			
		||||
            the estimated trajectory which is added to each point of the trajectory.
 | 
			
		||||
        remove_outliers_rate: the number between 0 and 1; if > 0,
 | 
			
		||||
            some outlier train_cameras will be removed from trajectory estimation;
 | 
			
		||||
            the filtering is based on camera center coordinates; top and
 | 
			
		||||
            bottom `remove_outliers_rate` cameras on each dimension are removed.
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dictionary of camera instances which can be used as the test dataset
 | 
			
		||||
        Batch of camera instances which can be used as the test dataset
 | 
			
		||||
    """
 | 
			
		||||
    if remove_outliers_rate > 0.0:
 | 
			
		||||
        train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
 | 
			
		||||
 | 
			
		||||
@ -68,7 +68,7 @@ def get_skateboard_data(
 | 
			
		||||
    if not os.environ.get("FB_TEST", False):
 | 
			
		||||
        if os.getenv("FAIR_ENV_CLUSTER", "") == "":
 | 
			
		||||
            raise unittest.SkipTest("Unknown environment. Data not available.")
 | 
			
		||||
        yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", PathManager()
 | 
			
		||||
        yield "/datasets01/co3d/081922", PathManager()
 | 
			
		||||
 | 
			
		||||
    elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
 | 
			
		||||
        from libfb.py.parutil import get_file_path
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										154
									
								
								tests/implicitron/test_model_visualize.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										154
									
								
								tests/implicitron/test_model_visualize.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,154 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import contextlib
 | 
			
		||||
import math
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
			
		||||
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from tests.common_testing import interactive_testing_requested
 | 
			
		||||
from visdom import Visdom
 | 
			
		||||
 | 
			
		||||
from .common_resources import get_skateboard_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestModelVisualize(unittest.TestCase):
 | 
			
		||||
    def test_flyaround_one_sequence(
 | 
			
		||||
        self,
 | 
			
		||||
        image_size: int = 256,
 | 
			
		||||
    ):
 | 
			
		||||
        if not interactive_testing_requested():
 | 
			
		||||
            return
 | 
			
		||||
        category = "skateboard"
 | 
			
		||||
        stack = contextlib.ExitStack()
 | 
			
		||||
        dataset_root, path_manager = stack.enter_context(get_skateboard_data())
 | 
			
		||||
        self.addCleanup(stack.close)
 | 
			
		||||
        frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
 | 
			
		||||
        sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
 | 
			
		||||
        subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
 | 
			
		||||
        expand_args_fields(JsonIndexDataset)
 | 
			
		||||
        train_dataset = JsonIndexDataset(
 | 
			
		||||
            frame_annotations_file=frame_file,
 | 
			
		||||
            sequence_annotations_file=sequence_file,
 | 
			
		||||
            subset_lists_file=subset_lists_file,
 | 
			
		||||
            dataset_root=dataset_root,
 | 
			
		||||
            image_height=image_size,
 | 
			
		||||
            image_width=image_size,
 | 
			
		||||
            box_crop=True,
 | 
			
		||||
            load_point_clouds=True,
 | 
			
		||||
            path_manager=path_manager,
 | 
			
		||||
            subsets=[
 | 
			
		||||
                "train_known",
 | 
			
		||||
            ],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # select few sequences to visualize
 | 
			
		||||
        sequence_names = list(train_dataset.seq_annots.keys())
 | 
			
		||||
 | 
			
		||||
        # select the first sequence name
 | 
			
		||||
        show_sequence_name = sequence_names[0]
 | 
			
		||||
 | 
			
		||||
        output_dir = os.path.split(os.path.abspath(__file__))[0]
 | 
			
		||||
 | 
			
		||||
        visdom_show_preds = Visdom().check_connection()
 | 
			
		||||
 | 
			
		||||
        for load_dataset_pointcloud in [True, False]:
 | 
			
		||||
 | 
			
		||||
            model = _PointcloudRenderingModel(
 | 
			
		||||
                train_dataset,
 | 
			
		||||
                show_sequence_name,
 | 
			
		||||
                device="cuda:0",
 | 
			
		||||
                load_dataset_pointcloud=load_dataset_pointcloud,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            video_path = os.path.join(
 | 
			
		||||
                output_dir,
 | 
			
		||||
                f"load_pcl_{load_dataset_pointcloud}",
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            os.makedirs(output_dir, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
            render_flyaround(
 | 
			
		||||
                train_dataset,
 | 
			
		||||
                show_sequence_name,
 | 
			
		||||
                model,
 | 
			
		||||
                video_path,
 | 
			
		||||
                n_flyaround_poses=40,
 | 
			
		||||
                fps=20,
 | 
			
		||||
                max_angle=2 * math.pi,
 | 
			
		||||
                trajectory_type="circular_lsq_fit",
 | 
			
		||||
                trajectory_scale=1.1,
 | 
			
		||||
                scene_center=(0.0, 0.0, 0.0),
 | 
			
		||||
                up=(0.0, 1.0, 0.0),
 | 
			
		||||
                traj_offset=1.0,
 | 
			
		||||
                n_source_views=1,
 | 
			
		||||
                visdom_show_preds=visdom_show_preds,
 | 
			
		||||
                visdom_environment="test_model_visalize",
 | 
			
		||||
                visdom_server="http://127.0.0.1",
 | 
			
		||||
                visdom_port=8097,
 | 
			
		||||
                num_workers=10,
 | 
			
		||||
                seed=None,
 | 
			
		||||
                video_resize=None,
 | 
			
		||||
                visualize_preds_keys=[
 | 
			
		||||
                    "images_render",
 | 
			
		||||
                    "depths_render",
 | 
			
		||||
                    "masks_render",
 | 
			
		||||
                    "_all_source_images",
 | 
			
		||||
                ],
 | 
			
		||||
                output_video_frames_dir=video_path,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _PointcloudRenderingModel(torch.nn.Module):
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        train_dataset: JsonIndexDataset,
 | 
			
		||||
        sequence_name: str,
 | 
			
		||||
        render_size: Tuple[int, int] = (400, 400),
 | 
			
		||||
        device=None,
 | 
			
		||||
        load_dataset_pointcloud: bool = False,
 | 
			
		||||
        max_frames: int = 30,
 | 
			
		||||
        num_workers: int = 10,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        self._render_size = render_size
 | 
			
		||||
        point_cloud, _ = get_implicitron_sequence_pointcloud(
 | 
			
		||||
            train_dataset,
 | 
			
		||||
            sequence_name=sequence_name,
 | 
			
		||||
            mask_points=True,
 | 
			
		||||
            max_frames=max_frames,
 | 
			
		||||
            num_workers=num_workers,
 | 
			
		||||
            load_dataset_point_cloud=load_dataset_pointcloud,
 | 
			
		||||
        )
 | 
			
		||||
        self._point_cloud = point_cloud.to(device)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        camera: CamerasBase,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        image_render, mask_render, depth_render = render_point_cloud_pytorch3d(
 | 
			
		||||
            camera[0],
 | 
			
		||||
            self._point_cloud,
 | 
			
		||||
            render_size=self._render_size,
 | 
			
		||||
            point_radius=1e-2,
 | 
			
		||||
            topk=10,
 | 
			
		||||
            bg_color=0.0,
 | 
			
		||||
        )
 | 
			
		||||
        return {
 | 
			
		||||
            "images_render": image_render.clamp(0.0, 1.0),
 | 
			
		||||
            "masks_render": mask_render,
 | 
			
		||||
            "depths_render": depth_render,
 | 
			
		||||
        }
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user