mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
allow get_default_args(JsonIndexDataset)
Summary: Changes to JsonIndexDataset to make it fit with OmegaConf.structured. Also match some default values to what the provider defaults to. Reviewed By: davnov134 Differential Revision: D36666704 fbshipit-source-id: 65b059a1dbaa240ce85c3e8762b7c3db3b5a6e75
This commit is contained in:
parent
8bc0a04e86
commit
1fb268dea6
@ -8,7 +8,6 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
@ -182,8 +181,28 @@ class FrameData(Mapping[str, Any]):
|
|||||||
return torch.utils.data._utils.collate.default_collate(batch)
|
return torch.utils.data._utils.collate.default_collate(batch)
|
||||||
|
|
||||||
|
|
||||||
|
class _GenericWorkaround:
|
||||||
|
"""
|
||||||
|
OmegaConf.structured has a weirdness when you try to apply
|
||||||
|
it to a dataclass whose first base class is a Generic which is not
|
||||||
|
Dict. The issue is with a function called get_dict_key_value_types
|
||||||
|
in omegaconf/_utils.py.
|
||||||
|
For example this fails:
|
||||||
|
|
||||||
|
@dataclass(eq=False)
|
||||||
|
class D(torch.utils.data.Dataset[int]):
|
||||||
|
a: int = 3
|
||||||
|
|
||||||
|
OmegaConf.structured(D)
|
||||||
|
|
||||||
|
We avoid the problem by adding this class as an extra base class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
|
||||||
"""
|
"""
|
||||||
Base class to describe a dataset to be used with Implicitron.
|
Base class to describe a dataset to be used with Implicitron.
|
||||||
|
|
||||||
@ -195,10 +214,11 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
|||||||
which will describe one frame in one sequence.
|
which will describe one frame in one sequence.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Maps sequence name to the sequence's global frame indices.
|
# _seq_to_idx is a member which implementations can define.
|
||||||
|
# It maps sequence name to the sequence's global frame indices.
|
||||||
# It is used for the default implementations of some functions in this class.
|
# It is used for the default implementations of some functions in this class.
|
||||||
# Implementations which override them are free to ignore this member.
|
# Implementations which override them are free to ignore it.
|
||||||
_seq_to_idx: Dict[str, List[int]] = field(init=False)
|
# _seq_to_idx: Dict[str, List[int]] = field(init=False)
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@ -232,6 +252,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
|||||||
|
|
||||||
def sequence_names(self) -> Iterable[str]:
|
def sequence_names(self) -> Iterable[str]:
|
||||||
"""Returns an iterator over sequence names in the dataset."""
|
"""Returns an iterator over sequence names in the dataset."""
|
||||||
|
# pyre-ignore[16]
|
||||||
return self._seq_to_idx.keys()
|
return self._seq_to_idx.keys()
|
||||||
|
|
||||||
def sequence_frames_in_order(
|
def sequence_frames_in_order(
|
||||||
@ -250,6 +271,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]):
|
|||||||
`dataset_idx` is the index within the dataset.
|
`dataset_idx` is the index within the dataset.
|
||||||
`None` timestamps are replaced with 0s.
|
`None` timestamps are replaced with 0s.
|
||||||
"""
|
"""
|
||||||
|
# pyre-ignore[16]
|
||||||
seq_frame_indices = self._seq_to_idx[seq_name]
|
seq_frame_indices = self._seq_to_idx[seq_name]
|
||||||
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices)
|
||||||
|
|
||||||
|
@ -13,12 +13,12 @@ import os
|
|||||||
import random
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
Dict,
|
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -30,7 +30,6 @@ from typing import (
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from iopath.common.file_io import PathManager
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pytorch3d.io import IO
|
from pytorch3d.io import IO
|
||||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||||
@ -116,7 +115,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
Type[types.FrameAnnotation]
|
Type[types.FrameAnnotation]
|
||||||
] = types.FrameAnnotation
|
] = types.FrameAnnotation
|
||||||
|
|
||||||
path_manager: Optional[PathManager] = None
|
path_manager: Any = None
|
||||||
frame_annotations_file: str = ""
|
frame_annotations_file: str = ""
|
||||||
sequence_annotations_file: str = ""
|
sequence_annotations_file: str = ""
|
||||||
subset_lists_file: str = ""
|
subset_lists_file: str = ""
|
||||||
@ -135,18 +134,18 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
max_points: int = 0
|
max_points: int = 0
|
||||||
mask_images: bool = False
|
mask_images: bool = False
|
||||||
mask_depths: bool = False
|
mask_depths: bool = False
|
||||||
image_height: Optional[int] = 256
|
image_height: Optional[int] = 800
|
||||||
image_width: Optional[int] = 256
|
image_width: Optional[int] = 800
|
||||||
box_crop: bool = False
|
box_crop: bool = True
|
||||||
box_crop_mask_thr: float = 0.4
|
box_crop_mask_thr: float = 0.4
|
||||||
box_crop_context: float = 1.0
|
box_crop_context: float = 0.3
|
||||||
remove_empty_masks: bool = False
|
remove_empty_masks: bool = True
|
||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
sort_frames: bool = False
|
sort_frames: bool = False
|
||||||
eval_batches: Optional[List[List[int]]] = None
|
eval_batches: Any = None
|
||||||
frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||||
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
# pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`.
|
||||||
@ -172,9 +171,11 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
# TODO: check the frame numbers are unique
|
# TODO: check the frame numbers are unique
|
||||||
_dataset_seq_frame_n_index = {
|
_dataset_seq_frame_n_index = {
|
||||||
seq: {
|
seq: {
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
self.frame_annots[idx]["frame_annotation"].frame_number: idx
|
||||||
for idx in seq_idx
|
for idx in seq_idx
|
||||||
}
|
}
|
||||||
|
# pyre-ignore[16]
|
||||||
for seq, seq_idx in self._seq_to_idx.items()
|
for seq, seq_idx in self._seq_to_idx.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -184,6 +185,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
# Check that the loaded frame path is consistent
|
# Check that the loaded frame path is consistent
|
||||||
# with the one stored in self.frame_annots.
|
# with the one stored in self.frame_annots.
|
||||||
assert os.path.normpath(
|
assert os.path.normpath(
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots[idx]["frame_annotation"].image.path
|
self.frame_annots[idx]["frame_annotation"].image.path
|
||||||
) == os.path.normpath(
|
) == os.path.normpath(
|
||||||
path
|
path
|
||||||
@ -194,19 +196,23 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
return batches_idx
|
return batches_idx
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
|
# pyre-ignore[16]
|
||||||
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
|
# pyre-ignore[16]
|
||||||
return len(self.frame_annots)
|
return len(self.frame_annots)
|
||||||
|
|
||||||
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
|
||||||
return entry["subset"]
|
return entry["subset"]
|
||||||
|
|
||||||
def __getitem__(self, index) -> FrameData:
|
def __getitem__(self, index) -> FrameData:
|
||||||
|
# pyre-ignore[16]
|
||||||
if index >= len(self.frame_annots):
|
if index >= len(self.frame_annots):
|
||||||
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
|
||||||
|
|
||||||
entry = self.frame_annots[index]["frame_annotation"]
|
entry = self.frame_annots[index]["frame_annotation"]
|
||||||
|
# pyre-ignore[16]
|
||||||
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
|
||||||
frame_data = FrameData(
|
frame_data = FrameData(
|
||||||
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
|
||||||
@ -441,6 +447,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
)
|
)
|
||||||
if not frame_annots_list:
|
if not frame_annots_list:
|
||||||
raise ValueError("Empty dataset!")
|
raise ValueError("Empty dataset!")
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots = [
|
self.frame_annots = [
|
||||||
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
|
||||||
]
|
]
|
||||||
@ -452,6 +459,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
|
||||||
if not seq_annots:
|
if not seq_annots:
|
||||||
raise ValueError("Empty sequences file!")
|
raise ValueError("Empty sequences file!")
|
||||||
|
# pyre-ignore[16]
|
||||||
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
|
||||||
|
|
||||||
def _load_subset_lists(self) -> None:
|
def _load_subset_lists(self) -> None:
|
||||||
@ -467,7 +475,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
for subset, frames in subset_to_seq_frame.items()
|
for subset, frames in subset_to_seq_frame.items()
|
||||||
for _, _, path in frames
|
for _, _, path in frames
|
||||||
}
|
}
|
||||||
|
# pyre-ignore[16]
|
||||||
for frame in self.frame_annots:
|
for frame in self.frame_annots:
|
||||||
frame["subset"] = frame_path_to_subset.get(
|
frame["subset"] = frame_path_to_subset.get(
|
||||||
frame["frame_annotation"].image.path, None
|
frame["frame_annotation"].image.path, None
|
||||||
@ -480,6 +488,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
|
|
||||||
def _sort_frames(self) -> None:
|
def _sort_frames(self) -> None:
|
||||||
# Sort frames to have them grouped by sequence, ordered by timestamp
|
# Sort frames to have them grouped by sequence, ordered by timestamp
|
||||||
|
# pyre-ignore[16]
|
||||||
self.frame_annots = sorted(
|
self.frame_annots = sorted(
|
||||||
self.frame_annots,
|
self.frame_annots,
|
||||||
key=lambda f: (
|
key=lambda f: (
|
||||||
@ -491,6 +500,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
def _filter_db(self) -> None:
|
def _filter_db(self) -> None:
|
||||||
if self.remove_empty_masks:
|
if self.remove_empty_masks:
|
||||||
logger.info("Removing images with empty masks.")
|
logger.info("Removing images with empty masks.")
|
||||||
|
# pyre-ignore[16]
|
||||||
old_len = len(self.frame_annots)
|
old_len = len(self.frame_annots)
|
||||||
|
|
||||||
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
|
||||||
@ -531,6 +541,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
|
|
||||||
if len(self.limit_category_to) > 0:
|
if len(self.limit_category_to) > 0:
|
||||||
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
logger.info(f"Limiting dataset to categories: {self.limit_category_to}")
|
||||||
|
# pyre-ignore[16]
|
||||||
self.seq_annots = {
|
self.seq_annots = {
|
||||||
name: entry
|
name: entry
|
||||||
for name, entry in self.seq_annots.items()
|
for name, entry in self.seq_annots.items()
|
||||||
@ -568,6 +579,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
if self.n_frames_per_sequence > 0:
|
if self.n_frames_per_sequence > 0:
|
||||||
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.")
|
||||||
keep_idx = []
|
keep_idx = []
|
||||||
|
# pyre-ignore[16]
|
||||||
for seq, seq_indices in self._seq_to_idx.items():
|
for seq, seq_indices in self._seq_to_idx.items():
|
||||||
# infer the seed from the sequence name, this is reproducible
|
# infer the seed from the sequence name, this is reproducible
|
||||||
# and makes the selection differ for different sequences
|
# and makes the selection differ for different sequences
|
||||||
@ -597,14 +609,20 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
self._invalidate_seq_to_idx()
|
self._invalidate_seq_to_idx()
|
||||||
|
|
||||||
if filter_seq_annots:
|
if filter_seq_annots:
|
||||||
|
# pyre-ignore[16]
|
||||||
self.seq_annots = {
|
self.seq_annots = {
|
||||||
k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx
|
k: v
|
||||||
|
for k, v in self.seq_annots.items()
|
||||||
|
# pyre-ignore[16]
|
||||||
|
if k in self._seq_to_idx
|
||||||
}
|
}
|
||||||
|
|
||||||
def _invalidate_seq_to_idx(self) -> None:
|
def _invalidate_seq_to_idx(self) -> None:
|
||||||
seq_to_idx = defaultdict(list)
|
seq_to_idx = defaultdict(list)
|
||||||
|
# pyre-ignore[16]
|
||||||
for idx, entry in enumerate(self.frame_annots):
|
for idx, entry in enumerate(self.frame_annots):
|
||||||
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
|
||||||
|
# pyre-ignore[16]
|
||||||
self._seq_to_idx = seq_to_idx
|
self._seq_to_idx = seq_to_idx
|
||||||
|
|
||||||
def _resize_image(
|
def _resize_image(
|
||||||
@ -644,6 +662,7 @@ class JsonIndexDataset(DatasetBase):
|
|||||||
) -> List[Tuple[int, float]]:
|
) -> List[Tuple[int, float]]:
|
||||||
out: List[Tuple[int, float]] = []
|
out: List[Tuple[int, float]] = []
|
||||||
for idx in idxs:
|
for idx in idxs:
|
||||||
|
# pyre-ignore[16]
|
||||||
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
||||||
out.append(
|
out.append(
|
||||||
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
||||||
|
@ -44,6 +44,7 @@ def get_implicitron_sequence_pointcloud(
|
|||||||
sequence_entries = [
|
sequence_entries = [
|
||||||
ei
|
ei
|
||||||
for ei in sequence_entries
|
for ei in sequence_entries
|
||||||
|
# pyre-ignore[16]
|
||||||
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
|
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
|
||||||
== sequence_name
|
== sequence_name
|
||||||
]
|
]
|
||||||
|
@ -9,6 +9,7 @@ import unittest
|
|||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
from tests.common_testing import get_tests_dir
|
from tests.common_testing import get_tests_dir
|
||||||
|
|
||||||
@ -20,6 +21,33 @@ class TestDataSource(unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.maxDiff = None
|
self.maxDiff = None
|
||||||
|
|
||||||
|
def _test_omegaconf_generic_failure(self):
|
||||||
|
# OmegaConf possible bug - this is why we need _GenericWorkaround
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class D(torch.utils.data.Dataset[int]):
|
||||||
|
a: int = 3
|
||||||
|
|
||||||
|
OmegaConf.structured(D)
|
||||||
|
|
||||||
|
def _test_omegaconf_ListList(self):
|
||||||
|
# Demo that OmegaConf doesn't support nested lists
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class A:
|
||||||
|
a: Sequence[Sequence[int]] = ((32,),)
|
||||||
|
|
||||||
|
OmegaConf.structured(A)
|
||||||
|
|
||||||
|
def test_JsonIndexDataset_args(self):
|
||||||
|
# test that JsonIndexDataset works with get_default_args
|
||||||
|
get_default_args(JsonIndexDataset)
|
||||||
|
|
||||||
def test_one(self):
|
def test_one(self):
|
||||||
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
|
with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}):
|
||||||
cfg = get_default_args(ImplicitronDataSource)
|
cfg = get_default_args(ImplicitronDataSource)
|
||||||
|
@ -51,6 +51,7 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
image_height=self.image_size,
|
image_height=self.image_size,
|
||||||
image_width=self.image_size,
|
image_width=self.image_size,
|
||||||
box_crop=True,
|
box_crop=True,
|
||||||
|
remove_empty_masks=False,
|
||||||
path_manager=path_manager,
|
path_manager=path_manager,
|
||||||
)
|
)
|
||||||
self.bg_color = (0.0, 0.0, 0.0)
|
self.bg_color = (0.0, 0.0, 0.0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user