allow get_default_args(JsonIndexDataset)

Summary: Changes to JsonIndexDataset to make it fit with OmegaConf.structured. Also match some default values to what the provider defaults to.

Reviewed By: davnov134

Differential Revision: D36666704

fbshipit-source-id: 65b059a1dbaa240ce85c3e8762b7c3db3b5a6e75
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 8bc0a04e86
commit 1fb268dea6
5 changed files with 90 additions and 19 deletions

View File

@ -8,7 +8,6 @@ from collections import defaultdict
from dataclasses import dataclass, field, fields
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
@ -182,8 +181,28 @@ class FrameData(Mapping[str, Any]):
return torch.utils.data._utils.collate.default_collate(batch)
class _GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False)
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""
pass
@dataclass(eq=False)
class DatasetBase(torch.utils.data.Dataset[FrameData]):
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.
@ -195,10 +214,11 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
which will describe one frame in one sequence.
"""
# Maps sequence name to the sequence's global frame indices.
# _seq_to_idx is a member which implementations can define.
# It maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member.
_seq_to_idx: Dict[str, List[int]] = field(init=False)
# Implementations which override them are free to ignore it.
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
@ -232,6 +252,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset."""
# pyre-ignore[16]
return self._seq_to_idx.keys()
def sequence_frames_in_order(
@ -250,6 +271,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
`dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s.
"""
# pyre-ignore[16]
seq_frame_indices = self._seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)

View File

@ -13,12 +13,12 @@ import os
import random
import warnings
from collections import defaultdict
from dataclasses import dataclass, field
from dataclasses import dataclass
from itertools import islice
from pathlib import Path
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Sequence,
@ -30,7 +30,6 @@ from typing import (
import numpy as np
import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras
@ -116,7 +115,7 @@ class JsonIndexDataset(DatasetBase):
Type[types.FrameAnnotation]
] = types.FrameAnnotation
path_manager: Optional[PathManager] = None
path_manager: Any = None
frame_annotations_file: str = ""
sequence_annotations_file: str = ""
subset_lists_file: str = ""
@ -135,18 +134,18 @@ class JsonIndexDataset(DatasetBase):
max_points: int = 0
mask_images: bool = False
mask_depths: bool = False
image_height: Optional[int] = 256
image_width: Optional[int] = 256
box_crop: bool = False
image_height: Optional[int] = 800
image_width: Optional[int] = 800
box_crop: bool = True
box_crop_mask_thr: float = 0.4
box_crop_context: float = 1.0
remove_empty_masks: bool = False
box_crop_context: float = 0.3
remove_empty_masks: bool = True
n_frames_per_sequence: int = -1
seed: int = 0
sort_frames: bool = False
eval_batches: Optional[List[List[int]]] = None
frame_annots: List[FrameAnnotsEntry] = field(init=False)
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
eval_batches: Any = None
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
def __post_init__(self) -> None:
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
@ -172,9 +171,11 @@ class JsonIndexDataset(DatasetBase):
# TODO: check the frame numbers are unique
_dataset_seq_frame_n_index = {
seq: {
# pyre-ignore[16]
self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx
}
# pyre-ignore[16]
for seq, seq_idx in self._seq_to_idx.items()
}
@ -184,6 +185,7 @@ class JsonIndexDataset(DatasetBase):
# Check that the loaded frame path is consistent
# with the one stored in self.frame_annots.
assert os.path.normpath(
# pyre-ignore[16]
self.frame_annots[idx]["frame_annotation"].image.path
) == os.path.normpath(
path
@ -194,19 +196,23 @@ class JsonIndexDataset(DatasetBase):
return batches_idx
def __str__(self) -> str:
# pyre-ignore[16]
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
def __len__(self) -> int:
# pyre-ignore[16]
return len(self.frame_annots)
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["subset"]
def __getitem__(self, index) -> FrameData:
# pyre-ignore[16]
if index >= len(self.frame_annots):
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"]
# pyre-ignore[16]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
@ -441,6 +447,7 @@ class JsonIndexDataset(DatasetBase):
)
if not frame_annots_list:
raise ValueError("Empty dataset!")
# pyre-ignore[16]
self.frame_annots = [
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
]
@ -452,6 +459,7 @@ class JsonIndexDataset(DatasetBase):
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
if not seq_annots:
raise ValueError("Empty sequences file!")
# pyre-ignore[16]
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None:
@ -467,7 +475,7 @@ class JsonIndexDataset(DatasetBase):
for subset, frames in subset_to_seq_frame.items()
for _, _, path in frames
}
# pyre-ignore[16]
for frame in self.frame_annots:
frame["subset"] = frame_path_to_subset.get(
frame["frame_annotation"].image.path, None
@ -480,6 +488,7 @@ class JsonIndexDataset(DatasetBase):
def _sort_frames(self) -> None:
# Sort frames to have them grouped by sequence, ordered by timestamp
# pyre-ignore[16]
self.frame_annots = sorted(
self.frame_annots,
key=lambda f: (
@ -491,6 +500,7 @@ class JsonIndexDataset(DatasetBase):
def _filter_db(self) -> None:
if self.remove_empty_masks:
logger.info("Removing images with empty masks.")
# pyre-ignore[16]
old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
@ -531,6 +541,7 @@ class JsonIndexDataset(DatasetBase):
if len(self.limit_category_to) > 0:
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
# pyre-ignore[16]
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
@ -568,6 +579,7 @@ class JsonIndexDataset(DatasetBase):
if self.n_frames_per_sequence > 0:
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
# pyre-ignore[16]
for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
@ -597,14 +609,20 @@ class JsonIndexDataset(DatasetBase):
self._invalidate_seq_to_idx()
if filter_seq_annots:
# pyre-ignore[16]
self.seq_annots = {
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
k: v
for k, v in self.seq_annots.items()
# pyre-ignore[16]
if k in self._seq_to_idx
}
def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list)
# pyre-ignore[16]
for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
# pyre-ignore[16]
self._seq_to_idx = seq_to_idx
def _resize_image(
@ -644,6 +662,7 @@ class JsonIndexDataset(DatasetBase):
) -> List[Tuple[int, float]]:
out: List[Tuple[int, float]] = []
for idx in idxs:
# pyre-ignore[16]
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)

View File

@ -44,6 +44,7 @@ def get_implicitron_sequence_pointcloud(
sequence_entries = [
ei
for ei in sequence_entries
# pyre-ignore[16]
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name
]

View File

@ -9,6 +9,7 @@ import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import get_default_args
from tests.common_testing import get_tests_dir
@ -20,6 +21,33 @@ class TestDataSource(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def _test_omegaconf_generic_failure(self):
# OmegaConf possible bug - this is why we need _GenericWorkaround
from dataclasses import dataclass
import torch
@dataclass
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
def _test_omegaconf_ListList(self):
# Demo that OmegaConf doesn't support nested lists
from dataclasses import dataclass
from typing import Sequence
@dataclass
class A:
a: Sequence[Sequence[int]] = ((32,),)
OmegaConf.structured(A)
def test_JsonIndexDataset_args(self):
# test that JsonIndexDataset works with get_default_args
get_default_args(JsonIndexDataset)
def test_one(self):
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
cfg = get_default_args(ImplicitronDataSource)

View File

@ -51,6 +51,7 @@ class TestEvaluation(unittest.TestCase):
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
remove_empty_masks=False,
path_manager=path_manager,
)
self.bg_color = (0.0, 0.0, 0.0)