diff --git a/pytorch3d/implicitron/dataset/data_loader_map_provider.py b/pytorch3d/implicitron/dataset/data_loader_map_provider.py index d22a4678..187d670e 100644 --- a/pytorch3d/implicitron/dataset/data_loader_map_provider.py +++ b/pytorch3d/implicitron/dataset/data_loader_map_provider.py @@ -10,7 +10,7 @@ from typing import Optional, Sequence import torch from pytorch3d.implicitron.tools.config import registry, ReplaceableBase -from .dataset_base import FrameData, ImplicitronDatasetBase +from .dataset_base import DatasetBase, FrameData from .dataset_map_provider import DatasetMap from .scene_batch_sampler import SceneBatchSampler @@ -102,7 +102,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): } def train_or_val_loader( - dataset: Optional[ImplicitronDatasetBase], num_batches: int + dataset: Optional[DatasetBase], num_batches: int ) -> Optional[torch.utils.data.DataLoader]: if dataset is None: return None diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 9eb98d4f..83859f6b 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -183,7 +183,7 @@ class FrameData(Mapping[str, Any]): @dataclass(eq=False) -class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): +class DatasetBase(torch.utils.data.Dataset[FrameData]): """ Base class to describe a dataset to be used with Implicitron. diff --git a/pytorch3d/implicitron/dataset/dataset_map_provider.py b/pytorch3d/implicitron/dataset/dataset_map_provider.py index 3177a0d7..810ce234 100644 --- a/pytorch3d/implicitron/dataset/dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/dataset_map_provider.py @@ -10,7 +10,7 @@ from typing import Iterator, Optional from pytorch3d.implicitron.tools.config import ReplaceableBase -from .dataset_base import ImplicitronDatasetBase +from .dataset_base import DatasetBase @dataclass @@ -25,11 +25,11 @@ class DatasetMap: test: a dataset for final evaluation """ - train: Optional[ImplicitronDatasetBase] - val: Optional[ImplicitronDatasetBase] - test: Optional[ImplicitronDatasetBase] + train: Optional[DatasetBase] + val: Optional[DatasetBase] + test: Optional[DatasetBase] - def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]: + def __getitem__(self, split: str) -> Optional[DatasetBase]: """ Get one of the datasets by key (name of data split) """ @@ -37,7 +37,7 @@ class DatasetMap: raise ValueError(f"{split} was not a valid split name (train/val/test)") return getattr(self, split) - def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]: + def iter_datasets(self) -> Iterator[DatasetBase]: """ Iterator over all datasets. """ diff --git a/pytorch3d/implicitron/dataset/implicitron_dataset.py b/pytorch3d/implicitron/dataset/implicitron_dataset.py index d5c158bd..78e8c00f 100644 --- a/pytorch3d/implicitron/dataset/implicitron_dataset.py +++ b/pytorch3d/implicitron/dataset/implicitron_dataset.py @@ -37,7 +37,7 @@ from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.structures.pointclouds import Pointclouds from . import types -from .dataset_base import FrameData, ImplicitronDatasetBase +from .dataset_base import DatasetBase, FrameData logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ class FrameAnnotsEntry(TypedDict): @dataclass(eq=False) -class ImplicitronDataset(ImplicitronDatasetBase): +class ImplicitronDataset(DatasetBase): """ A class for the Common Objects in 3D (CO3D) dataset. diff --git a/pytorch3d/implicitron/dataset/scene_batch_sampler.py b/pytorch3d/implicitron/dataset/scene_batch_sampler.py index 65973772..2012e706 100644 --- a/pytorch3d/implicitron/dataset/scene_batch_sampler.py +++ b/pytorch3d/implicitron/dataset/scene_batch_sampler.py @@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple import numpy as np from torch.utils.data.sampler import Sampler -from .dataset_base import ImplicitronDatasetBase +from .dataset_base import DatasetBase @dataclass(eq=False) # TODO: do we need this if not init from config? @@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]): of sequences. """ - dataset: ImplicitronDatasetBase + dataset: DatasetBase batch_size: int num_batches: int # the sampler first samples a random element k from this list and then diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index f8de2560..2be1c0df 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -13,7 +13,7 @@ import lpips import torch from iopath.common.file_io import PathManager from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task -from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( CO3D_CATEGORIES, @@ -188,7 +188,7 @@ def _print_aggregate_results( def _get_all_source_cameras( - dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8 + dataset: DatasetBase, sequence_name: str, num_workers: int = 8 ): """ Loads all training cameras of a given sequence. diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index 0029783a..571c95f8 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -9,7 +9,7 @@ import unittest from collections import defaultdict from dataclasses import dataclass -from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler @@ -19,7 +19,7 @@ class MockFrameAnnotation: frame_timestamp: float = 0.0 -class MockDataset(ImplicitronDatasetBase): +class MockDataset(DatasetBase): def __init__(self, num_seq, max_frame_gap=1): """ Makes a gap of max_frame_gap frame numbers in the middle of each sequence