mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	get_all_train_cameras
Summary: As part of removing Task, make the dataset code generate the source cameras for itself. There's a small optimization available here, in that the JsonIndexDataset could avoid loading images. Reviewed By: shapovalov Differential Revision: D37313423 fbshipit-source-id: 3e5e0b2aabbf9cc51f10547a3523e98c72ad8755
This commit is contained in:
		
							parent
							
								
									771cf8a328
								
							
						
					
					
						commit
						4e87c2b7f1
					
				@ -66,9 +66,7 @@ from packaging import version
 | 
			
		||||
from pytorch3d.implicitron.dataset import utils as ds_utils
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
 | 
			
		||||
from pytorch3d.implicitron.tools import model_io, vis_utils
 | 
			
		||||
@ -456,6 +454,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
 | 
			
		||||
    datasource = ImplicitronDataSource(**cfg.data_source_args)
 | 
			
		||||
    datasets, dataloaders = datasource.get_datasets_and_dataloaders()
 | 
			
		||||
    task = datasource.get_task()
 | 
			
		||||
    all_train_cameras = datasource.get_all_train_cameras()
 | 
			
		||||
 | 
			
		||||
    # init the model
 | 
			
		||||
    model, stats, optimizer_state = init_model(cfg)
 | 
			
		||||
