mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	return types for dataset_zoo, dataloader_zoo
Summary: Stronger typing for these functions Reviewed By: shapovalov Differential Revision: D36170489 fbshipit-source-id: a2104b29dbbbcfcf91ae1d076cd6b0e3d2030c0b
This commit is contained in:
		
							parent
							
								
									90ab219d88
								
							
						
					
					
						commit
						2c1901522a
					
				@ -64,8 +64,8 @@ import tqdm
 | 
			
		||||
from omegaconf import DictConfig, OmegaConf
 | 
			
		||||
from packaging import version
 | 
			
		||||
from pytorch3d.implicitron.dataset import utils as ds_utils
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
 | 
			
		||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
 | 
			
		||||
    FrameData,
 | 
			
		||||
    ImplicitronDataset,
 | 
			
		||||
@ -453,7 +453,6 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
 | 
			
		||||
 | 
			
		||||
    # setup datasets
 | 
			
		||||
    datasets = dataset_zoo(**cfg.dataset_args)
 | 
			
		||||
    cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
 | 
			
		||||
    dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
 | 
			
		||||
 | 
			
		||||
    # init the model
 | 
			
		||||
@ -499,7 +498,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
 | 
			
		||||
                model,
 | 
			
		||||
                stats,
 | 
			
		||||
                epoch,
 | 
			
		||||
                dataloaders["train"],
 | 
			
		||||
                dataloaders.train,
 | 
			
		||||
                optimizer,
 | 
			
		||||
                False,
 | 
			
		||||
                visdom_env_root=vis_utils.get_visdom_env(cfg),
 | 
			
		||||
