mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
return types for dataset_zoo, dataloader_zoo
Summary: Stronger typing for these functions Reviewed By: shapovalov Differential Revision: D36170489 fbshipit-source-id: a2104b29dbbbcfcf91ae1d076cd6b0e3d2030c0b
This commit is contained in:
parent
90ab219d88
commit
2c1901522a
@ -64,8 +64,8 @@ import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo, Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo, Datasets
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
@ -453,7 +453,6 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
|
||||
# setup datasets
|
||||
datasets = dataset_zoo(**cfg.dataset_args)
|
||||
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
|
||||
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
|
||||
|
||||
# init the model
|
||||
@ -499,7 +498,7 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders["train"],
|
||||
dataloaders.train,
|
||||
optimizer,
|
||||
False,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
@ -508,12 +507,12 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
)
|
||||
|
||||
# val loop (optional)
|
||||
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
|
||||
if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
|
||||
trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders["val"],
|
||||
dataloaders.val,
|
||||
optimizer,
|
||||
True,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
@ -523,11 +522,11 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
|
||||
# eval loop (optional)
|
||||
if (
|
||||
"test" in dataloaders
|
||||
dataloaders.test is not None
|
||||
and cfg.test_interval > 0
|
||||
and epoch % cfg.test_interval == 0
|
||||
):
|
||||
run_eval(cfg, model, stats, dataloaders["test"], device=device)
|
||||
run_eval(cfg, model, stats, dataloaders.test, device=device)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
@ -550,23 +549,28 @@ def run_training(cfg: DictConfig, device: str = "cpu"):
|
||||
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
|
||||
|
||||
|
||||
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
|
||||
def _eval_and_dump(
|
||||
cfg, datasets: Datasets, dataloaders: Dataloaders, model, stats, device
|
||||
) -> None:
|
||||
"""
|
||||
Run the evaluation loop with the test data loader and
|
||||
save the predictions to the `exp_dir`.
|
||||
"""
|
||||
|
||||
if "test" not in dataloaders:
|
||||
dataloader = dataloaders.test
|
||||
|
||||
if dataloader is None:
|
||||
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
|
||||
|
||||
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
|
||||
all_source_cameras = (
|
||||
_get_all_source_cameras(datasets["train"])
|
||||
if eval_task == "singlesequence"
|
||||
else None
|
||||
)
|
||||
if eval_task == "singlesequence":
|
||||
if datasets.train is None:
|
||||
raise ValueError("train dataset must be provided")
|
||||
all_source_cameras = _get_all_source_cameras(datasets.train)
|
||||
else:
|
||||
all_source_cameras = None
|
||||
results = run_eval(
|
||||
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
|
||||
cfg, model, all_source_cameras, dataloader, eval_task, device=device
|
||||
)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
|
@ -27,6 +27,7 @@ from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
ImplicitronDatasetBase,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.utils import is_train_frame
|
||||
from pytorch3d.implicitron.models.base_model import EvaluationMode
|
||||
@ -342,7 +343,10 @@ def export_scenes(
|
||||
model.eval()
|
||||
|
||||
# Setup the dataset
|
||||
dataset = dataset_zoo(**config.dataset_args)[split]
|
||||
datasets = dataset_zoo(**config.dataset_args)
|
||||
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
|
||||
if dataset is None:
|
||||
raise ValueError(f"{split} dataset not provided")
|
||||
|
||||
# iterate over the sequences in the dataset
|
||||
for sequence_name in dataset.sequence_names():
|
||||
|
@ -4,18 +4,36 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Dict, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .dataset_zoo import Datasets
|
||||
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dataloaders:
|
||||
"""
|
||||
A provider of dataloaders for implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a dataloader for training
|
||||
val: a dataloader for validating during training
|
||||
test: a dataloader for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
val: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
test: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
|
||||
|
||||
def dataloader_zoo(
|
||||
datasets: Dict[str, ImplicitronDatasetBase],
|
||||
dataset_name: str = "co3d_singlesequence",
|
||||
datasets: Datasets,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
dataset_len: int = 1000,
|
||||
@ -24,7 +42,7 @@ def dataloader_zoo(
|
||||
sample_consecutive_frames: bool = False,
|
||||
consecutive_frames_max_gap: int = 0,
|
||||
consecutive_frames_max_gap_seconds: float = 0.1,
|
||||
) -> Dict[str, torch.utils.data.DataLoader]:
|
||||
) -> Dataloaders:
|
||||
"""
|
||||
Returns a set of dataloaders for a given set of datasets.
|
||||
|
||||
@ -57,44 +75,43 @@ def dataloader_zoo(
|
||||
dataloaders: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
||||
"""
|
||||
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
dataloaders = {}
|
||||
dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate}
|
||||
|
||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
for dataset_set, dataset in datasets.items():
|
||||
num_samples = {
|
||||
"train": dataset_len,
|
||||
"val": dataset_len_val,
|
||||
"test": None,
|
||||
}[dataset_set]
|
||||
def train_or_val_loader(
|
||||
dataset: Optional[ImplicitronDatasetBase], num_batches: int
|
||||
) -> Optional[torch.utils.data.DataLoader]:
|
||||
if dataset is None:
|
||||
return None
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
batch_size,
|
||||
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||
images_per_seq_options=images_per_seq_options,
|
||||
sample_consecutive_frames=sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=consecutive_frames_max_gap,
|
||||
consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds,
|
||||
)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
|
||||
if dataset_set == "test":
|
||||
batch_sampler = dataset.get_eval_batches()
|
||||
else:
|
||||
assert num_samples is not None
|
||||
num_samples = len(dataset) if num_samples <= 0 else num_samples
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
batch_size,
|
||||
num_batches=num_samples,
|
||||
images_per_seq_options=images_per_seq_options,
|
||||
sample_consecutive_frames=sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=consecutive_frames_max_gap,
|
||||
)
|
||||
|
||||
dataloaders[dataset_set] = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
train_dataloader = train_or_val_loader(datasets.train, dataset_len)
|
||||
val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
|
||||
|
||||
test_dataset = datasets.test
|
||||
if test_dataset is not None:
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_sampler=test_dataset.get_eval_batches(),
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
test_dataloader = None
|
||||
|
||||
return dataloaders
|
||||
return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
|
||||
|
||||
|
||||
enable_get_default_args(dataloader_zoo)
|
||||
|
@ -5,10 +5,10 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Iterator, List, Optional, Sequence
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
@ -52,6 +52,34 @@ CO3D_CATEGORIES: List[str] = list(reversed([
|
||||
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||
|
||||
|
||||
@dataclass
|
||||
class Datasets:
|
||||
"""
|
||||
A provider of datasets for implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a dataset for training
|
||||
val: a dataset for validating during training
|
||||
test: a dataset for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[ImplicitronDatasetBase]
|
||||
val: Optional[ImplicitronDatasetBase]
|
||||
test: Optional[ImplicitronDatasetBase]
|
||||
|
||||
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
|
||||
"""
|
||||
Iterator over all datasets.
|
||||
"""
|
||||
if self.train is not None:
|
||||
yield self.train
|
||||
if self.val is not None:
|
||||
yield self.val
|
||||
if self.test is not None:
|
||||
yield self.test
|
||||
|
||||
|
||||
def dataset_zoo(
|
||||
dataset_name: str = "co3d_singlesequence",
|
||||
dataset_root: str = _CO3D_DATASET_ROOT,
|
||||
@ -69,7 +97,7 @@ def dataset_zoo(
|
||||
only_test_set: bool = False,
|
||||
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
|
||||
path_manager: Optional[PathManager] = None,
|
||||
) -> Dict[str, ImplicitronDatasetBase]:
|
||||
) -> Datasets:
|
||||
"""
|
||||
Generates the training / validation and testing dataset objects.
|
||||
|
||||
@ -101,12 +129,31 @@ def dataset_zoo(
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
"""
|
||||
datasets = {}
|
||||
if only_test_set and test_on_train:
|
||||
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||
|
||||
# TODO:
|
||||
# - implement loading multiple categories
|
||||
|
||||
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
||||
common_kwargs = {
|
||||
"dataset_root": dataset_root,
|
||||
"limit_to": limit_to,
|
||||
"limit_sequences_to": limit_sequences_to,
|
||||
"load_point_clouds": load_point_clouds,
|
||||
"mask_images": mask_images,
|
||||
"mask_depths": mask_depths,
|
||||
"pick_sequence": restrict_sequence_name,
|
||||
"path_manager": path_manager,
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
**aux_dataset_kwargs,
|
||||
}
|
||||
|
||||
# This maps the common names of the dataset subsets ("train"/"val"/"test")
|
||||
# to the names of the subsets in the CO3D dataset.
|
||||
set_names_mapping = _get_co3d_set_names_mapping(
|
||||
@ -156,65 +203,48 @@ def dataset_zoo(
|
||||
# overwrite the restrict_sequence_name
|
||||
restrict_sequence_name = [eval_sequence_name]
|
||||
|
||||
for dataset, subsets in set_names_mapping.items():
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
|
||||
sequence_file = os.path.join(
|
||||
dataset_root, category, "sequence_annotations.jgz"
|
||||
train_dataset = None
|
||||
if not only_test_set:
|
||||
train_dataset = ImplicitronDataset(
|
||||
n_frames_per_sequence=n_frames_per_sequence,
|
||||
subsets=set_names_mapping["train"],
|
||||
**common_kwargs,
|
||||
)
|
||||
|
||||
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
|
||||
|
||||
# TODO: maybe directly in param list
|
||||
params = {
|
||||
**copy.deepcopy(aux_dataset_kwargs),
|
||||
"frame_annotations_file": frame_file,
|
||||
"sequence_annotations_file": sequence_file,
|
||||
"subset_lists_file": subset_lists_file,
|
||||
"dataset_root": dataset_root,
|
||||
"limit_to": limit_to,
|
||||
"limit_sequences_to": limit_sequences_to,
|
||||
"n_frames_per_sequence": n_frames_per_sequence
|
||||
if dataset == "train"
|
||||
else -1,
|
||||
"subsets": subsets,
|
||||
"load_point_clouds": load_point_clouds,
|
||||
"mask_images": mask_images,
|
||||
"mask_depths": mask_depths,
|
||||
"pick_sequence": restrict_sequence_name,
|
||||
"path_manager": path_manager,
|
||||
}
|
||||
|
||||
datasets[dataset] = ImplicitronDataset(**params)
|
||||
if dataset == "test":
|
||||
if len(restrict_sequence_name) > 0:
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||
]
|
||||
|
||||
datasets[dataset].eval_batches = datasets[
|
||||
dataset
|
||||
].seq_frame_index_to_dataset_index(eval_batch_index)
|
||||
|
||||
if assert_single_seq:
|
||||
# check theres only one sequence in all datasets
|
||||
assert (
|
||||
len(
|
||||
{
|
||||
e["frame_annotation"].sequence_name
|
||||
for dset in datasets.values()
|
||||
for e in dset.frame_annots
|
||||
}
|
||||
)
|
||||
<= 1
|
||||
), "Multiple sequences loaded but expected one"
|
||||
if test_on_train:
|
||||
assert train_dataset is not None
|
||||
val_dataset = test_dataset = train_dataset
|
||||
else:
|
||||
val_dataset = ImplicitronDataset(
|
||||
n_frames_per_sequence=1,
|
||||
subsets=set_names_mapping["val"],
|
||||
**common_kwargs,
|
||||
)
|
||||
test_dataset = ImplicitronDataset(
|
||||
n_frames_per_sequence=1,
|
||||
subsets=set_names_mapping["test"],
|
||||
**common_kwargs,
|
||||
)
|
||||
if len(restrict_sequence_name) > 0:
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||
]
|
||||
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index
|
||||
)
|
||||
datasets = Datasets(train=train_dataset, val=val_dataset, test=test_dataset)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported dataset: {dataset_name}")
|
||||
|
||||
if test_on_train:
|
||||
datasets["val"] = datasets["train"]
|
||||
datasets["test"] = datasets["train"]
|
||||
if assert_single_seq:
|
||||
# check there's only one sequence in all datasets
|
||||
sequence_names = {
|
||||
sequence_name
|
||||
for dset in datasets.iter_datasets()
|
||||
for sequence_name in dset.sequence_names()
|
||||
}
|
||||
if len(sequence_names) > 1:
|
||||
raise ValueError("Multiple sequences loaded but expected one")
|
||||
|
||||
return datasets
|
||||
|
||||
@ -231,6 +261,11 @@ def _get_co3d_set_names_mapping(
|
||||
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
|
||||
to the names of the corresponding subsets in the CO3D dataset
|
||||
("test_known"/"test_unseen"/"train_known"/"train_unseen").
|
||||
|
||||
The keys returned will be
|
||||
- train (if not only_test)
|
||||
- val (if not test_on_train)
|
||||
- test (if not test_on_train)
|
||||
"""
|
||||
single_seq = dataset_name == "co3d_singlesequence"
|
||||
|
||||
|
@ -115,13 +115,12 @@ def evaluate_dbir_for_category(
|
||||
path_manager=path_manager,
|
||||
)
|
||||
|
||||
dataloaders = dataloader_zoo(
|
||||
datasets,
|
||||
dataset_name=f"co3d_{task}",
|
||||
)
|
||||
dataloaders = dataloader_zoo(datasets)
|
||||
|
||||
test_dataset = datasets["test"]
|
||||
test_dataloader = dataloaders["test"]
|
||||
test_dataset = datasets.test
|
||||
test_dataloader = dataloaders.test
|
||||
if test_dataset is None or test_dataloader is None:
|
||||
raise ValueError("must have a test dataset.")
|
||||
|
||||
if task == "singlesequence":
|
||||
# all_source_cameras are needed for evaluation of the
|
||||
|
Loading…
x
Reference in New Issue
Block a user