@ -466,7 +465,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
 | 
			
		||||
 | 
			
		||||
    # only run evaluation on the test dataloader
 | 
			
		||||
    if cfg.eval_only:
 | 
			
		||||
        _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
 | 
			
		||||
        _eval_and_dump(
 | 
			
		||||
            cfg,
 | 
			
		||||
            task,
 | 
			
		||||
            all_train_cameras,
 | 
			
		||||
            datasets,
 | 
			
		||||
            dataloaders,
 | 
			
		||||
            model,
 | 
			
		||||
            stats,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # init the optimizer
 | 
			
		||||
@ -528,7 +536,9 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
 | 
			
		||||
                and cfg.test_interval > 0
 | 
			
		||||
                and epoch % cfg.test_interval == 0
 | 
			
		||||
            ):
 | 
			
		||||
                _run_eval(model, stats, dataloaders.test, task, device=device)
 | 
			
		||||
                _run_eval(
 | 
			
		||||
                    model, all_train_cameras, dataloaders.test, task, device=device
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            assert stats.epoch == epoch, "inconsistent stats!"
 | 
			
		||||
 | 
			
		||||
@ -548,12 +558,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
 | 
			
		||||
                logger.info(f"LR change! {cur_lr} -> {new_lr}")
 | 
			
		||||
 | 
			
		||||
    if cfg.test_when_finished:
 | 
			
		||||
        _eval_and_dump(cfg, task, datasets, dataloaders, model, stats, device=device)
 | 
			
		||||
        _eval_and_dump(
 | 
			
		||||
            cfg,
 | 
			
		||||
            task,
 | 
			
		||||
            all_train_cameras,
 | 
			
		||||
            datasets,
 | 
			
		||||
            dataloaders,
 | 
			
		||||
            model,
 | 
			
		||||
            stats,
 | 
			
		||||
            device=device,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _eval_and_dump(
 | 
			
		||||
    cfg,
 | 
			
		||||
    task: Task,
 | 
			
		||||
    all_train_cameras: Optional[CamerasBase],
 | 
			
		||||
    datasets: DatasetMap,
 | 
			
		||||
    dataloaders: DataLoaderMap,
 | 
			
		||||
    model,
 | 
			
		||||
@ -570,13 +590,7 @@ def _eval_and_dump(
 | 
			
		||||
    if dataloader is None:
 | 
			
		||||
        raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
 | 
			
		||||
 | 
			
		||||
    if task == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        if datasets.train is None:
 | 
			
		||||
            raise ValueError("train dataset must be provided")
 | 
			
		||||
        all_source_cameras = _get_all_source_cameras(datasets.train)
 | 
			
		||||
    else:
 | 
			
		||||
        all_source_cameras = None
 | 
			
		||||
    results = _run_eval(model, all_source_cameras, dataloader, task, device=device)
 | 
			
		||||
    results = _run_eval(model, all_train_cameras, dataloader, task, device=device)
 | 
			
		||||
 | 
			
		||||
    # add the evaluation epoch to the results
 | 
			
		||||
    for r in results:
 | 
			
		||||
@ -603,7 +617,7 @@ def _get_eval_frame_data(frame_data):
 | 
			
		||||
    return frame_data_for_eval
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _run_eval(model, all_source_cameras, loader, task: Task, device):
 | 
			
		||||
def _run_eval(model, all_train_cameras, loader, task: Task, device):
 | 
			
		||||
    """
 | 
			
		||||
    Run the evaluation loop on the test dataloader
 | 
			
		||||
    """
 | 
			
		||||
@ -631,7 +645,7 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
 | 
			
		||||
                    implicitron_render,
 | 
			
		||||
                    bg_color="black",
 | 
			
		||||
                    lpips_model=lpips_model,
 | 
			
		||||
                    source_cameras=all_source_cameras,
 | 
			
		||||
                    source_cameras=all_train_cameras,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -642,31 +656,6 @@ def _run_eval(model, all_source_cameras, loader, task: Task, device):
 | 
			
		||||
    return category_result["results"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_all_source_cameras(
 | 
			
		||||
    dataset: JsonIndexDataset,
 | 
			
		||||
    num_workers: int = 8,
 | 
			
		||||
) -> CamerasBase:
 | 
			
		||||
    """
 | 
			
		||||
    Load and return all the source cameras in the training dataset
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    all_frame_data = next(
 | 
			
		||||
        iter(
 | 
			
		||||
            torch.utils.data.DataLoader(
 | 
			
		||||
                dataset,
 | 
			
		||||
                shuffle=False,
 | 
			
		||||
                batch_size=len(dataset),
 | 
			
		||||
                num_workers=num_workers,
 | 
			
		||||
                collate_fn=FrameData.collate,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
 | 
			
		||||
    source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
 | 
			
		||||
    return source_cameras
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _seed_all_random_engines(seed: int):
 | 
			
		||||
    np.random.seed(seed)
 | 
			
		||||
    torch.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
@ -4,13 +4,14 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import Tuple
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    registry,
 | 
			
		||||
    ReplaceableBase,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .blender_dataset_map_provider import BlenderDatasetMapProvider  # noqa
 | 
			
		||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
 | 
			
		||||
@ -28,6 +29,15 @@ class DataSourceBase(ReplaceableBase):
 | 
			
		||||
    def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        """
 | 
			
		||||
        If the data is all for a single scene, returns a list
 | 
			
		||||
        of the known training cameras for that scene, which is
 | 
			
		||||
        used for evaluating the viewpoint difficulty of the
 | 
			
		||||
        unseen cameras.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class ImplicitronDataSource(DataSourceBase):  # pyre-ignore[13]
 | 
			
		||||
@ -56,3 +66,6 @@ class ImplicitronDataSource(DataSourceBase):  # pyre-ignore[13]
 | 
			
		||||
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return self.dataset_map_provider.get_task()
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        return self.dataset_map_provider.get_all_train_cameras()
 | 
			
		||||
 | 
			
		||||
@ -224,7 +224,7 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
    # _seq_to_idx: Dict[str, List[int]] = field(init=False)
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_frame_numbers_and_timestamps(
 | 
			
		||||
        self, idxs: Sequence[int]
 | 
			
		||||
 | 
			
		||||
@ -12,6 +12,7 @@ from typing import Iterator, Optional
 | 
			
		||||
 | 
			
		||||
from iopath.common.file_io import PathManager
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .dataset_base import DatasetBase
 | 
			
		||||
 | 
			
		||||
@ -73,6 +74,15 @@ class DatasetMapProviderBase(ReplaceableBase):
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        """
 | 
			
		||||
        If the data is all for a single scene, returns a list
 | 
			
		||||
        of the known training cameras for that scene, which is
 | 
			
		||||
        used for evaluating the difficulty of the unknown
 | 
			
		||||
        cameras. Otherwise return None.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class PathManagerFactory(ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
@ -32,11 +32,13 @@ import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.io import IO
 | 
			
		||||
from pytorch3d.renderer.cameras import PerspectiveCameras
 | 
			
		||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.pointclouds import Pointclouds
 | 
			
		||||
 | 
			
		||||
from . import types
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .utils import is_known_frame_scalar
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
@ -206,6 +208,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
 | 
			
		||||
        return entry["subset"]
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> CamerasBase:
 | 
			
		||||
        """
 | 
			
		||||
        Returns the cameras corresponding to all the known frames.
 | 
			
		||||
        """
 | 
			
		||||
        cameras = []
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        for frame_idx, frame_annot in enumerate(self.frame_annots):
 | 
			
		||||
            frame_type = self._get_frame_type(frame_annot)
 | 
			
		||||
            if frame_type is None:
 | 
			
		||||
                raise ValueError("subsets not loaded")
 | 
			
		||||
            if is_known_frame_scalar(frame_type):
 | 
			
		||||
                cameras.append(self[frame_idx].camera)
 | 
			
		||||
        return join_cameras_as_batch(cameras)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index) -> FrameData:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        if index >= len(self.frame_annots):
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from typing import Dict, List, Tuple, Type
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Type
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig, open_dict
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
@ -15,6 +15,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    registry,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
@ -267,6 +268,15 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return Task(self.task_str)
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        if Task(self.task_str) == Task.MULTI_SEQUENCE:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        train_dataset = self.dataset_map.train
 | 
			
		||||
        assert isinstance(train_dataset, JsonIndexDataset)
 | 
			
		||||
        return train_dataset.get_all_train_cameras()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_co3d_set_names_mapping(
 | 
			
		||||
    task: Task,
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    expand_args_fields,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.renderer import PerspectiveCameras
 | 
			
		||||
from pytorch3d.renderer import CamerasBase, join_cameras_as_batch, PerspectiveCameras
 | 
			
		||||
 | 
			
		||||
from .dataset_base import DatasetBase, FrameData
 | 
			
		||||
from .dataset_map_provider import (
 | 
			
		||||
@ -110,7 +110,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
 | 
			
		||||
        # - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
 | 
			
		||||
        # - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
 | 
			
		||||
        # - splits: List[List[int]] of indices for train/val/test subsets.
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def _get_dataset(
 | 
			
		||||
        self, split_idx: int, frame_type: str, set_eval_batches: bool = False
 | 
			
		||||
@ -162,6 +162,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
 | 
			
		||||
    def get_task(self) -> Task:
 | 
			
		||||
        return Task.SINGLE_SEQUENCE
 | 
			
		||||
 | 
			
		||||
    def get_all_train_cameras(self) -> Optional[CamerasBase]:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        cameras = [self.poses[i] for i in self.i_split[0]]
 | 
			
		||||
        return join_cameras_as_batch(cameras)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _interpret_blender_cameras(
 | 
			
		||||
    poses: torch.Tensor, H: int, W: int, focal: float
 | 
			
		||||
 | 
			
		||||
@ -12,12 +12,10 @@ from typing import Any, cast, Dict, List, Optional, Tuple
 | 
			
		||||
import lpips
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
 | 
			
		||||
    CO3D_CATEGORIES,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_known_frame
 | 
			
		||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
 | 
			
		||||
    aggregate_nvs_results,
 | 
			
		||||
    eval_batch,
 | 
			
		||||
@ -120,16 +118,7 @@ def evaluate_dbir_for_category(
 | 
			
		||||
    if test_dataset is None or test_dataloader is None:
 | 
			
		||||
        raise ValueError("must have a test dataset.")
 | 
			
		||||
 | 
			
		||||
    if task == Task.SINGLE_SEQUENCE:
 | 
			
		||||
        # all_source_cameras are needed for evaluation of the
 | 
			
		||||
        # target camera difficulty
 | 
			
		||||
        # pyre-fixme[16]: `JsonIndexDataset` has no attribute `frame_annots`.
 | 
			
		||||
        sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
 | 
			
		||||
        all_source_cameras = _get_all_source_cameras(
 | 
			
		||||
            test_dataset, sequence_name, num_workers=num_workers
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        all_source_cameras = None
 | 
			
		||||
    all_train_cameras = data_source.get_all_train_cameras()
 | 
			
		||||
 | 
			
		||||
    image_size = cast(JsonIndexDataset, test_dataset).image_width
 | 
			
		||||
 | 
			
		||||
@ -160,7 +149,7 @@ def evaluate_dbir_for_category(
 | 
			
		||||
                preds["implicitron_render"],
 | 
			
		||||
                bg_color=bg_color,
 | 
			
		||||
                lpips_model=lpips_model,
 | 
			
		||||
                source_cameras=all_source_cameras,
 | 
			
		||||
                source_cameras=all_train_cameras,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -184,35 +173,5 @@ def _print_aggregate_results(
 | 
			
		||||
    print("")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_all_source_cameras(
 | 
			
		||||
    dataset: DatasetBase, sequence_name: str, num_workers: int = 8
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Loads all training cameras of a given sequence.
 | 
			
		||||
 | 
			
		||||
    The set of all seen cameras is needed for evaluating the viewpoint difficulty
 | 
			
		||||
    for the singlescene evaluation.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        dataset: Co3D dataset object.
 | 
			
		||||
        sequence_name: The name of the sequence.
 | 
			
		||||
        num_workers: The number of for the utilized dataloader.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # load all source cameras of the sequence
 | 
			
		||||
    seq_idx = list(dataset.sequence_indices_in_order(sequence_name))
 | 
			
		||||
    dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
 | 
			
		||||
    (all_frame_data,) = torch.utils.data.DataLoader(
 | 
			
		||||
        dataset_for_loader,
 | 
			
		||||
        shuffle=False,
 | 
			
		||||
        batch_size=len(dataset_for_loader),
 | 
			
		||||
        num_workers=num_workers,
 | 
			
		||||
        collate_fn=dataset.frame_data_type.collate,
 | 
			
		||||
    )
 | 
			
		||||
    is_known = is_known_frame(all_frame_data.frame_type)
 | 
			
		||||
    source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
 | 
			
		||||
    return source_cameras
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,7 @@ def eval_batch(
 | 
			
		||||
    visualize: bool = False,
 | 
			
		||||
    visualize_visdom_env: str = "eval_debug",
 | 
			
		||||
    break_after_visualising: bool = True,
 | 
			
		||||
    source_cameras: Optional[List[CamerasBase]] = None,
 | 
			
		||||
    source_cameras: Optional[CamerasBase] = None,
 | 
			
		||||
) -> Dict[str, Any]:
 | 
			
		||||
    """
 | 
			
		||||
    Produce performance metrics for a single batch of new-view synthesis
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										42
									
								
								tests/implicitron/test_data_json_index.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								tests/implicitron/test_data_json_index.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,42 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
			
		||||
from pytorch3d.implicitron.tools.config import get_default_args
 | 
			
		||||
from pytorch3d.renderer import PerspectiveCameras
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
# These tests are only run internally, where the data is available.
 | 
			
		||||
internal = os.environ.get("FB_TEST", False)
 | 
			
		||||
inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
 | 
			
		||||
skip_tests = not internal or inside_re_worker
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@unittest.skipIf(skip_tests, "no data")
 | 
			
		||||
class TestDataJsonIndex(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_loaders(self):
 | 
			
		||||
        args = get_default_args(ImplicitronDataSource)
 | 
			
		||||
        args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
 | 
			
		||||
        dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
 | 
			
		||||
        dataset_args.category = "skateboard"
 | 
			
		||||
        dataset_args.dataset_root = "manifold://co3d/tree/extracted"
 | 
			
		||||
        dataset_args.test_restrict_sequence_id = 0
 | 
			
		||||
        dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 1
 | 
			
		||||
 | 
			
		||||
        data_source = ImplicitronDataSource(**args)
 | 
			
		||||
 | 
			
		||||
        cameras = data_source.get_all_train_cameras()
 | 
			
		||||
        self.assertIsInstance(cameras, PerspectiveCameras)
 | 
			
		||||
        self.assertEqual(len(cameras), 81)
 | 
			
		||||
 | 
			
		||||
        data_sets, data_loaders = data_source.get_datasets_and_dataloaders()
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(data_sets.train), 81)
 | 
			
		||||
        self.assertEqual(len(data_sets.val), 102)
 | 
			
		||||
        self.assertEqual(len(data_sets.test), 102)
 | 
			
		||||
@ -16,6 +16,7 @@ from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
 | 
			
		||||
    LlffDatasetMapProvider,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
 | 
			
		||||
from pytorch3d.renderer import PerspectiveCameras
 | 
			
		||||
from tests.common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -123,3 +124,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        for i in data_loaders.test:
 | 
			
		||||
            self.assertEqual(i.frame_type, ["unseen"])
 | 
			
		||||
            self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
 | 
			
		||||
 | 
			
		||||
        cameras = data_source.get_all_train_cameras()
 | 
			
		||||
        self.assertIsInstance(cameras, PerspectiveCameras)
 | 
			
		||||
        self.assertEqual(len(cameras), 100)
 | 
			
		||||
 | 
			
		||||
@ -24,11 +24,7 @@ from pytorch3d.implicitron.tools.config import expand_args_fields, registry
 | 
			
		||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
 | 
			
		||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if os.environ.get("FB_TEST", False):
 | 
			
		||||
    from .common_resources import get_skateboard_data, provide_lpips_vgg
 | 
			
		||||
else:
 | 
			
		||||
    from common_resources import get_skateboard_data, provide_lpips_vgg
 | 
			
		||||
from .common_resources import get_skateboard_data, provide_lpips_vgg
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestEvaluation(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user