mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	visualize_reconstruction fixes
Summary: Various fixes to get visualize_reconstruction running, and an interactive test for it. Reviewed By: kjchalup Differential Revision: D39286691 fbshipit-source-id: 88735034cc01736b24735bcb024577e6ab7ed336
This commit is contained in:
		
							parent
							
								
									34ad77b841
								
							
						
					
					
						commit
						6e25fe8cb3
					
				@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
 | 
				
			|||||||
To run training, pass a yaml config file, followed by a list of overridden arguments.
 | 
					To run training, pass a yaml config file, followed by a list of overridden arguments.
 | 
				
			||||||
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
 | 
					For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
 | 
				
			||||||
```shell
 | 
					```shell
 | 
				
			||||||
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
					dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
				
			||||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
 | 
					pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
 | 
				
			||||||
    $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' \
 | 
					    $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' \
 | 
				
			||||||
    $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
 | 
					    $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
 | 
				
			||||||
@ -87,7 +87,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
 | 
					E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
 | 
				
			||||||
```shell
 | 
					```shell
 | 
				
			||||||
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
					dataset_args=data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
				
			||||||
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
 | 
					pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf \
 | 
				
			||||||
    $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' \
 | 
					    $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' \
 | 
				
			||||||
    $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
 | 
					    $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
 | 
				
			||||||
 | 
				
			|||||||
@ -13,16 +13,7 @@ from hydra import compose, initialize_config_dir
 | 
				
			|||||||
from omegaconf import OmegaConf
 | 
					from omegaconf import OmegaConf
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .. import experiment
 | 
					from .. import experiment
 | 
				
			||||||
from .utils import intercept_logs
 | 
					from .utils import interactive_testing_requested, intercept_logs
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def interactive_testing_requested() -> bool:
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Certain tests are only useful when run interactively, and so are not regularly run.
 | 
					 | 
				
			||||||
    These are activated by this funciton returning True, which the user requests by
 | 
					 | 
				
			||||||
    setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
