mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
allow get_default_args(JsonIndexDataset)
Summary: Changes to JsonIndexDataset to make it fit with OmegaConf.structured. Also match some default values to what the provider defaults to. Reviewed By: davnov134 Differential Revision: D36666704 fbshipit-source-id: 65b059a1dbaa240ce85c3e8762b7c3db3b5a6e75
This commit is contained in:
parent
8bc0a04e86
commit
1fb268dea6
@ -8,7 +8,6 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
@ -182,8 +181,28 @@ class FrameData(Mapping[str, Any]):
|
||||
return torch.utils.data._utils.collate.default_collate(batch)
|
||||
|
||||
|
||||
class _GenericWorkaround:
|
||||
"""
|
||||
OmegaConf.structured has a weirdness when you try to apply
|
||||
it to a dataclass whose first base class is a Generic which is not
|
||||
Dict. The issue is with a function called get_dict_key_value_types
|
||||
in omegaconf/_utils.py.
|
||||
For example this fails:
|
||||
|
||||
@dataclass(eq=False)
|
||||
class D(torch.utils.data.Dataset[int]):
|
||||
a: int = 3
|
||||
|
||||
OmegaConf.structured(D)
|
||||
|
||||
We avoid the problem by adding this class as an extra base class.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||
"""
|
||||
Base class to describe a dataset to be used with Implicitron.
|
||||
|
||||
@ -195,10 +214,11 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
which will describe one frame in one sequence.
|
||||
"""
|
||||
|
||||
# Maps sequence name to the sequence's global frame indices.
|
||||
# _seq_to_idx is a member which implementations can define.
|
||||
# It maps sequence name to the sequence's global frame indices.
|
||||
# It is used for the default implementations of some functions in this class.
|
||||
# Implementations which override them are free to ignore this member.
|
||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
# Implementations which override them are free to ignore it.
|
||||
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
raise NotImplementedError
|
||||
@ -232,6 +252,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
|
||||
def sequence_names(self) -> Iterable[str]:
|
||||
"""Returns an iterator over sequence names in the dataset."""
|
||||
# pyre-ignore[16]
|
||||
return self._seq_to_idx.keys()
|
||||
|
||||
def sequence_frames_in_order(
|
||||
@ -250,6 +271,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
||||
`dataset_idx` is the index within the dataset.
|
||||
`None` timestamps are replaced with 0s.
|
||||
"""
|
||||
# pyre-ignore[16]
|
||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||
|
||||
|
@ -13,12 +13,12 @@ import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -30,7 +30,6 @@ from typing import (
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from iopath.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
from pytorch3d.io import IO
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
@ -116,7 +115,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
Type[types.FrameAnnotation]
|
||||
] = types.FrameAnnotation
|
||||
|
||||
path_manager: Optional[PathManager] = None
|
||||
path_manager: Any = None
|
||||
frame_annotations_file: str = ""
|
||||
sequence_annotations_file: str = ""
|
||||
subset_lists_file: str = ""
|
||||
@ -135,18 +134,18 @@ class JsonIndexDataset(DatasetBase):
|
||||
max_points: int = 0
|
||||
mask_images: bool = False
|
||||
mask_depths: bool = False
|
||||
image_height: Optional[int] = 256
|
||||
image_width: Optional[int] = 256
|
||||
box_crop: bool = False
|
||||
image_height: Optional[int] = 800
|
||||
image_width: Optional[int] = 800
|
||||
box_crop: bool = True
|
||||
box_crop_mask_thr: float = 0.4
|
||||
box_crop_context: float = 1.0
|
||||
remove_empty_masks: bool = False
|
||||
box_crop_context: float = 0.3
|
||||
remove_empty_masks: bool = True
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
sort_frames: bool = False
|
||||
eval_batches: Optional[List[List[int]]] = None
|
||||
frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||
eval_batches: Any = None
|
||||
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
||||
@ -172,9 +171,11 @@ class JsonIndexDataset(DatasetBase):
|
||||
# TODO: check the frame numbers are unique
|
||||
_dataset_seq_frame_n_index = {
|
||||
seq: {
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
||||
for idx in seq_idx
|
||||
}
|
||||
# pyre-ignore[16]
|
||||
for seq, seq_idx in self._seq_to_idx.items()
|
||||
}
|
||||
|
||||
@ -184,6 +185,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
# Check that the loaded frame path is consistent
|
||||
# with the one stored in self.frame_annots.
|
||||
assert os.path.normpath(
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots[idx]["frame_annotation"].image.path
|
||||
) == os.path.normpath(
|
||||
path
|
||||
@ -194,19 +196,23 @@ class JsonIndexDataset(DatasetBase):
|
||||
return batches_idx
|
||||
|
||||
def __str__(self) -> str:
|
||||
# pyre-ignore[16]
|
||||
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyre-ignore[16]
|
||||
return len(self.frame_annots)
|
||||
|
||||
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
||||
return entry["subset"]
|
||||
|
||||
def __getitem__(self, index) -> FrameData:
|
||||
# pyre-ignore[16]
|
||||
if index >= len(self.frame_annots):
|
||||
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
||||
|
||||
entry = self.frame_annots[index]["frame_annotation"]
|
||||
# pyre-ignore[16]
|
||||
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
||||
frame_data = FrameData(
|
||||
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
||||
@ -441,6 +447,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
)
|
||||
if not frame_annots_list:
|
||||
raise ValueError("Empty dataset!")
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots = [
|
||||
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
||||
]
|
||||
@ -452,6 +459,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
||||
if not seq_annots:
|
||||
raise ValueError("Empty sequences file!")
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
||||
|
||||
def _load_subset_lists(self) -> None:
|
||||
@ -467,7 +475,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
for subset, frames in subset_to_seq_frame.items()
|
||||
for _, _, path in frames
|
||||
}
|
||||
|
||||
# pyre-ignore[16]
|
||||
for frame in self.frame_annots:
|
||||
frame["subset"] = frame_path_to_subset.get(
|
||||
frame["frame_annotation"].image.path, None
|
||||
@ -480,6 +488,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
|
||||
def _sort_frames(self) -> None:
|
||||
# Sort frames to have them grouped by sequence, ordered by timestamp
|
||||
# pyre-ignore[16]
|
||||
self.frame_annots = sorted(
|
||||
self.frame_annots,
|
||||
key=lambda f: (
|
||||
@ -491,6 +500,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
def _filter_db(self) -> None:
|
||||
if self.remove_empty_masks:
|
||||
logger.info("Removing images with empty masks.")
|
||||
# pyre-ignore[16]
|
||||
old_len = len(self.frame_annots)
|
||||
|
||||
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
||||
@ -531,6 +541,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
|
||||
if len(self.limit_category_to) > 0:
|
||||
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots = {
|
||||
name: entry
|
||||
for name, entry in self.seq_annots.items()
|
||||
@ -568,6 +579,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
if self.n_frames_per_sequence > 0:
|
||||
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||
keep_idx = []
|
||||
# pyre-ignore[16]
|
||||
for seq, seq_indices in self._seq_to_idx.items():
|
||||
# infer the seed from the sequence name, this is reproducible
|
||||
# and makes the selection differ for different sequences
|
||||
@ -597,14 +609,20 @@ class JsonIndexDataset(DatasetBase):
|
||||
self._invalidate_seq_to_idx()
|
||||
|
||||
if filter_seq_annots:
|
||||
# pyre-ignore[16]
|
||||
self.seq_annots = {
|
||||
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
|
||||
k: v
|
||||
for k, v in self.seq_annots.items()
|
||||
# pyre-ignore[16]
|
||||
if k in self._seq_to_idx
|
||||
}
|
||||
|
||||
def _invalidate_seq_to_idx(self) -> None:
|
||||
seq_to_idx = defaultdict(list)
|
||||
# pyre-ignore[16]
|
||||
for idx, entry in enumerate(self.frame_annots):
|
||||
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
||||
# pyre-ignore[16]
|
||||
self._seq_to_idx = seq_to_idx
|
||||
|
||||
def _resize_image(
|
||||
@ -644,6 +662,7 @@ class JsonIndexDataset(DatasetBase):
|
||||
) -> List[Tuple[int, float]]:
|
||||
out: List[Tuple[int, float]] = []
|
||||
for idx in idxs:
|
||||
# pyre-ignore[16]
|
||||
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
||||
out.append(
|
||||
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
||||
|
@ -44,6 +44,7 @@ def get_implicitron_sequence_pointcloud(
|
||||
sequence_entries = [
|
||||
ei
|
||||
for ei in sequence_entries
|
||||
# pyre-ignore[16]
|
||||
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
|
||||
== sequence_name
|
||||
]
|
||||
|
@ -9,6 +9,7 @@ import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from tests.common_testing import get_tests_dir
|
||||
|
||||
@ -20,6 +21,33 @@ class TestDataSource(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
|
||||
def _test_omegaconf_generic_failure(self):
|
||||
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
@dataclass
|
||||
class D(torch.utils.data.Dataset[int]):
|
||||
a: int = 3
|
||||
|
||||
OmegaConf.structured(D)
|
||||
|
||||
def _test_omegaconf_ListList(self):
|
||||
# Demo that OmegaConf doesn't support nested lists
|
||||
from dataclasses import dataclass
|
||||
from typing import Sequence
|
||||
|
||||
@dataclass
|
||||
class A:
|
||||
a: Sequence[Sequence[int]] = ((32,),)
|
||||
|
||||
OmegaConf.structured(A)
|
||||
|
||||
def test_JsonIndexDataset_args(self):
|
||||
# test that JsonIndexDataset works with get_default_args
|
||||
get_default_args(JsonIndexDataset)
|
||||
|
||||
def test_one(self):
|
||||
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
|
||||
cfg = get_default_args(ImplicitronDataSource)
|
||||
|
@ -51,6 +51,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
image_height=self.image_size,
|
||||
image_width=self.image_size,
|
||||
box_crop=True,
|
||||
remove_empty_masks=False,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
self.bg_color = (0.0, 0.0, 0.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user