mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
ImplicitronDatasetBase -> DatasetBase
Summary: Just a rename Reviewed By: shapovalov Differential Revision: D36516885 fbshipit-source-id: 2126e3aee26d89a95afdb31e06942d61cbe88d5a
This commit is contained in:
parent
0f12c51646
commit
9fe15da3cd
@ -10,7 +10,7 @@ from typing import Optional, Sequence
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
@ -102,7 +102,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
}
|
||||
|
||||
def train_or_val_loader(
|
||||
dataset: Optional[ImplicitronDatasetBase], num_batches: int
|
||||
dataset: Optional[DatasetBase], num_batches: int
|
||||
) -> Optional[torch.utils.data.DataLoader]:
|
||||
if dataset is None:
|
||||
return None
|
||||
|
@ -183,7 +183,7 @@ class FrameData(Mapping[str, Any]):
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
|
@ -10,7 +10,7 @@ from typing import Iterator, Optional
|
||||
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||
|
||||
from .dataset_base import ImplicitronDatasetBase
|
||||
from .dataset_base import DatasetBase
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -25,11 +25,11 @@ class DatasetMap:
|
||||
test: a dataset for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[ImplicitronDatasetBase]
|
||||
val: Optional[ImplicitronDatasetBase]
|
||||
test: Optional[ImplicitronDatasetBase]
|
||||
train: Optional[DatasetBase]
|
||||
val: Optional[DatasetBase]
|
||||
test: Optional[DatasetBase]
|
||||
|
||||
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]:
|
||||
def __getitem__(self, split: str) -> Optional[DatasetBase]:
|
||||
"""
|
||||
Get one of the datasets by key (name of data split)
|
||||
"""
|
||||
@ -37,7 +37,7 @@ class DatasetMap:
|
||||
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||
return getattr(self, split)
|
||||
|
||||
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
|
||||
def iter_datasets(self) -> Iterator[DatasetBase]:
|
||||
"""
|
||||
Iterator over all datasets.
|
||||
"""
|
||||
|
@ -37,7 +37,7 @@ from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from . import types
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from .dataset_base import DatasetBase, FrameData
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -49,7 +49,7 @@ class FrameAnnotsEntry(TypedDict):
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class ImplicitronDataset(ImplicitronDatasetBase):
|
||||
class ImplicitronDataset(DatasetBase):
|
||||
"""
|
||||
A class for the Common Objects in 3D (CO3D) dataset.
|
||||
|
||||
|
@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
|
||||
import numpy as np
|
||||
from torch.utils.data.sampler import Sampler
|
||||
|
||||
from .dataset_base import ImplicitronDatasetBase
|
||||
from .dataset_base import DatasetBase
|
||||
|
||||
|
||||
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
||||
@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
||||
of sequences.
|
||||
"""
|
||||
|
||||
dataset: ImplicitronDatasetBase
|
||||
dataset: DatasetBase
|
||||
batch_size: int
|
||||
num_batches: int
|
||||
# the sampler first samples a random element k from this list and then
|
||||
|
@ -13,7 +13,7 @@ import lpips
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||
CO3D_CATEGORIES,
|
||||
@ -188,7 +188,7 @@ def _print_aggregate_results(
|
||||
|
||||
|
||||
def _get_all_source_cameras(
|
||||
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
|
||||
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
|
||||
):
|
||||
"""
|
||||
Loads all training cameras of a given sequence.
|
||||
|
@ -9,7 +9,7 @@ import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ class MockFrameAnnotation:
|
||||
frame_timestamp: float = 0.0
|
||||
|
||||
|
||||
class MockDataset(ImplicitronDatasetBase):
|
||||
class MockDataset(DatasetBase):
|
||||
def __init__(self, num_seq, max_frame_gap=1):
|
||||
"""
|
||||
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
||||
|
Loading…
x
Reference in New Issue
Block a user