ImplicitronDatasetBase -> DatasetBase

Summary: Just a rename

Reviewed By: shapovalov

Differential Revision: D36516885

fbshipit-source-id: 2126e3aee26d89a95afdb31e06942d61cbe88d5a
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 0f12c51646
commit 9fe15da3cd
7 changed files with 17 additions and 17 deletions

View File

@ -10,7 +10,7 @@ from typing import Optional, Sequence
import torch import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from .dataset_base import FrameData, ImplicitronDatasetBase from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import DatasetMap from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler from .scene_batch_sampler import SceneBatchSampler
@ -102,7 +102,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
} }
def train_or_val_loader( def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int dataset: Optional[DatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]: ) -> Optional[torch.utils.data.DataLoader]:
if dataset is None: if dataset is None:
return None return None

View File

@ -183,7 +183,7 @@ class FrameData(Mapping[str, Any]):
@dataclass(eq=False) @dataclass(eq=False)
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): class DatasetBase(torch.utils.data.Dataset[FrameData]):
""" """
Base class to describe a dataset to be used with Implicitron. Base class to describe a dataset to be used with Implicitron.

View File

@ -10,7 +10,7 @@ from typing import Iterator, Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.implicitron.tools.config import ReplaceableBase
from .dataset_base import ImplicitronDatasetBase from .dataset_base import DatasetBase
@dataclass @dataclass
@ -25,11 +25,11 @@ class DatasetMap:
test: a dataset for final evaluation test: a dataset for final evaluation
""" """
train: Optional[ImplicitronDatasetBase] train: Optional[DatasetBase]
val: Optional[ImplicitronDatasetBase] val: Optional[DatasetBase]
test: Optional[ImplicitronDatasetBase] test: Optional[DatasetBase]
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]: def __getitem__(self, split: str) -> Optional[DatasetBase]:
""" """
Get one of the datasets by key (name of data split) Get one of the datasets by key (name of data split)
""" """
@ -37,7 +37,7 @@ class DatasetMap:
raise ValueError(f"{split} was not a valid split name (train/val/test)") raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split) return getattr(self, split)
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]: def iter_datasets(self) -> Iterator[DatasetBase]:
""" """
Iterator over all datasets. Iterator over all datasets.
""" """

View File

@ -37,7 +37,7 @@ from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds from pytorch3d.structures.pointclouds import Pointclouds
from . import types from . import types
from .dataset_base import FrameData, ImplicitronDatasetBase from .dataset_base import DatasetBase, FrameData
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,7 +49,7 @@ class FrameAnnotsEntry(TypedDict):
@dataclass(eq=False) @dataclass(eq=False)
class ImplicitronDataset(ImplicitronDatasetBase): class ImplicitronDataset(DatasetBase):
""" """
A class for the Common Objects in 3D (CO3D) dataset. A class for the Common Objects in 3D (CO3D) dataset.

View File

@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np import numpy as np
from torch.utils.data.sampler import Sampler from torch.utils.data.sampler import Sampler
from .dataset_base import ImplicitronDatasetBase from .dataset_base import DatasetBase
@dataclass(eq=False) # TODO: do we need this if not init from config? @dataclass(eq=False) # TODO: do we need this if not init from config?
@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]):
of sequences. of sequences.
""" """
dataset: ImplicitronDatasetBase dataset: DatasetBase
batch_size: int batch_size: int
num_batches: int num_batches: int
# the sampler first samples a random element k from this list and then # the sampler first samples a random element k from this list and then

View File

@ -13,7 +13,7 @@ import lpips
import torch import torch
from iopath.common.file_io import PathManager from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import ( from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES, CO3D_CATEGORIES,
@ -188,7 +188,7 @@ def _print_aggregate_results(
def _get_all_source_cameras( def _get_all_source_cameras(
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8 dataset: DatasetBase, sequence_name: str, num_workers: int = 8
): ):
""" """
Loads all training cameras of a given sequence. Loads all training cameras of a given sequence.

View File

@ -9,7 +9,7 @@ import unittest
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@ -19,7 +19,7 @@ class MockFrameAnnotation:
frame_timestamp: float = 0.0 frame_timestamp: float = 0.0
class MockDataset(ImplicitronDatasetBase): class MockDataset(DatasetBase):
def __init__(self, num_seq, max_frame_gap=1): def __init__(self, num_seq, max_frame_gap=1):
""" """
Makes a gap of max_frame_gap frame numbers in the middle of each sequence Makes a gap of max_frame_gap frame numbers in the middle of each sequence