mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
JsonIndexDatasetProviderV2
Summary: A new version of json index dataset provider supporting CO3Dv2 Reviewed By: shapovalov Differential Revision: D37690918 fbshipit-source-id: bf2d5fc9d0f1220259e08661dafc69cdbe6b7f94
This commit is contained in:
parent
4300030d7a
commit
e8390d3500
@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import gzip
|
import gzip
|
||||||
import hashlib
|
import hashlib
|
||||||
@ -255,6 +256,47 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
|
|
||||||
return dataset_idx
|
return dataset_idx
|
||||||
|
|
||||||
|
def subset_from_frame_index(
|
||||||
|
self,
|
||||||
|
frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
|
||||||
|
allow_missing_indices: bool = True,
|
||||||
|
) -> "JsonIndexDataset":
|
||||||
|
# Get the indices into the frame annots.
|
||||||
|
dataset_indices = self.seq_frame_index_to_dataset_index(
|
||||||
|
[frame_index],
|
||||||
|
allow_missing_indices=self.is_filtered() and allow_missing_indices,
|
||||||
|
)[0]
|
||||||
|
valid_dataset_indices = [i for i in dataset_indices if i is not None]
|
||||||
|
|
||||||
|
# Deep copy the whole dataset except frame_annots, which are large so we
|
||||||
|
# deep copy only the requested subset of frame_annots.
|
||||||
|
memo = {id(self.frame_annots): None} # pyre-ignore[16]
|
||||||
|
dataset_new = copy.deepcopy(self, memo)
|
||||||
|
dataset_new.frame_annots = copy.deepcopy(
|
||||||
|
[self.frame_annots[i] for i in valid_dataset_indices]
|
||||||
|
)
|
||||||
|
|
||||||
|
# This will kill all unneeded sequence annotations.
|
||||||
|
dataset_new._invalidate_indexes(filter_seq_annots=True)
|
||||||
|
|
||||||
|
# Finally annotate the frame annotations with the name of the subset
|
||||||
|
# stored in meta.
|
||||||
|
for frame_annot in dataset_new.frame_annots:
|
||||||
|
frame_annotation = frame_annot["frame_annotation"]
|
||||||
|
if frame_annotation.meta is not None:
|
||||||
|
frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
|
||||||
|
|
||||||
|
# A sanity check - this will crash in case some entries from frame_index are missing
|
||||||
|
# in dataset_new.
|
||||||
|
valid_frame_index = [
|
||||||
|
fi for fi, di in zip(frame_index, dataset_indices) if di is not None
|
||||||
|
]
|
||||||
|
dataset_new.seq_frame_index_to_dataset_index(
|
||||||
|
[valid_frame_index], allow_missing_indices=False
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset_new
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
return f"JsonIndexDataset #frames={len(self.frame_annots)}"
|
||||||
|
@ -92,6 +92,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
to use for the dataset.
|
to use for the dataset.
|
||||||
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
|
dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
|
||||||
to all the dataset constructors.
|
to all the dataset constructors.
|
||||||
|
path_manager_factory: (Optional) An object that generates an instance of
|
||||||
|
PathManager that can translate provided file paths.
|
||||||
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
category: str
|
category: str
|
||||||
|
@ -0,0 +1,343 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from typing import Dict, List, Type
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
|
DatasetMap,
|
||||||
|
DatasetMapProviderBase,
|
||||||
|
PathManagerFactory,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||||
|
"""
|
||||||
|
Generates the training, validation, and testing dataset objects for
|
||||||
|
a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
|
||||||
|
|
||||||
|
The dataset is organized in the filesystem as follows:
|
||||||
|
```
|
||||||
|
self.dataset_root
|
||||||
|
├── <category_0>
|
||||||
|
│ ├── <sequence_name_0>
|
||||||
|
│ │ ├── depth_masks
|
||||||
|
│ │ ├── depths
|
||||||
|
│ │ ├── images
|
||||||
|
│ │ ├── masks
|
||||||
|
│ │ └── pointcloud.ply
|
||||||
|
│ ├── <sequence_name_1>
|
||||||
|
│ │ ├── depth_masks
|
||||||
|
│ │ ├── depths
|
||||||
|
│ │ ├── images
|
||||||
|
│ │ ├── masks
|
||||||
|
│ │ └── pointcloud.ply
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── <sequence_name_N>
|
||||||
|
│ ├── set_lists
|
||||||
|
│ ├── set_lists_<subset_name_0>.json
|
||||||
|
│ ├── set_lists_<subset_name_1>.json
|
||||||
|
│ ├── ...
|
||||||
|
│ ├── set_lists_<subset_name_M>.json
|
||||||
|
│ ├── eval_batches
|
||||||
|
│ │ ├── eval_batches_<subset_name_0>.json
|
||||||
|
│ │ ├── eval_batches_<subset_name_1>.json
|
||||||
|
│ │ ├── ...
|
||||||
|
│ │ ├── eval_batches_<subset_name_M>.json
|
||||||
|
│ ├── frame_annotations.jgz
|
||||||
|
│ ├── sequence_annotations.jgz
|
||||||
|
├── <category_1>
|
||||||
|
├── ...
|
||||||
|
├── <category_K>
|
||||||
|
```
|
||||||
|
|
||||||
|
The dataset contains sequences named `<sequence_name_i>` from `K` categories with
|
||||||
|
names `<category_j>`. Each category comprises sequence folders
|
||||||
|
`<category_k>/<sequence_name_i>` containing the list of sequence images, depth maps,
|
||||||
|
foreground masks, and valid-depth masks `images`, `depths`, `masks`, and `depth_masks`
|
||||||
|
respectively. Furthermore, `<category_k>/<sequence_name_i>/set_lists/` stores `M`
|
||||||
|
json files `set_lists_<subset_name_l>.json`, each describing a certain sequence subset.
|
||||||
|
|
||||||
|
Users specify the loaded dataset subset by setting `self.subset_name` to one of the
|
||||||
|
available subset names `<subset_name_l>`.
|
||||||
|
|
||||||
|
`frame_annotations.jgz` and `sequence_annotations.jgz` are gzipped json files containing
|
||||||
|
the list of all frames and sequences of the given category stored as lists of
|
||||||
|
`FrameAnnotation` and `SequenceAnnotation` objects respectivelly.
|
||||||
|
|
||||||
|
Each `set_lists_<subset_name_l>.json` file contains the following dictionary:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"train": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"val": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
"test": [
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
]
|
||||||
|
```
|
||||||
|
defining the list of frames (identified with their `sequence_name` and `frame_number`)
|
||||||
|
in the "train", "val", and "test" subsets of the dataset.
|
||||||
|
Note that `frame_number` can be obtained only from `frame_annotations.jgz` and
|
||||||
|
does not necesarrily correspond to the numeric suffix of the corresponding image
|
||||||
|
file name (e.g. a file `<category_0>/<sequence_name_0>/images/frame00005.jpg` can
|
||||||
|
have its frame number set to `20`, not 5).
|
||||||
|
|
||||||
|
Each `eval_batches_<subset_name_l>.json` file contains a list of evaluation examples
|
||||||
|
in the following form:
|
||||||
|
```
|
||||||
|
[
|
||||||
|
[ # batch 1
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
[ # batch 1
|
||||||
|
(sequence_name: str, frame_number: int, image_path: str),
|
||||||
|
...
|
||||||
|
],
|
||||||
|
]
|
||||||
|
```
|
||||||
|
Note that the evaluation examples always come from the `"test"` subset of the dataset.
|
||||||
|
(test frames can repeat across batches).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: The object category of the dataset.
|
||||||
|
subset_name: The name of the dataset subset. For CO3Dv2, these include
|
||||||
|
e.g. "manyview_dev_0", "fewview_test", ...
|
||||||
|
dataset_root: The root folder of the dataset.
|
||||||
|
test_on_train: Construct validation and test datasets from
|
||||||
|
the training subset.
|
||||||
|
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||||
|
load_eval_batches: Load the file containing eval batches pointing to the
|
||||||
|
test dataset.
|
||||||
|
dataset_args: Specifies additional arguments to the
|
||||||
|
JsonIndexDataset constructor call.
|
||||||
|
path_manager_factory: (Optional) An object that generates an instance of
|
||||||
|
PathManager that can translate provided file paths.
|
||||||
|
path_manager_factory_class_type: The class type of `path_manager_factory`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
category: str
|
||||||
|
subset_name: str
|
||||||
|
dataset_root: str = _CO3DV2_DATASET_ROOT
|
||||||
|
|
||||||
|
test_on_train: bool = False
|
||||||
|
only_test_set: bool = False
|
||||||
|
load_eval_batches: bool = True
|
||||||
|
|
||||||
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
|
dataset: JsonIndexDataset
|
||||||
|
|
||||||
|
path_manager_factory: PathManagerFactory
|
||||||
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__init__()
|
||||||
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
if self.only_test_set and self.test_on_train:
|
||||||
|
raise ValueError("Cannot have only_test_set and test_on_train")
|
||||||
|
|
||||||
|
frame_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "frame_annotations.jgz"
|
||||||
|
)
|
||||||
|
sequence_file = os.path.join(
|
||||||
|
self.dataset_root, self.category, "sequence_annotations.jgz"
|
||||||
|
)
|
||||||
|
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
|
||||||
|
# setup the common dataset arguments
|
||||||
|
common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args")
|
||||||
|
common_dataset_kwargs = {
|
||||||
|
**common_dataset_kwargs,
|
||||||
|
"dataset_root": self.dataset_root,
|
||||||
|
"frame_annotations_file": frame_file,
|
||||||
|
"sequence_annotations_file": sequence_file,
|
||||||
|
"subsets": None,
|
||||||
|
"subset_lists_file": "",
|
||||||
|
"path_manager": path_manager,
|
||||||
|
}
|
||||||
|
|
||||||
|
# get the used dataset type
|
||||||
|
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||||
|
JsonIndexDataset, self.dataset_class_type
|
||||||
|
)
|
||||||
|
expand_args_fields(dataset_type)
|
||||||
|
|
||||||
|
dataset = dataset_type(**common_dataset_kwargs)
|
||||||
|
|
||||||
|
available_subset_names = self._get_available_subset_names()
|
||||||
|
logger.debug(f"Available subset names: {str(available_subset_names)}.")
|
||||||
|
if self.subset_name not in available_subset_names:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown subset name {self.subset_name}."
|
||||||
|
+ f" Choose one of available subsets: {str(available_subset_names)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the list of train/val/test frames
|
||||||
|
subset_mapping = self._load_annotation_json(
|
||||||
|
os.path.join(
|
||||||
|
self.category, "set_lists", f"set_lists_{self.subset_name}.json"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# load the evaluation batches
|
||||||
|
if self.load_eval_batches:
|
||||||
|
eval_batch_index = self._load_annotation_json(
|
||||||
|
os.path.join(
|
||||||
|
self.category,
|
||||||
|
"eval_batches",
|
||||||
|
f"eval_batches_{self.subset_name}.json",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = None
|
||||||
|
if not self.only_test_set:
|
||||||
|
# load the training set
|
||||||
|
logger.debug("Extracting train dataset.")
|
||||||
|
train_dataset = dataset.subset_from_frame_index(subset_mapping["train"])
|
||||||
|
logger.info(f"Train dataset: {str(train_dataset)}")
|
||||||
|
|
||||||
|
if self.test_on_train:
|
||||||
|
assert train_dataset is not None
|
||||||
|
val_dataset = test_dataset = train_dataset
|
||||||
|
else:
|
||||||
|
# load the val and test sets
|
||||||
|
logger.debug("Extracting val dataset.")
|
||||||
|
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
||||||
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
|
logger.debug("Extracting test dataset.")
|
||||||
|
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
||||||
|
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||||
|
if self.load_eval_batches:
|
||||||
|
# load the eval batches
|
||||||
|
logger.debug("Extracting eval batches.")
|
||||||
|
try:
|
||||||
|
test_dataset.eval_batches = (
|
||||||
|
test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
eval_batch_index,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except IndexError:
|
||||||
|
warnings.warn(
|
||||||
|
"@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n"
|
||||||
|
+ "Some eval batches are missing from the test dataset.\n"
|
||||||
|
+ "The evaluation results will be incomparable to the\n"
|
||||||
|
+ "evaluation results calculated on the original dataset.\n"
|
||||||
|
+ "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
|
||||||
|
)
|
||||||
|
test_dataset.eval_batches = (
|
||||||
|
test_dataset.seq_frame_index_to_dataset_index(
|
||||||
|
eval_batch_index,
|
||||||
|
allow_missing_indices=True,
|
||||||
|
remove_missing_indices=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
|
||||||
|
|
||||||
|
self.dataset_map = DatasetMap(
|
||||||
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_dataset(self):
|
||||||
|
# The dataset object is created inside `self.get_dataset_map`
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_dataset_map(self) -> DatasetMap:
|
||||||
|
return self.dataset_map # pyre-ignore [16]
|
||||||
|
|
||||||
|
def get_category_to_subset_name_list(self) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Returns a global dataset index containing the available subset names per category
|
||||||
|
as a dictionary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
category_to_subset_name_list: A dictionary containing subset names available
|
||||||
|
per category of the following form:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
category_0: [category_0_subset_name_0, category_0_subset_name_1, ...],
|
||||||
|
category_1: [category_1_subset_name_0, category_1_subset_name_1, ...],
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
category_to_subset_name_list_json = "category_to_subset_name_list.json"
|
||||||
|
category_to_subset_name_list = self._load_annotation_json(
|
||||||
|
category_to_subset_name_list_json
|
||||||
|
)
|
||||||
|
return category_to_subset_name_list
|
||||||
|
|
||||||
|
def _load_annotation_json(self, json_filename: str):
|
||||||
|
full_path = os.path.join(
|
||||||
|
self.dataset_root,
|
||||||
|
json_filename,
|
||||||
|
)
|
||||||
|
logger.info(f"Loading frame index json from {full_path}.")
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
if path_manager is not None:
|
||||||
|
full_path = path_manager.get_local_path(full_path)
|
||||||
|
if not os.path.isfile(full_path):
|
||||||
|
# The batch indices file does not exist.
|
||||||
|
# Most probably the user has not specified the root folder.
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for dataset json file in {full_path}. "
|
||||||
|
+ "Please specify a correct dataset_root folder."
|
||||||
|
)
|
||||||
|
with open(full_path, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _get_available_subset_names(self):
|
||||||
|
path_manager = self.path_manager_factory.get()
|
||||||
|
if path_manager is not None:
|
||||||
|
dataset_root = path_manager.get_local_path(self.dataset_root)
|
||||||
|
else:
|
||||||
|
dataset_root = self.dataset_root
|
||||||
|
return get_available_subset_names(dataset_root, self.category)
|
||||||
|
|
||||||
|
|
||||||
|
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get the available subset names for a given category folder inside a root dataset
|
||||||
|
folder `dataset_root`.
|
||||||
|
"""
|
||||||
|
category_dir = os.path.join(dataset_root, category)
|
||||||
|
if not os.path.isdir(category_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"Looking for dataset files in {category_dir}. "
|
||||||
|
+ "Please specify a correct dataset_root folder."
|
||||||
|
)
|
||||||
|
set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists"))
|
||||||
|
return [
|
||||||
|
json_file.replace("set_lists_", "").replace(".json", "")
|
||||||
|
for json_file in set_list_jsons
|
||||||
|
]
|
@ -42,6 +42,47 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
|||||||
sort_frames: false
|
sort_frames: false
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||||
|
category: ???
|
||||||
|
subset_name: ???
|
||||||
|
dataset_root: ''
|
||||||
|
test_on_train: false
|
||||||
|
only_test_set: false
|
||||||
|
load_eval_batches: true
|
||||||
|
dataset_class_type: JsonIndexDataset
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
dataset_JsonIndexDataset_args:
|
||||||
|
path_manager: null
|
||||||
|
frame_annotations_file: ''
|
||||||
|
sequence_annotations_file: ''
|
||||||
|
subset_lists_file: ''
|
||||||
|
subsets: null
|
||||||
|
limit_to: 0
|
||||||
|
limit_sequences_to: 0
|
||||||
|
pick_sequence: []
|
||||||
|
exclude_sequence: []
|
||||||
|
limit_category_to: []
|
||||||
|
dataset_root: ''
|
||||||
|
load_images: true
|
||||||
|
load_depths: true
|
||||||
|
load_depth_masks: true
|
||||||
|
load_masks: true
|
||||||
|
load_point_clouds: false
|
||||||
|
max_points: 0
|
||||||
|
mask_images: false
|
||||||
|
mask_depths: false
|
||||||
|
image_height: 800
|
||||||
|
image_width: 800
|
||||||
|
box_crop: true
|
||||||
|
box_crop_mask_thr: 0.4
|
||||||
|
box_crop_context: 0.3
|
||||||
|
remove_empty_masks: true
|
||||||
|
n_frames_per_sequence: -1
|
||||||
|
seed: 0
|
||||||
|
sort_frames: false
|
||||||
|
eval_batches: null
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||||
base_dir: ???
|
base_dir: ???
|
||||||
object_name: ???
|
object_name: ???
|
||||||
|
155
tests/implicitron/test_json_index_dataset_provider_v2.py
Normal file
155
tests/implicitron/test_json_index_dataset_provider_v2.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
from PIL import Image
|
||||||
|
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||||
|
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
|
||||||
|
JsonIndexDatasetMapProviderV2,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.types import (
|
||||||
|
dump_dataclass_jgzip,
|
||||||
|
FrameAnnotation,
|
||||||
|
ImageAnnotation,
|
||||||
|
MaskAnnotation,
|
||||||
|
SequenceAnnotation,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
|
|
||||||
|
|
||||||
|
class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||||
|
def test_random_dataset(self):
|
||||||
|
# store random frame annotations
|
||||||
|
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
||||||
|
categories = ["A", "B"]
|
||||||
|
subset_name = "test"
|
||||||
|
with tempfile.TemporaryDirectory() as tmpd:
|
||||||
|
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
|
||||||
|
for category in categories:
|
||||||
|
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||||
|
category=category,
|
||||||
|
subset_name="test",
|
||||||
|
dataset_root=tmpd,
|
||||||
|
)
|
||||||
|
dataset_map = dataset_provider.get_dataset_map()
|
||||||
|
for set_ in ["train", "val", "test"]:
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
getattr(dataset_map, set_),
|
||||||
|
batch_size=3,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=FrameData.collate,
|
||||||
|
)
|
||||||
|
for _ in dataloader:
|
||||||
|
pass
|
||||||
|
category_to_subset_list = (
|
||||||
|
dataset_provider.get_category_to_subset_name_list()
|
||||||
|
)
|
||||||
|
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||||
|
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_random_json_dataset_map_provider_v2_data(
|
||||||
|
root: str,
|
||||||
|
categories: List[str],
|
||||||
|
n_frames: int = 8,
|
||||||
|
n_sequences: int = 5,
|
||||||
|
H: int = 50,
|
||||||
|
W: int = 30,
|
||||||
|
subset_name: str = "test",
|
||||||
|
):
|
||||||
|
os.makedirs(root, exist_ok=True)
|
||||||
|
category_to_subset_list = {}
|
||||||
|
for category in categories:
|
||||||
|
frame_annotations = []
|
||||||
|
sequence_annotations = []
|
||||||
|
frame_index = []
|
||||||
|
for seq_i in range(n_sequences):
|
||||||
|
seq_name = str(seq_i)
|
||||||
|
for i in range(n_frames):
|
||||||
|
# generate and store image
|
||||||
|
imdir = os.path.join(root, category, seq_name, "images")
|
||||||
|
os.makedirs(imdir, exist_ok=True)
|
||||||
|
img_path = os.path.join(imdir, f"frame{i:05d}.jpg")
|
||||||
|
img = torch.rand(3, H, W)
|
||||||
|
torchvision.utils.save_image(img, img_path)
|
||||||
|
|
||||||
|
# generate and store mask
|
||||||
|
maskdir = os.path.join(root, category, seq_name, "masks")
|
||||||
|
os.makedirs(maskdir, exist_ok=True)
|
||||||
|
mask_path = os.path.join(maskdir, f"frame{i:05d}.png")
|
||||||
|
mask = np.zeros((H, W))
|
||||||
|
mask[H // 2 :, W // 2 :] = 1
|
||||||
|
Image.fromarray((mask * 255.0).astype(np.uint8), mode="L",).convert(
|
||||||
|
"L"
|
||||||
|
).save(mask_path)
|
||||||
|
|
||||||
|
fa = FrameAnnotation(
|
||||||
|
sequence_name=seq_name,
|
||||||
|
frame_number=i,
|
||||||
|
frame_timestamp=float(i),
|
||||||
|
image=ImageAnnotation(
|
||||||
|
path=img_path.replace(os.path.normpath(root) + "/", ""),
|
||||||
|
size=list(img.shape[-2:]),
|
||||||
|
),
|
||||||
|
mask=MaskAnnotation(
|
||||||
|
path=mask_path.replace(os.path.normpath(root) + "/", ""),
|
||||||
|
mass=mask.sum().item(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
frame_annotations.append(fa)
|
||||||
|
frame_index.append((seq_name, i, fa.image.path))
|
||||||
|
|
||||||
|
sequence_annotations.append(
|
||||||
|
SequenceAnnotation(
|
||||||
|
sequence_name=seq_name,
|
||||||
|
category=category,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
dump_dataclass_jgzip(
|
||||||
|
os.path.join(root, category, "frame_annotations.jgz"),
|
||||||
|
frame_annotations,
|
||||||
|
)
|
||||||
|
dump_dataclass_jgzip(
|
||||||
|
os.path.join(root, category, "sequence_annotations.jgz"),
|
||||||
|
sequence_annotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
test_frame_index = frame_index[2::3]
|
||||||
|
|
||||||
|
set_list = {
|
||||||
|
"train": frame_index[0::3],
|
||||||
|
"val": frame_index[1::3],
|
||||||
|
"test": test_frame_index,
|
||||||
|
}
|
||||||
|
set_lists_dir = os.path.join(root, category, "set_lists")
|
||||||
|
os.makedirs(set_lists_dir, exist_ok=True)
|
||||||
|
set_list_file = os.path.join(set_lists_dir, f"set_lists_{subset_name}.json")
|
||||||
|
with open(set_list_file, "w") as f:
|
||||||
|
json.dump(set_list, f)
|
||||||
|
|
||||||
|
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
|
||||||
|
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||||
|
os.makedirs(eval_b_dir, exist_ok=True)
|
||||||
|
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
||||||
|
with open(eval_b_file, "w") as f:
|
||||||
|
json.dump(eval_batches, f)
|
||||||
|
|
||||||
|
category_to_subset_list[category] = [subset_name]
|
||||||
|
|
||||||
|
with open(os.path.join(root, "category_to_subset_name_list.json"), "w") as f:
|
||||||
|
json.dump(category_to_subset_list, f)
|
Loading…
x
Reference in New Issue
Block a user