@ -508,12 +507,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            # val loop (optional)
 | 
			
		||||
            if "val" in dataloaders and epoch % cfg.validation_interval == 0:
 | 
			
		||||
            if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
 | 
			
		||||
                trainvalidate(
 | 
			
		||||
                    model,
 | 
			
		||||
                    stats,
 | 
			
		||||
                    epoch,
 | 
			
		||||
                    dataloaders["val"],
 | 
			
		||||
                    dataloaders.val,
 | 
			
		||||
                    optimizer,
 | 
			
		||||
                    True,
 | 
			
		||||
                    visdom_env_root=vis_utils.get_visdom_env(cfg),
 | 
			
		||||
@ -523,11 +522,11 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
 | 
			
		||||
 | 
			
		||||
            # eval loop (optional)
 | 
			
		||||
            if (
 | 
			
		||||
                "test" in dataloaders
 | 
			
		||||
                dataloaders.test is not None
 | 
			
		||||
                and cfg.test_interval > 0
 | 
			
		||||
                and epoch % cfg.test_interval == 0
 | 
			
		||||
            ):
 | 
			
		||||
                run_eval(cfg, model, stats, dataloaders["test"], device=device)
 | 
			
		||||
                run_eval(cfg, model, stats, dataloaders.test, device=device)
 | 
			
		||||
 | 
			
		||||
            assert stats.epoch == epoch, "inconsistent stats!"
 | 
			
		||||
 | 
			
		||||
@ -550,23 +549,28 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
 | 
			
		||||
        _eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
 | 
			
		||||
def _eval_and_dump(
 | 
			
		||||
    cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
 | 
			
		||||
) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Run the evaluation loop with the test data loader and
 | 
			
		||||
    save the predictions to the `exp_dir`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    if "test" not in dataloaders:
 | 
			
		||||
    dataloader = dataloaders.test
 | 
			
		||||
 | 
			
		||||
    if dataloader is None:
 | 
			
		||||
        raise ValueError('Dataloaders have to contain the "test" entry for eval!')
 | 
			
		||||
 | 
			
		||||
    eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
 | 
			
		||||
    all_source_cameras = (
 | 
			
		||||
        _get_all_source_cameras(datasets["train"])
 | 
			
		||||
        if eval_task == "singlesequence"
 | 
			
		||||
        else None
 | 
			
		||||
    )
 | 
			
		||||
    if eval_task == "singlesequence":
 | 
			
		||||
        if datasets.train is None:
 | 
			
		||||
            raise ValueError("train dataset must be provided")
 | 
			
		||||
        all_source_cameras = _get_all_source_cameras(datasets.train)
 | 
			
		||||
    else:
 | 
			
		||||
        all_source_cameras = None
 | 
			
		||||
    results = run_eval(
 | 
			
		||||
        cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
 | 
			
		||||
        cfg, model, all_source_cameras, dataloader, eval_task, device=device
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # add the evaluation epoch to the results
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,7 @@ from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
 | 
			
		||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
 | 
			
		||||
    FrameData,
 | 
			
		||||
    ImplicitronDataset,
 | 
			
		||||
    ImplicitronDatasetBase,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
 | 
			
		||||
@ -342,7 +343,10 @@ def export_scenes(
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
    # Setup the dataset
 | 
			
		||||
    dataset = dataset_zoo(**config.dataset_args)[split]
 | 
			
		||||
    datasets = dataset_zoo(**config.dataset_args)
 | 
			
		||||
    dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
 | 
			
		||||
    if dataset is None:
 | 
			
		||||
        raise ValueError(f"{split} dataset not provided")
 | 
			
		||||
 | 
			
		||||
    # iterate over the sequences in the dataset
 | 
			
		||||
    for sequence_name in dataset.sequence_names():
 | 
			
		||||
 | 
			
		||||
@ -4,18 +4,36 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from typing import Dict, Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Optional, Sequence
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
 | 
			
		||||
 | 
			
		||||
from .dataset_zoo import Datasets
 | 
			
		||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
 | 
			
		||||
from .scene_batch_sampler import SceneBatchSampler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Dataloaders:
 | 
			
		||||
    """
 | 
			
		||||
    A provider of dataloaders for implicitron.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
 | 
			
		||||
        train: a dataloader for training
 | 
			
		||||
        val: a dataloader for validating during training
 | 
			
		||||
        test: a dataloader for final evaluation
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    train: Optional[torch.utils.data.DataLoader[FrameData]]
 | 
			
		||||
    val: Optional[torch.utils.data.DataLoader[FrameData]]
 | 
			
		||||
    test: Optional[torch.utils.data.DataLoader[FrameData]]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dataloader_zoo(
 | 
			
		||||
    datasets: Dict[str, ImplicitronDatasetBase],
 | 
			
		||||
    dataset_name: str = "co3d_singlesequence",
 | 
			
		||||
    datasets: Datasets,
 | 
			
		||||
    batch_size: int = 1,
 | 
			
		||||
    num_workers: int = 0,
 | 
			
		||||
    dataset_len: int = 1000,
 | 
			
		||||
@ -24,7 +42,7 @@ def dataloader_zoo(
 | 
			
		||||
    sample_consecutive_frames: bool = False,
 | 
			
		||||
    consecutive_frames_max_gap: int = 0,
 | 
			
		||||
    consecutive_frames_max_gap_seconds: float = 0.1,
 | 
			
		||||
) -> Dict[str, torch.utils.data.DataLoader]:
 | 
			
		||||
) -> Dataloaders:
 | 
			
		||||
    """
 | 
			
		||||
    Returns a set of dataloaders for a given set of datasets.
 | 
			
		||||
 | 
			
		||||
@ -57,44 +75,43 @@ def dataloader_zoo(
 | 
			
		||||
        dataloaders: A dictionary containing the
 | 
			
		||||
            `"dataset_subset_name": torch_dataloader_object` key, value pairs.
 | 
			
		||||
    """
 | 
			
		||||
    if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
 | 
			
		||||
        raise ValueError(f"Unsupported dataset: {dataset_name}")
 | 
			
		||||
 | 
			
		||||
    dataloaders = {}
 | 
			
		||||
    dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate}
 | 
			
		||||
 | 
			
		||||
    if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
 | 
			
		||||
        for dataset_set, dataset in datasets.items():
 | 
			
		||||
            num_samples = {
 | 
			
		||||
                "train": dataset_len,
 | 
			
		||||
                "val": dataset_len_val,
 | 
			
		||||
                "test": None,
 | 
			
		||||
            }[dataset_set]
 | 
			
		||||
    def train_or_val_loader(
 | 
			
		||||
        dataset: Optional[ImplicitronDatasetBase], num_batches: int
 | 
			
		||||
    ) -> Optional[torch.utils.data.DataLoader]:
 | 
			
		||||
        if dataset is None:
 | 
			
		||||
            return None
 | 
			
		||||
        batch_sampler = SceneBatchSampler(
 | 
			
		||||
            dataset,
 | 
			
		||||
            batch_size,
 | 
			
		||||
            num_batches=len(dataset) if num_batches <= 0 else num_batches,
 | 
			
		||||
            images_per_seq_options=images_per_seq_options,
 | 
			
		||||
            sample_consecutive_frames=sample_consecutive_frames,
 | 
			
		||||
            consecutive_frames_max_gap=consecutive_frames_max_gap,
 | 
			
		||||
            consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds,
 | 
			
		||||
        )
 | 
			
		||||
        return torch.utils.data.DataLoader(
 | 
			
		||||
            dataset,
 | 
			
		||||
            batch_sampler=batch_sampler,
 | 
			
		||||
            **dataloader_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
            if dataset_set == "test":
 | 
			
		||||
                batch_sampler = dataset.get_eval_batches()
 | 
			
		||||
            else:
 | 
			
		||||
                assert num_samples is not None
 | 
			
		||||
                num_samples = len(dataset) if num_samples <= 0 else num_samples
 | 
			
		||||
                batch_sampler = SceneBatchSampler(
 | 
			
		||||
                    dataset,
 | 
			
		||||
                    batch_size,
 | 
			
		||||
                    num_batches=num_samples,
 | 
			
		||||
                    images_per_seq_options=images_per_seq_options,
 | 
			
		||||
                    sample_consecutive_frames=sample_consecutive_frames,
 | 
			
		||||
                    consecutive_frames_max_gap=consecutive_frames_max_gap,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            dataloaders[dataset_set] = torch.utils.data.DataLoader(
 | 
			
		||||
                dataset,
 | 
			
		||||
                num_workers=num_workers,
 | 
			
		||||
                batch_sampler=batch_sampler,
 | 
			
		||||
                collate_fn=FrameData.collate,
 | 
			
		||||
            )
 | 
			
		||||
    train_dataloader = train_or_val_loader(datasets.train, dataset_len)
 | 
			
		||||
    val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
 | 
			
		||||
 | 
			
		||||
    test_dataset = datasets.test
 | 
			
		||||
    if test_dataset is not None:
 | 
			
		||||
        test_dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
            test_dataset,
 | 
			
		||||
            batch_sampler=test_dataset.get_eval_batches(),
 | 
			
		||||
            **dataloader_kwargs,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unsupported dataset: {dataset_name}")
 | 
			
		||||
        test_dataloader = None
 | 
			
		||||
 | 
			
		||||
    return dataloaders
 | 
			
		||||
    return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
enable_get_default_args(dataloader_zoo)
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,10 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from typing import Any, Dict, List, Optional, Sequence
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, Iterator, List, Optional, Sequence
 | 
			
		||||
 | 
			
		||||
from iopath.common.file_io import PathManager
 | 
			
		||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
 | 
			
		||||
@ -52,6 +52,34 @@ CO3D_CATEGORIES: List[str] = list(reversed([
 | 
			
		||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Datasets:
 | 
			
		||||
    """
 | 
			
		||||
    A provider of datasets for implicitron.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
 | 
			
		||||
        train: a dataset for training
 | 
			
		||||
        val: a dataset for validating during training
 | 
			
		||||
        test: a dataset for final evaluation
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    train: Optional[ImplicitronDatasetBase]
 | 
			
		||||
    val: Optional[ImplicitronDatasetBase]
 | 
			
		||||
    test: Optional[ImplicitronDatasetBase]
 | 
			
		||||
 | 
			
		||||
    def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
 | 
			
		||||
        """
 | 
			
		||||
        Iterator over all datasets.
 | 
			
		||||
        """
 | 
			
		||||
        if self.train is not None:
 | 
			
		||||
            yield self.train
 | 
			
		||||
        if self.val is not None:
 | 
			
		||||
            yield self.val
 | 
			
		||||
        if self.test is not None:
 | 
			
		||||
            yield self.test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dataset_zoo(
 | 
			
		||||
    dataset_name: str = "co3d_singlesequence",
 | 
			
		||||
    dataset_root: str = _CO3D_DATASET_ROOT,
 | 
			
		||||
@ -69,7 +97,7 @@ def dataset_zoo(
 | 
			
		||||
    only_test_set: bool = False,
 | 
			
		||||
    aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
 | 
			
		||||
    path_manager: Optional[PathManager] = None,
 | 
			
		||||
) -> Dict[str, ImplicitronDatasetBase]:
 | 
			
		||||
) -> Datasets:
 | 
			
		||||
    """
 | 
			
		||||
    Generates the training / validation and testing dataset objects.
 | 
			
		||||
 | 
			
		||||
@ -101,12 +129,31 @@ def dataset_zoo(
 | 
			
		||||
        datasets: A dictionary containing the
 | 
			
		||||
            `"dataset_subset_name": torch_dataset_object` key, value pairs.
 | 
			
		||||
    """
 | 
			
		||||
    datasets = {}
 | 
			
		||||
    if only_test_set and test_on_train:
 | 
			
		||||
        raise ValueError("Cannot have only_test_set and test_on_train")
 | 
			
		||||
 | 
			
		||||
    # TODO:
 | 
			
		||||
    # - implement loading multiple categories
 | 
			
		||||
 | 
			
		||||
    if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
 | 
			
		||||
        frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
 | 
			
		||||
        sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
 | 
			
		||||
        subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
 | 
			
		||||
        common_kwargs = {
 | 
			
		||||
            "dataset_root": dataset_root,
 | 
			
		||||
            "limit_to": limit_to,
 | 
			
		||||
            "limit_sequences_to": limit_sequences_to,
 | 
			
		||||
            "load_point_clouds": load_point_clouds,
 | 
			
		||||
            "mask_images": mask_images,
 | 
			
		||||
            "mask_depths": mask_depths,
 | 
			
		||||
            "pick_sequence": restrict_sequence_name,
 | 
			
		||||
            "path_manager": path_manager,
 | 
			
		||||
            "frame_annotations_file": frame_file,
 | 
			
		||||
            "sequence_annotations_file": sequence_file,
 | 
			
		||||
            "subset_lists_file": subset_lists_file,
 | 
			
		||||
            **aux_dataset_kwargs,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        # This maps the common names of the dataset subsets ("train"/"val"/"test")
 | 
			
		||||
        # to the names of the subsets in the CO3D dataset.
 | 
			
		||||
        set_names_mapping = _get_co3d_set_names_mapping(
 | 
			
		||||
@ -156,65 +203,48 @@ def dataset_zoo(
 | 
			
		||||
            # overwrite the restrict_sequence_name
 | 
			
		||||
            restrict_sequence_name = [eval_sequence_name]
 | 
			
		||||
 | 
			
		||||
        for dataset, subsets in set_names_mapping.items():
 | 
			
		||||
            frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
 | 
			
		||||
 | 
			
		||||
            sequence_file = os.path.join(
 | 
			
		||||
                dataset_root, category, "sequence_annotations.jgz"
 | 
			
		||||
        train_dataset = None
 | 
			
		||||
        if not only_test_set:
 | 
			
		||||
            train_dataset = ImplicitronDataset(
 | 
			
		||||
                n_frames_per_sequence=n_frames_per_sequence,
 | 
			
		||||
                subsets=set_names_mapping["train"],
 | 
			
		||||
                **common_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
 | 
			
		||||
 | 
			
		||||
            # TODO: maybe directly in param list
 | 
			
		||||
            params = {
 | 
			
		||||
                **copy.deepcopy(aux_dataset_kwargs),
 | 
			
		||||
                "frame_annotations_file": frame_file,
 | 
			
		||||
                "sequence_annotations_file": sequence_file,
 | 
			
		||||
                "subset_lists_file": subset_lists_file,
 | 
			
		||||
                "dataset_root": dataset_root,
 | 
			
		||||
                "limit_to": limit_to,
 | 
			
		||||
                "limit_sequences_to": limit_sequences_to,
 | 
			
		||||
                "n_frames_per_sequence": n_frames_per_sequence
 | 
			
		||||
                if dataset == "train"
 | 
			
		||||
                else -1,
 | 
			
		||||
                "subsets": subsets,
 | 
			
		||||
                "load_point_clouds": load_point_clouds,
 | 
			
		||||
                "mask_images": mask_images,
 | 
			
		||||
                "mask_depths": mask_depths,
 | 
			
		||||
                "pick_sequence": restrict_sequence_name,
 | 
			
		||||
                "path_manager": path_manager,
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            datasets[dataset] = ImplicitronDataset(**params)
 | 
			
		||||
            if dataset == "test":
 | 
			
		||||
                if len(restrict_sequence_name) > 0:
 | 
			
		||||
                    eval_batch_index = [
 | 
			
		||||
                        b for b in eval_batch_index if b[0][0] in restrict_sequence_name
 | 
			
		||||
                    ]
 | 
			
		||||
 | 
			
		||||
                datasets[dataset].eval_batches = datasets[
 | 
			
		||||
                    dataset
 | 
			
		||||
                ].seq_frame_index_to_dataset_index(eval_batch_index)
 | 
			
		||||
 | 
			
		||||
        if assert_single_seq:
 | 
			
		||||
            # check theres only one sequence in all datasets
 | 
			
		||||
            assert (
 | 
			
		||||
                len(
 | 
			
		||||
                    {
 | 
			
		||||
                        e["frame_annotation"].sequence_name
 | 
			
		||||
                        for dset in datasets.values()
 | 
			
		||||
                        for e in dset.frame_annots
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                <= 1
 | 
			
		||||
            ), "Multiple sequences loaded but expected one"
 | 
			
		||||
        if test_on_train:
 | 
			
		||||
            assert train_dataset is not None
 | 
			
		||||
            val_dataset = test_dataset = train_dataset
 | 
			
		||||
        else:
 | 
			
		||||
            val_dataset = ImplicitronDataset(
 | 
			
		||||
                n_frames_per_sequence=1,
 | 
			
		||||
                subsets=set_names_mapping["val"],
 | 
			
		||||
                **common_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
            test_dataset = ImplicitronDataset(
 | 
			
		||||
                n_frames_per_sequence=1,
 | 
			
		||||
                subsets=set_names_mapping["test"],
 | 
			
		||||
                **common_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
            if len(restrict_sequence_name) > 0:
 | 
			
		||||
                eval_batch_index = [
 | 
			
		||||
                    b for b in eval_batch_index if b[0][0] in restrict_sequence_name
 | 
			
		||||
                ]
 | 
			
		||||
            test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
 | 
			
		||||
                eval_batch_index
 | 
			
		||||
            )
 | 
			
		||||
        datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset)
 | 
			
		||||
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unsupported dataset: {dataset_name}")
 | 
			
		||||
 | 
			
		||||
    if test_on_train:
 | 
			
		||||
        datasets["val"] = datasets["train"]
 | 
			
		||||
        datasets["test"] = datasets["train"]
 | 
			
		||||
    if assert_single_seq:
 | 
			
		||||
        # check there's only one sequence in all datasets
 | 
			
		||||
        sequence_names = {
 | 
			
		||||
            sequence_name
 | 
			
		||||
            for dset in datasets.iter_datasets()
 | 
			
		||||
            for sequence_name in dset.sequence_names()
 | 
			
		||||
        }
 | 
			
		||||
        if len(sequence_names) > 1:
 | 
			
		||||
            raise ValueError("Multiple sequences loaded but expected one")
 | 
			
		||||
 | 
			
		||||
    return datasets
 | 
			
		||||
 | 
			
		||||
@ -231,6 +261,11 @@ def _get_co3d_set_names_mapping(
 | 
			
		||||
    Returns the mapping of the common dataset subset names ("train"/"val"/"test")
 | 
			
		||||
    to the names of the corresponding subsets in the CO3D dataset
 | 
			
		||||
    ("test_known"/"test_unseen"/"train_known"/"train_unseen").
 | 
			
		||||
 | 
			
		||||
    The keys returned will be
 | 
			
		||||
        - train (if not only_test)
 | 
			
		||||
        - val (if not test_on_train)
 | 
			
		||||
        - test (if not test_on_train)
 | 
			
		||||
    """
 | 
			
		||||
    single_seq = dataset_name == "co3d_singlesequence"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -115,13 +115,12 @@ def evaluate_dbir_for_category(
 | 
			
		||||
        path_manager=path_manager,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    dataloaders = dataloader_zoo(
 | 
			
		||||
        datasets,
 | 
			
		||||
        dataset_name=f"co3d_{task}",
 | 
			
		||||
    )
 | 
			
		||||
    dataloaders = dataloader_zoo(datasets)
 | 
			
		||||
 | 
			
		||||
    test_dataset = datasets["test"]
 | 
			
		||||
    test_dataloader = dataloaders["test"]
 | 
			
		||||
    test_dataset = datasets.test
 | 
			
		||||
    test_dataloader = dataloaders.test
 | 
			
		||||
    if test_dataset is None or test_dataloader is None:
 | 
			
		||||
        raise ValueError("must have a test dataset.")
 | 
			
		||||
 | 
			
		||||
    if task == "singlesequence":
 | 
			
		||||
        # all_source_cameras are needed for evaluation of the
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user