mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
ImplicitronDatasetBase -> DatasetBase
Summary: Just a rename Reviewed By: shapovalov Differential Revision: D36516885 fbshipit-source-id: 2126e3aee26d89a95afdb31e06942d61cbe88d5a
This commit is contained in:
parent
0f12c51646
commit
9fe15da3cd
@ -10,7 +10,7 @@ from typing import Optional, Sequence
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
|
|
||||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
from .dataset_base import DatasetBase, FrameData
|
||||||
from .dataset_map_provider import DatasetMap
|
from .dataset_map_provider import DatasetMap
|
||||||
from .scene_batch_sampler import SceneBatchSampler
|
from .scene_batch_sampler import SceneBatchSampler
|
||||||
|
|
||||||
@ -102,7 +102,7 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def train_or_val_loader(
|
def train_or_val_loader(
|
||||||
dataset: Optional[ImplicitronDatasetBase], num_batches: int
|
dataset: Optional[DatasetBase], num_batches: int
|
||||||
) -> Optional[torch.utils.data.DataLoader]:
|
) -> Optional[torch.utils.data.DataLoader]:
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
|
@ -183,7 +183,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
|
class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||||
"""
|
"""
|
||||||
Base class to describe a dataset to be used with Implicitron.
|
Base class to describe a dataset to be used with Implicitron.
|
||||||
|
|
||||||
|
@ -10,7 +10,7 @@ from typing import Iterator, Optional
|
|||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
|
|
||||||
from .dataset_base import ImplicitronDatasetBase
|
from .dataset_base import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -25,11 +25,11 @@ class DatasetMap:
|
|||||||
test: a dataset for final evaluation
|
test: a dataset for final evaluation
|
||||||
"""
|
"""
|
||||||
|
|
||||||
train: Optional[ImplicitronDatasetBase]
|
train: Optional[DatasetBase]
|
||||||
val: Optional[ImplicitronDatasetBase]
|
val: Optional[DatasetBase]
|
||||||
test: Optional[ImplicitronDatasetBase]
|
test: Optional[DatasetBase]
|
||||||
|
|
||||||
def __getitem__(self, split: str) -> Optional[ImplicitronDatasetBase]:
|
def __getitem__(self, split: str) -> Optional[DatasetBase]:
|
||||||
"""
|
"""
|
||||||
Get one of the datasets by key (name of data split)
|
Get one of the datasets by key (name of data split)
|
||||||
"""
|
"""
|
||||||
@ -37,7 +37,7 @@ class DatasetMap:
|
|||||||
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||||
return getattr(self, split)
|
return getattr(self, split)
|
||||||
|
|
||||||
def iter_datasets(self) -> Iterator[ImplicitronDatasetBase]:
|
def iter_datasets(self) -> Iterator[DatasetBase]:
|
||||||
"""
|
"""
|
||||||
Iterator over all datasets.
|
Iterator over all datasets.
|
||||||
"""
|
"""
|
||||||
|
@ -37,7 +37,7 @@ from pytorch3d.renderer.cameras import PerspectiveCameras
|
|||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
|
||||||
from . import types
|
from . import types
|
||||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
from .dataset_base import DatasetBase, FrameData
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -49,7 +49,7 @@ class FrameAnnotsEntry(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class ImplicitronDataset(ImplicitronDatasetBase):
|
class ImplicitronDataset(DatasetBase):
|
||||||
"""
|
"""
|
||||||
A class for the Common Objects in 3D (CO3D) dataset.
|
A class for the Common Objects in 3D (CO3D) dataset.
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from typing import Iterable, Iterator, List, Sequence, Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from torch.utils.data.sampler import Sampler
|
from torch.utils.data.sampler import Sampler
|
||||||
|
|
||||||
from .dataset_base import ImplicitronDatasetBase
|
from .dataset_base import DatasetBase
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
@dataclass(eq=False) # TODO: do we need this if not init from config?
|
||||||
@ -22,7 +22,7 @@ class SceneBatchSampler(Sampler[List[int]]):
|
|||||||
of sequences.
|
of sequences.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dataset: ImplicitronDatasetBase
|
dataset: DatasetBase
|
||||||
batch_size: int
|
batch_size: int
|
||||||
num_batches: int
|
num_batches: int
|
||||||
# the sampler first samples a random element k from this list and then
|
# the sampler first samples a random element k from this list and then
|
||||||
|
@ -13,7 +13,7 @@ import lpips
|
|||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
from iopath.common.file_io import PathManager
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider import (
|
||||||
CO3D_CATEGORIES,
|
CO3D_CATEGORIES,
|
||||||
@ -188,7 +188,7 @@ def _print_aggregate_results(
|
|||||||
|
|
||||||
|
|
||||||
def _get_all_source_cameras(
|
def _get_all_source_cameras(
|
||||||
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
|
dataset: DatasetBase, sequence_name: str, num_workers: int = 8
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Loads all training cameras of a given sequence.
|
Loads all training cameras of a given sequence.
|
||||||
|
@ -9,7 +9,7 @@ import unittest
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import ImplicitronDatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||||
|
|
||||||
|
|
||||||
@ -19,7 +19,7 @@ class MockFrameAnnotation:
|
|||||||
frame_timestamp: float = 0.0
|
frame_timestamp: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
class MockDataset(ImplicitronDatasetBase):
|
class MockDataset(DatasetBase):
|
||||||
def __init__(self, num_seq, max_frame_gap=1):
|
def __init__(self, num_seq, max_frame_gap=1):
|
||||||
"""
|
"""
|
||||||
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
||||||
|
Loading…
x
Reference in New Issue
Block a user