internal = os.environ.get("FB_TEST", False)
 | 
					internal = os.environ.get("FB_TEST", False)
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										27
									
								
								projects/implicitron_trainer/tests/test_visualize.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								projects/implicitron_trainer/tests/test_visualize.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .. import visualize_reconstruction
 | 
				
			||||||
 | 
					from .utils import interactive_testing_requested
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					internal = os.environ.get("FB_TEST", False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestVisualize(unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_from_defaults(self):
 | 
				
			||||||
 | 
					        if not interactive_testing_requested():
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					        checkpoint_dir = os.environ["exp_dir"]
 | 
				
			||||||
 | 
					        argv = [
 | 
				
			||||||
 | 
					            f"exp_dir={checkpoint_dir}",
 | 
				
			||||||
 | 
					            "n_eval_cameras=40",
 | 
				
			||||||
 | 
					            "render_size=[64,64]",
 | 
				
			||||||
 | 
					            "video_size=[256,256]",
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        visualize_reconstruction.main(argv)
 | 
				
			||||||
@ -6,6 +6,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import contextlib
 | 
					import contextlib
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -28,3 +29,12 @@ def intercept_logs(logger_name: str, regexp: str):
 | 
				
			|||||||
        yield intercepted_messages
 | 
					        yield intercepted_messages
 | 
				
			||||||
    finally:
 | 
					    finally:
 | 
				
			||||||
        logger.removeFilter(interceptor)
 | 
					        logger.removeFilter(interceptor)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def interactive_testing_requested() -> bool:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Certain tests are only useful when run interactively, and so are not regularly run.
 | 
				
			||||||
 | 
					    These are activated by this funciton returning True, which the user requests by
 | 
				
			||||||
 | 
					    setting the environment variable `PYTORCH3D_INTERACTIVE_TESTING` to 1.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
 | 
				
			||||||
 | 
				
			|||||||
@ -5,10 +5,11 @@
 | 
				
			|||||||
# This source code is licensed under the BSD-style license found in the
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
# LICENSE file in the root directory of this source tree.
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""Script to visualize a previously trained model. Example call:
 | 
					"""
 | 
				
			||||||
 | 
					Script to visualize a previously trained model. Example call:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    projects/implicitron_trainer/visualize_reconstruction.py
 | 
					    pytorch3d_implicitron_visualizer \
 | 
				
			||||||
    exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
 | 
					    exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 \
 | 
				
			||||||
    n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
 | 
					    n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -18,9 +19,9 @@ from typing import Optional, Tuple
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from omegaconf import OmegaConf
 | 
					from omegaconf import DictConfig, OmegaConf
 | 
				
			||||||
from pytorch3d.implicitron.models.visualization import render_flyaround
 | 
					from pytorch3d.implicitron.models.visualization.render_flyaround import render_flyaround
 | 
				
			||||||
from pytorch3d.implicitron.tools.configurable import get_default_args
 | 
					from pytorch3d.implicitron.tools.config import enable_get_default_args, get_default_args
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .experiment import Experiment
 | 
					from .experiment import Experiment
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -38,7 +39,7 @@ def visualize_reconstruction(
 | 
				
			|||||||
    visdom_server: str = "http://127.0.0.1",
 | 
					    visdom_server: str = "http://127.0.0.1",
 | 
				
			||||||
    visdom_port: int = 8097,
 | 
					    visdom_port: int = 8097,
 | 
				
			||||||
    visdom_env: Optional[str] = None,
 | 
					    visdom_env: Optional[str] = None,
 | 
				
			||||||
):
 | 
					) -> None:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
 | 
					    Given an `exp_dir` containing a trained Implicitron model, generates videos consisting
 | 
				
			||||||
    of renderes of sequences from the dataset used to train and evaluate the trained
 | 
					    of renderes of sequences from the dataset used to train and evaluate the trained
 | 
				
			||||||
@ -76,22 +77,27 @@ def visualize_reconstruction(
 | 
				
			|||||||
    config = _get_config_from_experiment_directory(exp_dir)
 | 
					    config = _get_config_from_experiment_directory(exp_dir)
 | 
				
			||||||
    config.exp_dir = exp_dir
 | 
					    config.exp_dir = exp_dir
 | 
				
			||||||
    # important so that the CO3D dataset gets loaded in full
 | 
					    # important so that the CO3D dataset gets loaded in full
 | 
				
			||||||
    dataset_args = (
 | 
					    data_source_args = config.data_source_ImplicitronDataSource_args
 | 
				
			||||||
        config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
					    if "dataset_map_provider_JsonIndexDatasetMapProvider_args" in data_source_args:
 | 
				
			||||||
    )
 | 
					        dataset_args = (
 | 
				
			||||||
    dataset_args.test_on_train = False
 | 
					            data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        dataset_args.test_on_train = False
 | 
				
			||||||
 | 
					        if restrict_sequence_name is not None:
 | 
				
			||||||
 | 
					            dataset_args.restrict_sequence_name = restrict_sequence_name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Set the rendering image size
 | 
					    # Set the rendering image size
 | 
				
			||||||
    model_factory_args = config.model_factory_ImplicitronModelFactory_args
 | 
					    model_factory_args = config.model_factory_ImplicitronModelFactory_args
 | 
				
			||||||
 | 
					    model_factory_args.force_resume = True
 | 
				
			||||||
    model_args = model_factory_args.model_GenericModel_args
 | 
					    model_args = model_factory_args.model_GenericModel_args
 | 
				
			||||||
    model_args.render_image_width = render_size[0]
 | 
					    model_args.render_image_width = render_size[0]
 | 
				
			||||||
    model_args.render_image_height = render_size[1]
 | 
					    model_args.render_image_height = render_size[1]
 | 
				
			||||||
    if restrict_sequence_name is not None:
 | 
					 | 
				
			||||||
        dataset_args.restrict_sequence_name = restrict_sequence_name
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Load the previously trained model
 | 
					    # Load the previously trained model
 | 
				
			||||||
    experiment = Experiment(config)
 | 
					    experiment = Experiment(**config)
 | 
				
			||||||
    model = experiment.model_factory(force_resume=True)
 | 
					    model = experiment.model_factory(exp_dir=exp_dir)
 | 
				
			||||||
    model.cuda()
 | 
					    device = torch.device("cuda")
 | 
				
			||||||
 | 
					    model.to(device)
 | 
				
			||||||
    model.eval()
 | 
					    model.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Setup the dataset
 | 
					    # Setup the dataset
 | 
				
			||||||
@ -101,6 +107,11 @@ def visualize_reconstruction(
 | 
				
			|||||||
    if dataset is None:
 | 
					    if dataset is None:
 | 
				
			||||||
        raise ValueError(f"{split} dataset not provided")
 | 
					        raise ValueError(f"{split} dataset not provided")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if visdom_env is None:
 | 
				
			||||||
 | 
					        visdom_env = (
 | 
				
			||||||
 | 
					            "visualizer_" + config.training_loop_ImplicitronTrainingLoop_args.visdom_env
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # iterate over the sequences in the dataset
 | 
					    # iterate over the sequences in the dataset
 | 
				
			||||||
    for sequence_name in dataset.sequence_names():
 | 
					    for sequence_name in dataset.sequence_names():
 | 
				
			||||||
        with torch.no_grad():
 | 
					        with torch.no_grad():
 | 
				
			||||||
@ -114,23 +125,26 @@ def visualize_reconstruction(
 | 
				
			|||||||
                n_flyaround_poses=n_eval_cameras,
 | 
					                n_flyaround_poses=n_eval_cameras,
 | 
				
			||||||
                visdom_server=visdom_server,
 | 
					                visdom_server=visdom_server,
 | 
				
			||||||
                visdom_port=visdom_port,
 | 
					                visdom_port=visdom_port,
 | 
				
			||||||
                visdom_environment=f"visualizer_{config.visdom_env}"
 | 
					                visdom_environment=visdom_env,
 | 
				
			||||||
                if visdom_env is None
 | 
					 | 
				
			||||||
                else visdom_env,
 | 
					 | 
				
			||||||
                video_resize=video_size,
 | 
					                video_resize=video_size,
 | 
				
			||||||
 | 
					                device=device,
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _get_config_from_experiment_directory(experiment_directory):
 | 
					enable_get_default_args(visualize_reconstruction)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_config_from_experiment_directory(experiment_directory) -> DictConfig:
 | 
				
			||||||
    cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
 | 
					    cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
 | 
				
			||||||
    config = OmegaConf.load(cfg_file)
 | 
					    config = OmegaConf.load(cfg_file)
 | 
				
			||||||
 | 
					    # pyre-ignore[7]
 | 
				
			||||||
    return config
 | 
					    return config
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main(argv):
 | 
					def main(argv) -> None:
 | 
				
			||||||
    # automatically parses arguments of visualize_reconstruction
 | 
					    # automatically parses arguments of visualize_reconstruction
 | 
				
			||||||
    cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
 | 
					    cfg = OmegaConf.create(get_default_args(visualize_reconstruction))
 | 
				
			||||||
    cfg.update(OmegaConf.from_cli())
 | 
					    cfg.update(OmegaConf.from_cli(argv))
 | 
				
			||||||
    with torch.no_grad():
 | 
					    with torch.no_grad():
 | 
				
			||||||
        visualize_reconstruction(**cfg)
 | 
					        visualize_reconstruction(**cfg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -9,7 +9,7 @@
 | 
				
			|||||||
# provide data for a single scene.
 | 
					# provide data for a single scene.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from dataclasses import field
 | 
					from dataclasses import field
 | 
				
			||||||
from typing import Iterable, List, Optional
 | 
					from typing import Iterable, Iterator, List, Optional, Tuple
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
@ -46,6 +46,12 @@ class SingleSceneDataset(DatasetBase, Configurable):
 | 
				
			|||||||
    def __len__(self) -> int:
 | 
					    def __len__(self) -> int:
 | 
				
			||||||
        return len(self.poses)
 | 
					        return len(self.poses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sequence_frames_in_order(
 | 
				
			||||||
 | 
					        self, seq_name: str
 | 
				
			||||||
 | 
					    ) -> Iterator[Tuple[float, int, int]]:
 | 
				
			||||||
 | 
					        for i in range(len(self)):
 | 
				
			||||||
 | 
					            yield (0.0, i, i)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __getitem__(self, index) -> FrameData:
 | 
					    def __getitem__(self, index) -> FrameData:
 | 
				
			||||||
        if index >= len(self):
 | 
					        if index >= len(self):
 | 
				
			||||||
            raise IndexError(f"index {index} out of range {len(self)}")
 | 
					            raise IndexError(f"index {index} out of range {len(self)}")
 | 
				
			||||||
 | 
				
			|||||||
@ -61,7 +61,7 @@ def render_flyaround(
 | 
				
			|||||||
        "depths_render",
 | 
					        "depths_render",
 | 
				
			||||||
        "_all_source_images",
 | 
					        "_all_source_images",
 | 
				
			||||||
    ),
 | 
					    ),
 | 
				
			||||||
):
 | 
					) -> None:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Uses `model` to generate a video consisting of renders of a scene imaged from
 | 
					    Uses `model` to generate a video consisting of renders of a scene imaged from
 | 
				
			||||||
    a camera flying around the scene. The scene is specified with the `dataset` object and
 | 
					    a camera flying around the scene. The scene is specified with the `dataset` object and
 | 
				
			||||||
@ -133,6 +133,7 @@ def render_flyaround(
 | 
				
			|||||||
    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
 | 
					    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
 | 
				
			||||||
    train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
 | 
					    train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
 | 
				
			||||||
    assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
 | 
					    assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
 | 
				
			||||||
 | 
					    # pyre-ignore[6]
 | 
				
			||||||
    sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
 | 
					    sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
 | 
				
			||||||
    logger.info(f"Sequence set = {sequence_set_name}.")
 | 
					    logger.info(f"Sequence set = {sequence_set_name}.")
 | 
				
			||||||
    train_cameras = train_data.camera
 | 
					    train_cameras = train_data.camera
 | 
				
			||||||
@ -209,7 +210,7 @@ def render_flyaround(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
def _load_whole_dataset(
 | 
					def _load_whole_dataset(
 | 
				
			||||||
    dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
 | 
					    dataset: torch.utils.data.Dataset, idx: Sequence[int], num_workers: int = 10
 | 
				
			||||||
):
 | 
					) -> FrameData:
 | 
				
			||||||
    load_all_dataloader = torch.utils.data.DataLoader(
 | 
					    load_all_dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
        torch.utils.data.Subset(dataset, idx),
 | 
					        torch.utils.data.Subset(dataset, idx),
 | 
				
			||||||
        batch_size=len(idx),
 | 
					        batch_size=len(idx),
 | 
				
			||||||
@ -220,7 +221,7 @@ def _load_whole_dataset(
 | 
				
			|||||||
    return next(iter(load_all_dataloader))
 | 
					    return next(iter(load_all_dataloader))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _images_from_preds(preds: Dict[str, Any]):
 | 
					def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]:
 | 
				
			||||||
    imout = {}
 | 
					    imout = {}
 | 
				
			||||||
    for k in (
 | 
					    for k in (
 | 
				
			||||||
        "image_rgb",
 | 
					        "image_rgb",
 | 
				
			||||||
@ -253,7 +254,7 @@ def _images_from_preds(preds: Dict[str, Any]):
 | 
				
			|||||||
    return imout
 | 
					    return imout
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]):
 | 
					def _stack_images(ims: torch.Tensor, size: Optional[Tuple[int, int]]) -> torch.Tensor:
 | 
				
			||||||
    ba = ims.shape[0]
 | 
					    ba = ims.shape[0]
 | 
				
			||||||
    H = int(np.ceil(np.sqrt(ba)))
 | 
					    H = int(np.ceil(np.sqrt(ba)))
 | 
				
			||||||
    W = H
 | 
					    W = H
 | 
				
			||||||
@ -281,7 +282,7 @@ def _show_predictions(
 | 
				
			|||||||
    ),
 | 
					    ),
 | 
				
			||||||
    n_samples=10,
 | 
					    n_samples=10,
 | 
				
			||||||
    one_image_width=200,
 | 
					    one_image_width=200,
 | 
				
			||||||
):
 | 
					) -> None:
 | 
				
			||||||
    """Given a list of predictions visualize them into a single image using visdom."""
 | 
					    """Given a list of predictions visualize them into a single image using visdom."""
 | 
				
			||||||
    assert isinstance(preds, list)
 | 
					    assert isinstance(preds, list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -329,7 +330,7 @@ def _generate_prediction_videos(
 | 
				
			|||||||
    video_path: str = "/tmp/video",
 | 
					    video_path: str = "/tmp/video",
 | 
				
			||||||
    video_frames_dir: Optional[str] = None,
 | 
					    video_frames_dir: Optional[str] = None,
 | 
				
			||||||
    resize: Optional[Tuple[int, int]] = None,
 | 
					    resize: Optional[Tuple[int, int]] = None,
 | 
				
			||||||
):
 | 
					) -> None:
 | 
				
			||||||
    """Given a list of predictions create and visualize rotating videos of the
 | 
					    """Given a list of predictions create and visualize rotating videos of the
 | 
				
			||||||
    objects using visdom.
 | 
					    objects using visdom.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
@ -359,7 +360,7 @@ def _generate_prediction_videos(
 | 
				
			|||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for k in predicted_keys:
 | 
					    for k in predicted_keys:
 | 
				
			||||||
        vws[k].get_video(quiet=True)
 | 
					        vws[k].get_video()
 | 
				
			||||||
        logger.info(f"Generated {vws[k].out_path}.")
 | 
					        logger.info(f"Generated {vws[k].out_path}.")
 | 
				
			||||||
        if viz is not None:
 | 
					        if viz is not None:
 | 
				
			||||||
            viz.video(
 | 
					            viz.video(
 | 
				
			||||||
 | 
				
			|||||||
@ -6,6 +6,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import shutil
 | 
					import shutil
 | 
				
			||||||
 | 
					import subprocess
 | 
				
			||||||
import tempfile
 | 
					import tempfile
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
from typing import Optional, Tuple, Union
 | 
					from typing import Optional, Tuple, Union
 | 
				
			||||||
@ -15,6 +16,7 @@ import matplotlib.pyplot as plt
 | 
				
			|||||||
import numpy as np
 | 
					import numpy as np
 | 
				
			||||||
from PIL import Image
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
matplotlib.use("Agg")
 | 
					matplotlib.use("Agg")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -27,13 +29,13 @@ class VideoWriter:
 | 
				
			|||||||
    def __init__(
 | 
					    def __init__(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        cache_dir: Optional[str] = None,
 | 
					        cache_dir: Optional[str] = None,
 | 
				
			||||||
        ffmpeg_bin: str = "ffmpeg",
 | 
					        ffmpeg_bin: str = _DEFAULT_FFMPEG,
 | 
				
			||||||
        out_path: str = "/tmp/video.mp4",
 | 
					        out_path: str = "/tmp/video.mp4",
 | 
				
			||||||
        fps: int = 20,
 | 
					        fps: int = 20,
 | 
				
			||||||
        output_format: str = "visdom",
 | 
					        output_format: str = "visdom",
 | 
				
			||||||
        rmdir_allowed: bool = False,
 | 
					        rmdir_allowed: bool = False,
 | 
				
			||||||
        **kwargs,
 | 
					        **kwargs,
 | 
				
			||||||
    ):
 | 
					    ) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Args:
 | 
					        Args:
 | 
				
			||||||
            cache_dir: A directory for storing the video frames. If `None`,
 | 
					            cache_dir: A directory for storing the video frames. If `None`,
 | 
				
			||||||
@ -74,7 +76,7 @@ class VideoWriter:
 | 
				
			|||||||
        self,
 | 
					        self,
 | 
				
			||||||
        frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
 | 
					        frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
 | 
				
			||||||
        resize: Optional[Union[float, Tuple[int, int]]] = None,
 | 
					        resize: Optional[Union[float, Tuple[int, int]]] = None,
 | 
				
			||||||
    ):
 | 
					    ) -> None:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Write a frame to the video.
 | 
					        Write a frame to the video.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -114,7 +116,7 @@ class VideoWriter:
 | 
				
			|||||||
        self.frames.append(outfile)
 | 
					        self.frames.append(outfile)
 | 
				
			||||||
        self.frame_num += 1
 | 
					        self.frame_num += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_video(self, quiet: bool = True):
 | 
					    def get_video(self) -> str:
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Generate the video from the written frames.
 | 
					        Generate the video from the written frames.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -127,23 +129,39 @@ class VideoWriter:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        regexp = os.path.join(self.cache_dir, self.regexp)
 | 
					        regexp = os.path.join(self.cache_dir, self.regexp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.output_format == "visdom":  # works for ppt too
 | 
					        if shutil.which(self.ffmpeg_bin) is None:
 | 
				
			||||||
            ffmcmd_ = (
 | 
					            raise ValueError(
 | 
				
			||||||
                "%s -r %d -i %s -vcodec h264 -f mp4 \
 | 
					                f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
 | 
				
			||||||
                       -y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
 | 
					                + "Please set FFMPEG in the environment or ffmpeg_bin on this class."
 | 
				
			||||||
                % (self.ffmpeg_bin, self.fps, regexp, self.out_path)
 | 
					 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.output_format == "visdom":  # works for ppt too
 | 
				
			||||||
 | 
					            args = [
 | 
				
			||||||
 | 
					                self.ffmpeg_bin,
 | 
				
			||||||
 | 
					                "-r",
 | 
				
			||||||
 | 
					                str(self.fps),
 | 
				
			||||||
 | 
					                "-i",
 | 
				
			||||||
 | 
					                regexp,
 | 
				
			||||||
 | 
					                "-vcodec",
 | 
				
			||||||
 | 
					                "h264",
 | 
				
			||||||
 | 
					                "-f",
 | 
				
			||||||
 | 
					                "mp4",
 | 
				
			||||||
 | 
					                "-y",
 | 
				
			||||||
 | 
					                "-crf",
 | 
				
			||||||
 | 
					                "18",
 | 
				
			||||||
 | 
					                "-b",
 | 
				
			||||||
 | 
					                "2000k",
 | 
				
			||||||
 | 
					                "-pix_fmt",
 | 
				
			||||||
 | 
					                "yuv420p",
 | 
				
			||||||
 | 
					                self.out_path,
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            subprocess.check_call(args)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            raise ValueError("no such output type %s" % str(self.output_format))
 | 
					            raise ValueError("no such output type %s" % str(self.output_format))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if quiet:
 | 
					 | 
				
			||||||
            ffmcmd_ += " > /dev/null 2>&1"
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            print(ffmcmd_)
 | 
					 | 
				
			||||||
        os.system(ffmcmd_)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        return self.out_path
 | 
					        return self.out_path
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __del__(self):
 | 
					    def __del__(self) -> None:
 | 
				
			||||||
        if self.tmp_dir is not None:
 | 
					        if self.tmp_dir is not None:
 | 
				
			||||||
            self.tmp_dir.cleanup()
 | 
					            self.tmp_dir.cleanup()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user