allow get_default_args(JsonIndexDataset)

Summary: Changes to JsonIndexDataset to make it fit with OmegaConf.structured. Also match some default values to what the provider defaults to.

Reviewed By: davnov134

Differential Revision: D36666704

fbshipit-source-id: 65b059a1dbaa240ce85c3e8762b7c3db3b5a6e75
This commit is contained in:
Jeremy Reizenstein 2022-06-10 12:22:46 -07:00 committed by Facebook GitHub Bot
parent 8bc0a04e86
commit 1fb268dea6
5 changed files with 90 additions and 19 deletions

View File

@ -8,7 +8,6 @@ from collections import defaultdict
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import ( from typing import (
Any, Any,
Dict,
Iterable, Iterable,
Iterator, Iterator,
List, List,
@ -182,8 +181,28 @@ class FrameData(Mapping[str, Any]):
return torch.utils.data._utils.collate.default_collate(batch) return torch.utils.data._utils.collate.default_collate(batch)
class _GenericWorkaround:
"""
OmegaConf.structured has a weirdness when you try to apply
it to a dataclass whose first base class is a Generic which is not
Dict. The issue is with a function called get_dict_key_value_types
in omegaconf/_utils.py.
For example this fails:
@dataclass(eq=False) @dataclass(eq=False)
class DatasetBase(torch.utils.data.Dataset[FrameData]): class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
We avoid the problem by adding this class as an extra base class.
"""
pass
@dataclass(eq=False)
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
""" """
Base class to describe a dataset to be used with Implicitron. Base class to describe a dataset to be used with Implicitron.
@ -195,10 +214,11 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
which will describe one frame in one sequence. which will describe one frame in one sequence.
""" """
# Maps sequence name to the sequence's global frame indices. # _seq_to_idx is a member which implementations can define.
# It maps sequence name to the sequence's global frame indices.
# It is used for the default implementations of some functions in this class. # It is used for the default implementations of some functions in this class.
# Implementations which override them are free to ignore this member. # Implementations which override them are free to ignore it.
_seq_to_idx: Dict[str, List[int]] = field(init=False) # _seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int: def __len__(self) -> int:
raise NotImplementedError raise NotImplementedError
@ -232,6 +252,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
def sequence_names(self) -> Iterable[str]: def sequence_names(self) -> Iterable[str]:
"""Returns an iterator over sequence names in the dataset.""" """Returns an iterator over sequence names in the dataset."""
# pyre-ignore[16]
return self._seq_to_idx.keys() return self._seq_to_idx.keys()
def sequence_frames_in_order( def sequence_frames_in_order(
@ -250,6 +271,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
`dataset_idx` is the index within the dataset. `dataset_idx` is the index within the dataset.
`None` timestamps are replaced with 0s. `None` timestamps are replaced with 0s.
""" """
# pyre-ignore[16]
seq_frame_indices = self._seq_to_idx[seq_name] seq_frame_indices = self._seq_to_idx[seq_name]
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)

View File

@ -13,12 +13,12 @@ import os
import random import random
import warnings import warnings
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any,
ClassVar, ClassVar,
Dict,
List, List,
Optional, Optional,
Sequence, Sequence,
@ -30,7 +30,6 @@ from typing import (
import numpy as np import numpy as np
import torch import torch
from iopath.common.file_io import PathManager
from PIL import Image from PIL import Image
from pytorch3d.io import IO from pytorch3d.io import IO
from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer.cameras import PerspectiveCameras
@ -116,7 +115,7 @@ class JsonIndexDataset(DatasetBase):
Type[types.FrameAnnotation] Type[types.FrameAnnotation]
] = types.FrameAnnotation ] = types.FrameAnnotation
path_manager: Optional[PathManager] = None path_manager: Any = None
frame_annotations_file: str = "" frame_annotations_file: str = ""
sequence_annotations_file: str = "" sequence_annotations_file: str = ""
subset_lists_file: str = "" subset_lists_file: str = ""
@ -135,18 +134,18 @@ class JsonIndexDataset(DatasetBase):
max_points: int = 0 max_points: int = 0
mask_images: bool = False mask_images: bool = False
mask_depths: bool = False mask_depths: bool = False
image_height: Optional[int] = 256 image_height: Optional[int] = 800
image_width: Optional[int] = 256 image_width: Optional[int] = 800
box_crop: bool = False box_crop: bool = True
box_crop_mask_thr: float = 0.4 box_crop_mask_thr: float = 0.4
box_crop_context: float = 1.0 box_crop_context: float = 0.3
remove_empty_masks: bool = False remove_empty_masks: bool = True
n_frames_per_sequence: int = -1 n_frames_per_sequence: int = -1
seed: int = 0 seed: int = 0
sort_frames: bool = False sort_frames: bool = False
eval_batches: Optional[List[List[int]]] = None eval_batches: Any = None
frame_annots: List[FrameAnnotsEntry] = field(init=False) # frame_annots: List[FrameAnnotsEntry] = field(init=False)
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
@ -172,9 +171,11 @@ class JsonIndexDataset(DatasetBase):
# TODO: check the frame numbers are unique # TODO: check the frame numbers are unique
_dataset_seq_frame_n_index = { _dataset_seq_frame_n_index = {
seq: { seq: {
# pyre-ignore[16]
self.frame_annots[idx]["frame_annotation"].frame_number: idx self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx for idx in seq_idx
} }
# pyre-ignore[16]
for seq, seq_idx in self._seq_to_idx.items() for seq, seq_idx in self._seq_to_idx.items()
} }
@ -184,6 +185,7 @@ class JsonIndexDataset(DatasetBase):
# Check that the loaded frame path is consistent # Check that the loaded frame path is consistent
# with the one stored in self.frame_annots. # with the one stored in self.frame_annots.
assert os.path.normpath( assert os.path.normpath(
# pyre-ignore[16]
self.frame_annots[idx]["frame_annotation"].image.path self.frame_annots[idx]["frame_annotation"].image.path
) == os.path.normpath( ) == os.path.normpath(
path path
@ -194,19 +196,23 @@ class JsonIndexDataset(DatasetBase):
return batches_idx return batches_idx
def __str__(self) -> str: def __str__(self) -> str:
# pyre-ignore[16]
return f"JsonIndexDataset #frames={len(self.frame_annots)}" return f"JsonIndexDataset #frames={len(self.frame_annots)}"
def __len__(self) -> int: def __len__(self) -> int:
# pyre-ignore[16]
return len(self.frame_annots) return len(self.frame_annots)
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["subset"] return entry["subset"]
def __getitem__(self, index) -> FrameData: def __getitem__(self, index) -> FrameData:
# pyre-ignore[16]
if index >= len(self.frame_annots): if index >= len(self.frame_annots):
raise IndexError(f"index {index} out of range {len(self.frame_annots)}") raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"] entry = self.frame_annots[index]["frame_annotation"]
# pyre-ignore[16]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData( frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long), frame_number=_safe_as_tensor(entry.frame_number, torch.long),
@ -441,6 +447,7 @@ class JsonIndexDataset(DatasetBase):
) )
if not frame_annots_list: if not frame_annots_list:
raise ValueError("Empty dataset!") raise ValueError("Empty dataset!")
# pyre-ignore[16]
self.frame_annots = [ self.frame_annots = [
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
] ]
@ -452,6 +459,7 @@ class JsonIndexDataset(DatasetBase):
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
if not seq_annots: if not seq_annots:
raise ValueError("Empty sequences file!") raise ValueError("Empty sequences file!")
# pyre-ignore[16]
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None: def _load_subset_lists(self) -> None:
@ -467,7 +475,7 @@ class JsonIndexDataset(DatasetBase):
for subset, frames in subset_to_seq_frame.items() for subset, frames in subset_to_seq_frame.items()
for _, _, path in frames for _, _, path in frames
} }
# pyre-ignore[16]
for frame in self.frame_annots: for frame in self.frame_annots:
frame["subset"] = frame_path_to_subset.get( frame["subset"] = frame_path_to_subset.get(
frame["frame_annotation"].image.path, None frame["frame_annotation"].image.path, None
@ -480,6 +488,7 @@ class JsonIndexDataset(DatasetBase):
def _sort_frames(self) -> None: def _sort_frames(self) -> None:
# Sort frames to have them grouped by sequence, ordered by timestamp # Sort frames to have them grouped by sequence, ordered by timestamp
# pyre-ignore[16]
self.frame_annots = sorted( self.frame_annots = sorted(
self.frame_annots, self.frame_annots,
key=lambda f: ( key=lambda f: (
@ -491,6 +500,7 @@ class JsonIndexDataset(DatasetBase):
def _filter_db(self) -> None: def _filter_db(self) -> None:
if self.remove_empty_masks: if self.remove_empty_masks:
logger.info("Removing images with empty masks.") logger.info("Removing images with empty masks.")
# pyre-ignore[16]
old_len = len(self.frame_annots) old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
@ -531,6 +541,7 @@ class JsonIndexDataset(DatasetBase):
if len(self.limit_category_to) > 0: if len(self.limit_category_to) > 0:
logger.info(f"Limiting dataset to categories: {self.limit_category_to}") logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
# pyre-ignore[16]
self.seq_annots = { self.seq_annots = {
name: entry name: entry
for name, entry in self.seq_annots.items() for name, entry in self.seq_annots.items()
@ -568,6 +579,7 @@ class JsonIndexDataset(DatasetBase):
if self.n_frames_per_sequence > 0: if self.n_frames_per_sequence > 0:
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = [] keep_idx = []
# pyre-ignore[16]
for seq, seq_indices in self._seq_to_idx.items(): for seq, seq_indices in self._seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible # infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences # and makes the selection differ for different sequences
@ -597,14 +609,20 @@ class JsonIndexDataset(DatasetBase):
self._invalidate_seq_to_idx() self._invalidate_seq_to_idx()
if filter_seq_annots: if filter_seq_annots:
# pyre-ignore[16]
self.seq_annots = { self.seq_annots = {
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx k: v
for k, v in self.seq_annots.items()
# pyre-ignore[16]
if k in self._seq_to_idx
} }
def _invalidate_seq_to_idx(self) -> None: def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list) seq_to_idx = defaultdict(list)
# pyre-ignore[16]
for idx, entry in enumerate(self.frame_annots): for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
# pyre-ignore[16]
self._seq_to_idx = seq_to_idx self._seq_to_idx = seq_to_idx
def _resize_image( def _resize_image(
@ -644,6 +662,7 @@ class JsonIndexDataset(DatasetBase):
) -> List[Tuple[int, float]]: ) -> List[Tuple[int, float]]:
out: List[Tuple[int, float]] = [] out: List[Tuple[int, float]] = []
for idx in idxs: for idx in idxs:
# pyre-ignore[16]
frame_annotation = self.frame_annots[idx]["frame_annotation"] frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append( out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp) (frame_annotation.frame_number, frame_annotation.frame_timestamp)

View File

@ -44,6 +44,7 @@ def get_implicitron_sequence_pointcloud(
sequence_entries = [ sequence_entries = [
ei ei
for ei in sequence_entries for ei in sequence_entries
# pyre-ignore[16]
if dataset.frame_annots[ei]["frame_annotation"].sequence_name if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name == sequence_name
] ]

View File

@ -9,6 +9,7 @@ import unittest
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import get_default_args from pytorch3d.implicitron.tools.config import get_default_args
from tests.common_testing import get_tests_dir from tests.common_testing import get_tests_dir
@ -20,6 +21,33 @@ class TestDataSource(unittest.TestCase):
def setUp(self): def setUp(self):
self.maxDiff = None self.maxDiff = None
def _test_omegaconf_generic_failure(self):
# OmegaConf possible bug - this is why we need _GenericWorkaround
from dataclasses import dataclass
import torch
@dataclass
class D(torch.utils.data.Dataset[int]):
a: int = 3
OmegaConf.structured(D)
def _test_omegaconf_ListList(self):
# Demo that OmegaConf doesn't support nested lists
from dataclasses import dataclass
from typing import Sequence
@dataclass
class A:
a: Sequence[Sequence[int]] = ((32,),)
OmegaConf.structured(A)
def test_JsonIndexDataset_args(self):
# test that JsonIndexDataset works with get_default_args
get_default_args(JsonIndexDataset)
def test_one(self): def test_one(self):
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}): with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
cfg = get_default_args(ImplicitronDataSource) cfg = get_default_args(ImplicitronDataSource)

View File

@ -51,6 +51,7 @@ class TestEvaluation(unittest.TestCase):
image_height=self.image_size, image_height=self.image_size,
image_width=self.image_size, image_width=self.image_size,
box_crop=True, box_crop=True,
remove_empty_masks=False,
path_manager=path_manager, path_manager=path_manager,
) )
self.bg_color = (0.0, 0.0, 0.0) self.bg_color = (0.0, 0.0, 0.0)