ImplicitronDatasetBase -> DatasetBase

Summary: Just a rename

Reviewed By: shapovalov

Differential Revision: D36516885

fbshipit-source-id: 2126e3aee26d89a95afdb31e06942d61cbe88d5a
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 0f12c51646
commit 9fe15da3cd
7 changed files with 17 additions and 17 deletions

View File

@ -10,7 +10,7 @@ from typing import Optional, Sequence
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler
@ -102,7 +102,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
}
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
dataset: Optional[DatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None

View File

@ -183,7 +183,7 @@ class FrameData(Mapping[str, Any]):
@dataclass(eq=False)
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
class DatasetBase(torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.

View File

@ -10,7 +10,7 @@ from typing import Iterator, Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase
from .dataset_base import ImplicitronDatasetBase
from .dataset_base import DatasetBase
@dataclass
@ -25,11 +25,11 @@ class DatasetMap:
test: a dataset for final evaluation
"""
train: Optional[ImplicitronDatasetBase]
val: Optional[ImplicitronDatasetBase]
test: Optional[ImplicitronDatasetBase]
train: Optional[DatasetBase]
val: Optional[DatasetBase]
test: Optional[DatasetBase]
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]:
def __getitem__(self, split: str) -> Optional[DatasetBase]:
"""
Get one of the datasets by key (name of data split)
"""
@ -37,7 +37,7 @@ class DatasetMap:
raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split)
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
def iter_datasets(self) -> Iterator[DatasetBase]:
"""
Iterator over all datasets.
"""

View File

@ -37,7 +37,7 @@ from pytorch3d.renderer.cameras import PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds
from . import types
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_base import DatasetBase, FrameData
logger = logging.getLogger(__name__)
@ -49,7 +49,7 @@ class FrameAnnotsEntry(TypedDict):
@dataclass(eq=False)
class ImplicitronDataset(ImplicitronDatasetBase):
class ImplicitronDataset(DatasetBase):
"""
A class for the Common Objects in 3D (CO3D) dataset.

View File

@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
from .dataset_base import ImplicitronDatasetBase
from .dataset_base import DatasetBase
@dataclass(eq=False) # TODO: do we need this if not init from config?
@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]):
of sequences.
"""
dataset: ImplicitronDatasetBase
dataset: DatasetBase
batch_size: int
num_batches: int
# the sampler first samples a random element k from this list and then

View File

@ -13,7 +13,7 @@ import lpips
import torch
from iopath.common.file_io import PathManager
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
CO3D_CATEGORIES,
@ -188,7 +188,7 @@ def _print_aggregate_results(
def _get_all_source_cameras(
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
):
"""
Loads all training cameras of a given sequence.

View File

@ -9,7 +9,7 @@ import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@ -19,7 +19,7 @@ class MockFrameAnnotation:
frame_timestamp: float = 0.0
class MockDataset(ImplicitronDatasetBase):
class MockDataset(DatasetBase):
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence