From e8390d35002650c84ac3c3ae60eb7136a272916f Mon Sep 17 00:00:00 2001 From: David Novotny Date: Sat, 9 Jul 2022 17:16:24 -0700 Subject: [PATCH] JsonIndexDatasetProviderV2 Summary: A new version of json index dataset provider supporting CO3Dv2 Reviewed By: shapovalov Differential Revision: D37690918 fbshipit-source-id: bf2d5fc9d0f1220259e08661dafc69cdbe6b7f94 --- .../implicitron/dataset/json_index_dataset.py | 42 +++ .../json_index_dataset_map_provider.py | 3 + .../json_index_dataset_map_provider_v2.py | 343 ++++++++++++++++++ tests/implicitron/data/data_source.yaml | 41 +++ .../test_json_index_dataset_provider_v2.py | 155 ++++++++ 5 files changed, 584 insertions(+) create mode 100644 pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py create mode 100644 tests/implicitron/test_json_index_dataset_provider_v2.py diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 17ba4cec..b1640a13 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy import functools import gzip import hashlib @@ -255,6 +256,47 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): return dataset_idx + def subset_from_frame_index( + self, + frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]], + allow_missing_indices: bool = True, + ) -> "JsonIndexDataset": + # Get the indices into the frame annots. + dataset_indices = self.seq_frame_index_to_dataset_index( + [frame_index], + allow_missing_indices=self.is_filtered() and allow_missing_indices, + )[0] + valid_dataset_indices = [i for i in dataset_indices if i is not None] + + # Deep copy the whole dataset except frame_annots, which are large so we + # deep copy only the requested subset of frame_annots. + memo = {id(self.frame_annots): None} # pyre-ignore[16] + dataset_new = copy.deepcopy(self, memo) + dataset_new.frame_annots = copy.deepcopy( + [self.frame_annots[i] for i in valid_dataset_indices] + ) + + # This will kill all unneeded sequence annotations. + dataset_new._invalidate_indexes(filter_seq_annots=True) + + # Finally annotate the frame annotations with the name of the subset + # stored in meta. + for frame_annot in dataset_new.frame_annots: + frame_annotation = frame_annot["frame_annotation"] + if frame_annotation.meta is not None: + frame_annot["subset"] = frame_annotation.meta.get("frame_type", None) + + # A sanity check - this will crash in case some entries from frame_index are missing + # in dataset_new. + valid_frame_index = [ + fi for fi, di in zip(frame_index, dataset_indices) if di is not None + ] + dataset_new.seq_frame_index_to_dataset_index( + [valid_frame_index], allow_missing_indices=False + ) + + return dataset_new + def __str__(self) -> str: # pyre-ignore[16] return f"JsonIndexDataset #frames={len(self.frame_annots)}" diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py index f6155baa..06f70231 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py @@ -92,6 +92,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] to use for the dataset. dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed to all the dataset constructors. + path_manager_factory: (Optional) An object that generates an instance of + PathManager that can translate provided file paths. + path_manager_factory_class_type: The class type of `path_manager_factory`. """ category: str diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py new file mode 100644 index 00000000..4b1b29a0 --- /dev/null +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py @@ -0,0 +1,343 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import logging +import os +import warnings +from typing import Dict, List, Type + +from pytorch3d.implicitron.dataset.dataset_map_provider import ( + DatasetMap, + DatasetMapProviderBase, + PathManagerFactory, +) +from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset +from pytorch3d.implicitron.tools.config import ( + expand_args_fields, + registry, + run_auto_creation, +) + + +_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "") + + +logger = logging.getLogger(__name__) + + +@registry.register +class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] + """ + Generates the training, validation, and testing dataset objects for + a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files. + + The dataset is organized in the filesystem as follows: + ``` + self.dataset_root + ├── + │ ├── + │ │ ├── depth_masks + │ │ ├── depths + │ │ ├── images + │ │ ├── masks + │ │ └── pointcloud.ply + │ ├── + │ │ ├── depth_masks + │ │ ├── depths + │ │ ├── images + │ │ ├── masks + │ │ └── pointcloud.ply + │ ├── ... + │ ├── + │ ├── set_lists + │ ├── set_lists_.json + │ ├── set_lists_.json + │ ├── ... + │ ├── set_lists_.json + │ ├── eval_batches + │ │ ├── eval_batches_.json + │ │ ├── eval_batches_.json + │ │ ├── ... + │ │ ├── eval_batches_.json + │ ├── frame_annotations.jgz + │ ├── sequence_annotations.jgz + ├── + ├── ... + ├── + ``` + + The dataset contains sequences named `` from `K` categories with + names ``. Each category comprises sequence folders + `/` containing the list of sequence images, depth maps, + foreground masks, and valid-depth masks `images`, `depths`, `masks`, and `depth_masks` + respectively. Furthermore, `//set_lists/` stores `M` + json files `set_lists_.json`, each describing a certain sequence subset. + + Users specify the loaded dataset subset by setting `self.subset_name` to one of the + available subset names ``. + + `frame_annotations.jgz` and `sequence_annotations.jgz` are gzipped json files containing + the list of all frames and sequences of the given category stored as lists of + `FrameAnnotation` and `SequenceAnnotation` objects respectivelly. + + Each `set_lists_.json` file contains the following dictionary: + ``` + { + "train": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + "val": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + "test": [ + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + ] + ``` + defining the list of frames (identified with their `sequence_name` and `frame_number`) + in the "train", "val", and "test" subsets of the dataset. + Note that `frame_number` can be obtained only from `frame_annotations.jgz` and + does not necesarrily correspond to the numeric suffix of the corresponding image + file name (e.g. a file `//images/frame00005.jpg` can + have its frame number set to `20`, not 5). + + Each `eval_batches_.json` file contains a list of evaluation examples + in the following form: + ``` + [ + [ # batch 1 + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + [ # batch 1 + (sequence_name: str, frame_number: int, image_path: str), + ... + ], + ] + ``` + Note that the evaluation examples always come from the `"test"` subset of the dataset. + (test frames can repeat across batches). + + Args: + category: The object category of the dataset. + subset_name: The name of the dataset subset. For CO3Dv2, these include + e.g. "manyview_dev_0", "fewview_test", ... + dataset_root: The root folder of the dataset. + test_on_train: Construct validation and test datasets from + the training subset. + only_test_set: Load only the test set. Incompatible with `test_on_train`. + load_eval_batches: Load the file containing eval batches pointing to the + test dataset. + dataset_args: Specifies additional arguments to the + JsonIndexDataset constructor call. + path_manager_factory: (Optional) An object that generates an instance of + PathManager that can translate provided file paths. + path_manager_factory_class_type: The class type of `path_manager_factory`. + """ + + category: str + subset_name: str + dataset_root: str = _CO3DV2_DATASET_ROOT + + test_on_train: bool = False + only_test_set: bool = False + load_eval_batches: bool = True + + dataset_class_type: str = "JsonIndexDataset" + dataset: JsonIndexDataset + + path_manager_factory: PathManagerFactory + path_manager_factory_class_type: str = "PathManagerFactory" + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + if self.only_test_set and self.test_on_train: + raise ValueError("Cannot have only_test_set and test_on_train") + + frame_file = os.path.join( + self.dataset_root, self.category, "frame_annotations.jgz" + ) + sequence_file = os.path.join( + self.dataset_root, self.category, "sequence_annotations.jgz" + ) + + path_manager = self.path_manager_factory.get() + + # setup the common dataset arguments + common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args") + common_dataset_kwargs = { + **common_dataset_kwargs, + "dataset_root": self.dataset_root, + "frame_annotations_file": frame_file, + "sequence_annotations_file": sequence_file, + "subsets": None, + "subset_lists_file": "", + "path_manager": path_manager, + } + + # get the used dataset type + dataset_type: Type[JsonIndexDataset] = registry.get( + JsonIndexDataset, self.dataset_class_type + ) + expand_args_fields(dataset_type) + + dataset = dataset_type(**common_dataset_kwargs) + + available_subset_names = self._get_available_subset_names() + logger.debug(f"Available subset names: {str(available_subset_names)}.") + if self.subset_name not in available_subset_names: + raise ValueError( + f"Unknown subset name {self.subset_name}." + + f" Choose one of available subsets: {str(available_subset_names)}." + ) + + # load the list of train/val/test frames + subset_mapping = self._load_annotation_json( + os.path.join( + self.category, "set_lists", f"set_lists_{self.subset_name}.json" + ) + ) + + # load the evaluation batches + if self.load_eval_batches: + eval_batch_index = self._load_annotation_json( + os.path.join( + self.category, + "eval_batches", + f"eval_batches_{self.subset_name}.json", + ) + ) + + train_dataset = None + if not self.only_test_set: + # load the training set + logger.debug("Extracting train dataset.") + train_dataset = dataset.subset_from_frame_index(subset_mapping["train"]) + logger.info(f"Train dataset: {str(train_dataset)}") + + if self.test_on_train: + assert train_dataset is not None + val_dataset = test_dataset = train_dataset + else: + # load the val and test sets + logger.debug("Extracting val dataset.") + val_dataset = dataset.subset_from_frame_index(subset_mapping["val"]) + logger.info(f"Val dataset: {str(val_dataset)}") + logger.debug("Extracting test dataset.") + test_dataset = dataset.subset_from_frame_index(subset_mapping["test"]) + logger.info(f"Test dataset: {str(test_dataset)}") + if self.load_eval_batches: + # load the eval batches + logger.debug("Extracting eval batches.") + try: + test_dataset.eval_batches = ( + test_dataset.seq_frame_index_to_dataset_index( + eval_batch_index, + ) + ) + except IndexError: + warnings.warn( + "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n" + + "Some eval batches are missing from the test dataset.\n" + + "The evaluation results will be incomparable to the\n" + + "evaluation results calculated on the original dataset.\n" + + "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@" + ) + test_dataset.eval_batches = ( + test_dataset.seq_frame_index_to_dataset_index( + eval_batch_index, + allow_missing_indices=True, + remove_missing_indices=True, + ) + ) + logger.info(f"# eval batches: {len(test_dataset.eval_batches)}") + + self.dataset_map = DatasetMap( + train=train_dataset, val=val_dataset, test=test_dataset + ) + + def create_dataset(self): + # The dataset object is created inside `self.get_dataset_map` + pass + + def get_dataset_map(self) -> DatasetMap: + return self.dataset_map # pyre-ignore [16] + + def get_category_to_subset_name_list(self) -> Dict[str, List[str]]: + """ + Returns a global dataset index containing the available subset names per category + as a dictionary. + + Returns: + category_to_subset_name_list: A dictionary containing subset names available + per category of the following form: + ``` + { + category_0: [category_0_subset_name_0, category_0_subset_name_1, ...], + category_1: [category_1_subset_name_0, category_1_subset_name_1, ...], + ... + } + ``` + + """ + category_to_subset_name_list_json = "category_to_subset_name_list.json" + category_to_subset_name_list = self._load_annotation_json( + category_to_subset_name_list_json + ) + return category_to_subset_name_list + + def _load_annotation_json(self, json_filename: str): + full_path = os.path.join( + self.dataset_root, + json_filename, + ) + logger.info(f"Loading frame index json from {full_path}.") + path_manager = self.path_manager_factory.get() + if path_manager is not None: + full_path = path_manager.get_local_path(full_path) + if not os.path.isfile(full_path): + # The batch indices file does not exist. + # Most probably the user has not specified the root folder. + raise ValueError( + f"Looking for dataset json file in {full_path}. " + + "Please specify a correct dataset_root folder." + ) + with open(full_path, "r") as f: + data = json.load(f) + return data + + def _get_available_subset_names(self): + path_manager = self.path_manager_factory.get() + if path_manager is not None: + dataset_root = path_manager.get_local_path(self.dataset_root) + else: + dataset_root = self.dataset_root + return get_available_subset_names(dataset_root, self.category) + + +def get_available_subset_names(dataset_root: str, category: str) -> List[str]: + """ + Get the available subset names for a given category folder inside a root dataset + folder `dataset_root`. + """ + category_dir = os.path.join(dataset_root, category) + if not os.path.isdir(category_dir): + raise ValueError( + f"Looking for dataset files in {category_dir}. " + + "Please specify a correct dataset_root folder." + ) + set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists")) + return [ + json_file.replace("set_lists_", "").replace(".json", "") + for json_file in set_list_jsons + ] diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 2c84545e..3760f944 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -42,6 +42,47 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args: sort_frames: false path_manager_factory_PathManagerFactory_args: silence_logs: true +dataset_map_provider_JsonIndexDatasetMapProviderV2_args: + category: ??? + subset_name: ??? + dataset_root: '' + test_on_train: false + only_test_set: false + load_eval_batches: true + dataset_class_type: JsonIndexDataset + path_manager_factory_class_type: PathManagerFactory + dataset_JsonIndexDataset_args: + path_manager: null + frame_annotations_file: '' + sequence_annotations_file: '' + subset_lists_file: '' + subsets: null + limit_to: 0 + limit_sequences_to: 0 + pick_sequence: [] + exclude_sequence: [] + limit_category_to: [] + dataset_root: '' + load_images: true + load_depths: true + load_depth_masks: true + load_masks: true + load_point_clouds: false + max_points: 0 + mask_images: false + mask_depths: false + image_height: 800 + image_width: 800 + box_crop: true + box_crop_mask_thr: 0.4 + box_crop_context: 0.3 + remove_empty_masks: true + n_frames_per_sequence: -1 + seed: 0 + sort_frames: false + eval_batches: null + path_manager_factory_PathManagerFactory_args: + silence_logs: true dataset_map_provider_LlffDatasetMapProvider_args: base_dir: ??? object_name: ??? diff --git a/tests/implicitron/test_json_index_dataset_provider_v2.py b/tests/implicitron/test_json_index_dataset_provider_v2.py new file mode 100644 index 00000000..bd698692 --- /dev/null +++ b/tests/implicitron/test_json_index_dataset_provider_v2.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import json +import os +import random +import tempfile +import unittest +from typing import List + +import numpy as np + +import torch +import torchvision +from PIL import Image +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import ( + JsonIndexDatasetMapProviderV2, +) +from pytorch3d.implicitron.dataset.types import ( + dump_dataclass_jgzip, + FrameAnnotation, + ImageAnnotation, + MaskAnnotation, + SequenceAnnotation, +) +from pytorch3d.implicitron.tools.config import expand_args_fields + + +class TestJsonIndexDatasetProviderV2(unittest.TestCase): + def test_random_dataset(self): + # store random frame annotations + expand_args_fields(JsonIndexDatasetMapProviderV2) + categories = ["A", "B"] + subset_name = "test" + with tempfile.TemporaryDirectory() as tmpd: + _make_random_json_dataset_map_provider_v2_data(tmpd, categories) + for category in categories: + dataset_provider = JsonIndexDatasetMapProviderV2( + category=category, + subset_name="test", + dataset_root=tmpd, + ) + dataset_map = dataset_provider.get_dataset_map() + for set_ in ["train", "val", "test"]: + dataloader = torch.utils.data.DataLoader( + getattr(dataset_map, set_), + batch_size=3, + shuffle=True, + collate_fn=FrameData.collate, + ) + for _ in dataloader: + pass + category_to_subset_list = ( + dataset_provider.get_category_to_subset_name_list() + ) + category_to_subset_list_ = {c: [subset_name] for c in categories} + self.assertTrue(category_to_subset_list == category_to_subset_list_) + + +def _make_random_json_dataset_map_provider_v2_data( + root: str, + categories: List[str], + n_frames: int = 8, + n_sequences: int = 5, + H: int = 50, + W: int = 30, + subset_name: str = "test", +): + os.makedirs(root, exist_ok=True) + category_to_subset_list = {} + for category in categories: + frame_annotations = [] + sequence_annotations = [] + frame_index = [] + for seq_i in range(n_sequences): + seq_name = str(seq_i) + for i in range(n_frames): + # generate and store image + imdir = os.path.join(root, category, seq_name, "images") + os.makedirs(imdir, exist_ok=True) + img_path = os.path.join(imdir, f"frame{i:05d}.jpg") + img = torch.rand(3, H, W) + torchvision.utils.save_image(img, img_path) + + # generate and store mask + maskdir = os.path.join(root, category, seq_name, "masks") + os.makedirs(maskdir, exist_ok=True) + mask_path = os.path.join(maskdir, f"frame{i:05d}.png") + mask = np.zeros((H, W)) + mask[H // 2 :, W // 2 :] = 1 + Image.fromarray((mask * 255.0).astype(np.uint8), mode="L",).convert( + "L" + ).save(mask_path) + + fa = FrameAnnotation( + sequence_name=seq_name, + frame_number=i, + frame_timestamp=float(i), + image=ImageAnnotation( + path=img_path.replace(os.path.normpath(root) + "/", ""), + size=list(img.shape[-2:]), + ), + mask=MaskAnnotation( + path=mask_path.replace(os.path.normpath(root) + "/", ""), + mass=mask.sum().item(), + ), + ) + frame_annotations.append(fa) + frame_index.append((seq_name, i, fa.image.path)) + + sequence_annotations.append( + SequenceAnnotation( + sequence_name=seq_name, + category=category, + ) + ) + + dump_dataclass_jgzip( + os.path.join(root, category, "frame_annotations.jgz"), + frame_annotations, + ) + dump_dataclass_jgzip( + os.path.join(root, category, "sequence_annotations.jgz"), + sequence_annotations, + ) + + test_frame_index = frame_index[2::3] + + set_list = { + "train": frame_index[0::3], + "val": frame_index[1::3], + "test": test_frame_index, + } + set_lists_dir = os.path.join(root, category, "set_lists") + os.makedirs(set_lists_dir, exist_ok=True) + set_list_file = os.path.join(set_lists_dir, f"set_lists_{subset_name}.json") + with open(set_list_file, "w") as f: + json.dump(set_list, f) + + eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)] + eval_b_dir = os.path.join(root, category, "eval_batches") + os.makedirs(eval_b_dir, exist_ok=True) + eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json") + with open(eval_b_file, "w") as f: + json.dump(eval_batches, f) + + category_to_subset_list[category] = [subset_name] + + with open(os.path.join(root, "category_to_subset_name_list.json"), "w") as f: + json.dump(category_to_subset_list, f)