mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	JsonIndexDatasetProviderV2
Summary: A new version of json index dataset provider supporting CO3Dv2 Reviewed By: shapovalov Differential Revision: D37690918 fbshipit-source-id: bf2d5fc9d0f1220259e08661dafc69cdbe6b7f94
This commit is contained in:
		
							parent
							
								
									4300030d7a
								
							
						
					
					
						commit
						e8390d3500
					
				@ -4,6 +4,7 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
import functools
 | 
			
		||||
import gzip
 | 
			
		||||
import hashlib
 | 
			
		||||
@ -255,6 +256,47 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
        return dataset_idx
 | 
			
		||||
 | 
			
		||||
    def subset_from_frame_index(
 | 
			
		||||
        self,
 | 
			
		||||
        frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
 | 
			
		||||
        allow_missing_indices: bool = True,
 | 
			
		||||
    ) -> "JsonIndexDataset":
 | 
			
		||||
        # Get the indices into the frame annots.
 | 
			
		||||
        dataset_indices = self.seq_frame_index_to_dataset_index(
 | 
			
		||||
            [frame_index],
 | 
			
		||||
            allow_missing_indices=self.is_filtered() and allow_missing_indices,
 | 
			
		||||
        )[0]
 | 
			
		||||
        valid_dataset_indices = [i for i in dataset_indices if i is not None]
 | 
			
		||||
 | 
			
		||||
        # Deep copy the whole dataset except frame_annots, which are large so we
 | 
			
		||||
        # deep copy only the requested subset of frame_annots.
 | 
			
		||||
        memo = {id(self.frame_annots): None}  # pyre-ignore[16]
 | 
			
		||||
        dataset_new = copy.deepcopy(self, memo)
 | 
			
		||||
        dataset_new.frame_annots = copy.deepcopy(
 | 
			
		||||
            [self.frame_annots[i] for i in valid_dataset_indices]
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # This will kill all unneeded sequence annotations.
 | 
			
		||||
        dataset_new._invalidate_indexes(filter_seq_annots=True)
 | 
			
		||||
 | 
			
		||||
        # Finally annotate the frame annotations with the name of the subset
 | 
			
		||||
        # stored in meta.
 | 
			
		||||
        for frame_annot in dataset_new.frame_annots:
 | 
			
		||||
            frame_annotation = frame_annot["frame_annotation"]
 | 
			
		||||
            if frame_annotation.meta is not None:
 | 
			
		||||
                frame_annot["subset"] = frame_annotation.meta.get("frame_type", None)
 | 
			
		||||
 | 
			
		||||
        # A sanity check - this will crash in case some entries from frame_index are missing
 | 
			
		||||
        # in dataset_new.
 | 
			
		||||
        valid_frame_index = [
 | 
			
		||||
            fi for fi, di in zip(frame_index, dataset_indices) if di is not None
 | 
			
		||||
        ]
 | 
			
		||||
        dataset_new.seq_frame_index_to_dataset_index(
 | 
			
		||||
            [valid_frame_index], allow_missing_indices=False
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        return dataset_new
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return f"JsonIndexDataset #frames={len(self.frame_annots)}"
 | 
			
		||||
 | 
			
		||||
@ -92,6 +92,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
                            to use for the dataset.
 | 
			
		||||
        dataset_X_args (e.g. dataset_JsonIndexDataset_args): arguments passed
 | 
			
		||||
            to all the dataset constructors.
 | 
			
		||||
        path_manager_factory: (Optional) An object that generates an instance of
 | 
			
		||||
            PathManager that can translate provided file paths.
 | 
			
		||||
        path_manager_factory_class_type: The class type of `path_manager_factory`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    category: str
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,343 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Dict, List, Type
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
    PathManagerFactory,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    expand_args_fields,
 | 
			
		||||
    registry,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
    """
 | 
			
		||||
    Generates the training, validation, and testing dataset objects for
 | 
			
		||||
    a dataset laid out on disk like CO3Dv2, with annotations in gzipped json files.
 | 
			
		||||
 | 
			
		||||
    The dataset is organized in the filesystem as follows:
 | 
			
		||||
        ```
 | 
			
		||||
        self.dataset_root
 | 
			
		||||
            ├── <category_0>
 | 
			
		||||
            │   ├── <sequence_name_0>
 | 
			
		||||
            │   │   ├── depth_masks
 | 
			
		||||
            │   │   ├── depths
 | 
			
		||||
            │   │   ├── images
 | 
			
		||||
            │   │   ├── masks
 | 
			
		||||
            │   │   └── pointcloud.ply
 | 
			
		||||
            │   ├── <sequence_name_1>
 | 
			
		||||
            │   │   ├── depth_masks
 | 
			
		||||
            │   │   ├── depths
 | 
			
		||||
            │   │   ├── images
 | 
			
		||||
            │   │   ├── masks
 | 
			
		||||
            │   │   └── pointcloud.ply
 | 
			
		||||
            │   ├── ...
 | 
			
		||||
            │   ├── <sequence_name_N>
 | 
			
		||||
            │   ├── set_lists
 | 
			
		||||
            │       ├── set_lists_<subset_name_0>.json
 | 
			
		||||
            │       ├── set_lists_<subset_name_1>.json
 | 
			
		||||
            │       ├── ...
 | 
			
		||||
            │       ├── set_lists_<subset_name_M>.json
 | 
			
		||||
            │   ├── eval_batches
 | 
			
		||||
            │   │   ├── eval_batches_<subset_name_0>.json
 | 
			
		||||
            │   │   ├── eval_batches_<subset_name_1>.json
 | 
			
		||||
            │   │   ├── ...
 | 
			
		||||
            │   │   ├── eval_batches_<subset_name_M>.json
 | 
			
		||||
            │   ├── frame_annotations.jgz
 | 
			
		||||
            │   ├── sequence_annotations.jgz
 | 
			
		||||
            ├── <category_1>
 | 
			
		||||
            ├── ...
 | 
			
		||||
            ├── <category_K>
 | 
			
		||||
        ```
 | 
			
		||||
 | 
			
		||||
    The dataset contains sequences named `<sequence_name_i>` from `K` categories with
 | 
			
		||||
    names `<category_j>`. Each category comprises sequence folders
 | 
			
		||||
    `<category_k>/<sequence_name_i>` containing the list of sequence images, depth maps,
 | 
			
		||||
    foreground masks, and valid-depth masks `images`, `depths`, `masks`, and `depth_masks`
 | 
			
		||||
    respectively. Furthermore, `<category_k>/<sequence_name_i>/set_lists/` stores `M`
 | 
			
		||||
    json files `set_lists_<subset_name_l>.json`, each describing a certain sequence subset.
 | 
			
		||||
 | 
			
		||||
    Users specify the loaded dataset subset by setting `self.subset_name` to one of the
 | 
			
		||||
    available subset names `<subset_name_l>`.
 | 
			
		||||
 | 
			
		||||
    `frame_annotations.jgz` and `sequence_annotations.jgz` are gzipped json files containing
 | 
			
		||||
    the list of all frames and sequences of the given category stored as lists of
 | 
			
		||||
    `FrameAnnotation` and `SequenceAnnotation` objects respectivelly.
 | 
			
		||||
 | 
			
		||||
    Each `set_lists_<subset_name_l>.json` file contains the following dictionary:
 | 
			
		||||
        ```
 | 
			
		||||
        {
 | 
			
		||||
            "train": [
 | 
			
		||||
                (sequence_name: str, frame_number: int, image_path: str),
 | 
			
		||||
                ...
 | 
			
		||||
            ],
 | 
			
		||||
            "val": [
 | 
			
		||||
                (sequence_name: str, frame_number: int, image_path: str),
 | 
			
		||||
                ...
 | 
			
		||||
            ],
 | 
			
		||||
            "test": [
 | 
			
		||||
                (sequence_name: str, frame_number: int, image_path: str),
 | 
			
		||||
                ...
 | 
			
		||||
            ],
 | 
			
		||||
        ]
 | 
			
		||||
        ```
 | 
			
		||||
    defining the list of frames (identified with their `sequence_name` and `frame_number`)
 | 
			
		||||
    in the "train", "val", and "test" subsets of the dataset.
 | 
			
		||||
    Note that `frame_number` can be obtained only from `frame_annotations.jgz` and
 | 
			
		||||
    does not necesarrily correspond to the numeric suffix of the corresponding image
 | 
			
		||||
    file name (e.g. a file `<category_0>/<sequence_name_0>/images/frame00005.jpg` can
 | 
			
		||||
    have its frame number set to `20`, not 5).
 | 
			
		||||
 | 
			
		||||
    Each `eval_batches_<subset_name_l>.json` file contains a list of evaluation examples
 | 
			
		||||
    in the following form:
 | 
			
		||||
        ```
 | 
			
		||||
        [
 | 
			
		||||
            [  # batch 1
 | 
			
		||||
                (sequence_name: str, frame_number: int, image_path: str),
 | 
			
		||||
                ...
 | 
			
		||||
            ],
 | 
			
		||||
            [  # batch 1
 | 
			
		||||
                (sequence_name: str, frame_number: int, image_path: str),
 | 
			
		||||
                ...
 | 
			
		||||
            ],
 | 
			
		||||
        ]
 | 
			
		||||
        ```
 | 
			
		||||
    Note that the evaluation examples always come from the `"test"` subset of the dataset.
 | 
			
		||||
    (test frames can repeat across batches).
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        category: The object category of the dataset.
 | 
			
		||||
        subset_name: The name of the dataset subset. For CO3Dv2, these include
 | 
			
		||||
            e.g. "manyview_dev_0", "fewview_test", ...
 | 
			
		||||
        dataset_root: The root folder of the dataset.
 | 
			
		||||
        test_on_train: Construct validation and test datasets from
 | 
			
		||||
            the training subset.
 | 
			
		||||
        only_test_set: Load only the test set. Incompatible with `test_on_train`.
 | 
			
		||||
        load_eval_batches: Load the file containing eval batches pointing to the
 | 
			
		||||
            test dataset.
 | 
			
		||||
        dataset_args: Specifies additional arguments to the
 | 
			
		||||
            JsonIndexDataset constructor call.
 | 
			
		||||
        path_manager_factory: (Optional) An object that generates an instance of
 | 
			
		||||
            PathManager that can translate provided file paths.
 | 
			
		||||
        path_manager_factory_class_type: The class type of `path_manager_factory`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    category: str
 | 
			
		||||
    subset_name: str
 | 
			
		||||
    dataset_root: str = _CO3DV2_DATASET_ROOT
 | 
			
		||||
 | 
			
		||||
    test_on_train: bool = False
 | 
			
		||||
    only_test_set: bool = False
 | 
			
		||||
    load_eval_batches: bool = True
 | 
			
		||||
 | 
			
		||||
    dataset_class_type: str = "JsonIndexDataset"
 | 
			
		||||
    dataset: JsonIndexDataset
 | 
			
		||||
 | 
			
		||||
    path_manager_factory: PathManagerFactory
 | 
			
		||||
    path_manager_factory_class_type: str = "PathManagerFactory"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
        if self.only_test_set and self.test_on_train:
 | 
			
		||||
            raise ValueError("Cannot have only_test_set and test_on_train")
 | 
			
		||||
 | 
			
		||||
        frame_file = os.path.join(
 | 
			
		||||
            self.dataset_root, self.category, "frame_annotations.jgz"
 | 
			
		||||
        )
 | 
			
		||||
        sequence_file = os.path.join(
 | 
			
		||||
            self.dataset_root, self.category, "sequence_annotations.jgz"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        path_manager = self.path_manager_factory.get()
 | 
			
		||||
 | 
			
		||||
        # setup the common dataset arguments
 | 
			
		||||
        common_dataset_kwargs = getattr(self, f"dataset_{self.dataset_class_type}_args")
 | 
			
		||||
        common_dataset_kwargs = {
 | 
			
		||||
            **common_dataset_kwargs,
 | 
			
		||||
            "dataset_root": self.dataset_root,
 | 
			
		||||
            "frame_annotations_file": frame_file,
 | 
			
		||||
            "sequence_annotations_file": sequence_file,
 | 
			
		||||
            "subsets": None,
 | 
			
		||||
            "subset_lists_file": "",
 | 
			
		||||
            "path_manager": path_manager,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        # get the used dataset type
 | 
			
		||||
        dataset_type: Type[JsonIndexDataset] = registry.get(
 | 
			
		||||
            JsonIndexDataset, self.dataset_class_type
 | 
			
		||||
        )
 | 
			
		||||
        expand_args_fields(dataset_type)
 | 
			
		||||
 | 
			
		||||
        dataset = dataset_type(**common_dataset_kwargs)
 | 
			
		||||
 | 
			
		||||
        available_subset_names = self._get_available_subset_names()
 | 
			
		||||
        logger.debug(f"Available subset names: {str(available_subset_names)}.")
 | 
			
		||||
        if self.subset_name not in available_subset_names:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Unknown subset name {self.subset_name}."
 | 
			
		||||
                + f" Choose one of available subsets: {str(available_subset_names)}."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # load the list of train/val/test frames
 | 
			
		||||
        subset_mapping = self._load_annotation_json(
 | 
			
		||||
            os.path.join(
 | 
			
		||||
                self.category, "set_lists", f"set_lists_{self.subset_name}.json"
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # load the evaluation batches
 | 
			
		||||
        if self.load_eval_batches:
 | 
			
		||||
            eval_batch_index = self._load_annotation_json(
 | 
			
		||||
                os.path.join(
 | 
			
		||||
                    self.category,
 | 
			
		||||
                    "eval_batches",
 | 
			
		||||
                    f"eval_batches_{self.subset_name}.json",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        train_dataset = None
 | 
			
		||||
        if not self.only_test_set:
 | 
			
		||||
            # load the training set
 | 
			
		||||
            logger.debug("Extracting train dataset.")
 | 
			
		||||
            train_dataset = dataset.subset_from_frame_index(subset_mapping["train"])
 | 
			
		||||
            logger.info(f"Train dataset: {str(train_dataset)}")
 | 
			
		||||
 | 
			
		||||
        if self.test_on_train:
 | 
			
		||||
            assert train_dataset is not None
 | 
			
		||||
            val_dataset = test_dataset = train_dataset
 | 
			
		||||
        else:
 | 
			
		||||
            # load the val and test sets
 | 
			
		||||
            logger.debug("Extracting val dataset.")
 | 
			
		||||
            val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
 | 
			
		||||
            logger.info(f"Val dataset: {str(val_dataset)}")
 | 
			
		||||
            logger.debug("Extracting test dataset.")
 | 
			
		||||
            test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
 | 
			
		||||
            logger.info(f"Test dataset: {str(test_dataset)}")
 | 
			
		||||
            if self.load_eval_batches:
 | 
			
		||||
                # load the eval batches
 | 
			
		||||
                logger.debug("Extracting eval batches.")
 | 
			
		||||
                try:
 | 
			
		||||
                    test_dataset.eval_batches = (
 | 
			
		||||
                        test_dataset.seq_frame_index_to_dataset_index(
 | 
			
		||||
                            eval_batch_index,
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                except IndexError:
 | 
			
		||||
                    warnings.warn(
 | 
			
		||||
                        "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n"
 | 
			
		||||
                        + "Some eval batches are missing from the test dataset.\n"
 | 
			
		||||
                        + "The evaluation results will be incomparable to the\n"
 | 
			
		||||
                        + "evaluation results calculated on the original dataset.\n"
 | 
			
		||||
                        + "@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@"
 | 
			
		||||
                    )
 | 
			
		||||
                    test_dataset.eval_batches = (
 | 
			
		||||
                        test_dataset.seq_frame_index_to_dataset_index(
 | 
			
		||||
                            eval_batch_index,
 | 
			
		||||
                            allow_missing_indices=True,
 | 
			
		||||
                            remove_missing_indices=True,
 | 
			
		||||
                        )
 | 
			
		||||
                    )
 | 
			
		||||
                logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
 | 
			
		||||
 | 
			
		||||
        self.dataset_map = DatasetMap(
 | 
			
		||||
            train=train_dataset, val=val_dataset, test=test_dataset
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def create_dataset(self):
 | 
			
		||||
        # The dataset object is created inside `self.get_dataset_map`
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    def get_dataset_map(self) -> DatasetMap:
 | 
			
		||||
        return self.dataset_map  # pyre-ignore [16]
 | 
			
		||||
 | 
			
		||||
    def get_category_to_subset_name_list(self) -> Dict[str, List[str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns a global dataset index containing the available subset names per category
 | 
			
		||||
        as a dictionary.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            category_to_subset_name_list: A dictionary containing subset names available
 | 
			
		||||
                per category of the following form:
 | 
			
		||||
                    ```
 | 
			
		||||
                    {
 | 
			
		||||
                        category_0: [category_0_subset_name_0, category_0_subset_name_1, ...],
 | 
			
		||||
                        category_1: [category_1_subset_name_0, category_1_subset_name_1, ...],
 | 
			
		||||
                        ...
 | 
			
		||||
                    }
 | 
			
		||||
                    ```
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
        category_to_subset_name_list_json = "category_to_subset_name_list.json"
 | 
			
		||||
        category_to_subset_name_list = self._load_annotation_json(
 | 
			
		||||
            category_to_subset_name_list_json
 | 
			
		||||
        )
 | 
			
		||||
        return category_to_subset_name_list
 | 
			
		||||
 | 
			
		||||
    def _load_annotation_json(self, json_filename: str):
 | 
			
		||||
        full_path = os.path.join(
 | 
			
		||||
            self.dataset_root,
 | 
			
		||||
            json_filename,
 | 
			
		||||
        )
 | 
			
		||||
        logger.info(f"Loading frame index json from {full_path}.")
 | 
			
		||||
        path_manager = self.path_manager_factory.get()
 | 
			
		||||
        if path_manager is not None:
 | 
			
		||||
            full_path = path_manager.get_local_path(full_path)
 | 
			
		||||
        if not os.path.isfile(full_path):
 | 
			
		||||
            # The batch indices file does not exist.
 | 
			
		||||
            # Most probably the user has not specified the root folder.
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Looking for dataset json file in {full_path}. "
 | 
			
		||||
                + "Please specify a correct dataset_root folder."
 | 
			
		||||
            )
 | 
			
		||||
        with open(full_path, "r") as f:
 | 
			
		||||
            data = json.load(f)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def _get_available_subset_names(self):
 | 
			
		||||
        path_manager = self.path_manager_factory.get()
 | 
			
		||||
        if path_manager is not None:
 | 
			
		||||
            dataset_root = path_manager.get_local_path(self.dataset_root)
 | 
			
		||||
        else:
 | 
			
		||||
            dataset_root = self.dataset_root
 | 
			
		||||
        return get_available_subset_names(dataset_root, self.category)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
 | 
			
		||||
    """
 | 
			
		||||
    Get the available subset names for a given category folder inside a root dataset
 | 
			
		||||
    folder `dataset_root`.
 | 
			
		||||
    """
 | 
			
		||||
    category_dir = os.path.join(dataset_root, category)
 | 
			
		||||
    if not os.path.isdir(category_dir):
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Looking for dataset files in {category_dir}. "
 | 
			
		||||
            + "Please specify a correct dataset_root folder."
 | 
			
		||||
        )
 | 
			
		||||
    set_list_jsons = os.listdir(os.path.join(category_dir, "set_lists"))
 | 
			
		||||
    return [
 | 
			
		||||
        json_file.replace("set_lists_", "").replace(".json", "")
 | 
			
		||||
        for json_file in set_list_jsons
 | 
			
		||||
    ]
 | 
			
		||||
@ -42,6 +42,47 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
			
		||||
    sort_frames: false
 | 
			
		||||
  path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
    silence_logs: true
 | 
			
		||||
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
 | 
			
		||||
  category: ???
 | 
			
		||||
  subset_name: ???
 | 
			
		||||
  dataset_root: ''
 | 
			
		||||
  test_on_train: false
 | 
			
		||||
  only_test_set: false
 | 
			
		||||
  load_eval_batches: true
 | 
			
		||||
  dataset_class_type: JsonIndexDataset
 | 
			
		||||
  path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
  dataset_JsonIndexDataset_args:
 | 
			
		||||
    path_manager: null
 | 
			
		||||
    frame_annotations_file: ''
 | 
			
		||||
    sequence_annotations_file: ''
 | 
			
		||||
    subset_lists_file: ''
 | 
			
		||||
    subsets: null
 | 
			
		||||
    limit_to: 0
 | 
			
		||||
    limit_sequences_to: 0
 | 
			
		||||
    pick_sequence: []
 | 
			
		||||
    exclude_sequence: []
 | 
			
		||||
    limit_category_to: []
 | 
			
		||||
    dataset_root: ''
 | 
			
		||||
    load_images: true
 | 
			
		||||
    load_depths: true
 | 
			
		||||
    load_depth_masks: true
 | 
			
		||||
    load_masks: true
 | 
			
		||||
    load_point_clouds: false
 | 
			
		||||
    max_points: 0
 | 
			
		||||
    mask_images: false
 | 
			
		||||
    mask_depths: false
 | 
			
		||||
    image_height: 800
 | 
			
		||||
    image_width: 800
 | 
			
		||||
    box_crop: true
 | 
			
		||||
    box_crop_mask_thr: 0.4
 | 
			
		||||
    box_crop_context: 0.3
 | 
			
		||||
    remove_empty_masks: true
 | 
			
		||||
    n_frames_per_sequence: -1
 | 
			
		||||
    seed: 0
 | 
			
		||||
    sort_frames: false
 | 
			
		||||
    eval_batches: null
 | 
			
		||||
  path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
    silence_logs: true
 | 
			
		||||
dataset_map_provider_LlffDatasetMapProvider_args:
 | 
			
		||||
  base_dir: ???
 | 
			
		||||
  object_name: ???
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										155
									
								
								tests/implicitron/test_json_index_dataset_provider_v2.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								tests/implicitron/test_json_index_dataset_provider_v2.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,155 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import random
 | 
			
		||||
import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torchvision
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.json_index_dataset_map_provider_v2 import (
 | 
			
		||||
    JsonIndexDatasetMapProviderV2,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.types import (
 | 
			
		||||
    dump_dataclass_jgzip,
 | 
			
		||||
    FrameAnnotation,
 | 
			
		||||
    ImageAnnotation,
 | 
			
		||||
    MaskAnnotation,
 | 
			
		||||
    SequenceAnnotation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestJsonIndexDatasetProviderV2(unittest.TestCase):
 | 
			
		||||
    def test_random_dataset(self):
 | 
			
		||||
        # store random frame annotations
 | 
			
		||||
        expand_args_fields(JsonIndexDatasetMapProviderV2)
 | 
			
		||||
        categories = ["A", "B"]
 | 
			
		||||
        subset_name = "test"
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmpd:
 | 
			
		||||
            _make_random_json_dataset_map_provider_v2_data(tmpd, categories)
 | 
			
		||||
            for category in categories:
 | 
			
		||||
                dataset_provider = JsonIndexDatasetMapProviderV2(
 | 
			
		||||
                    category=category,
 | 
			
		||||
                    subset_name="test",
 | 
			
		||||
                    dataset_root=tmpd,
 | 
			
		||||
                )
 | 
			
		||||
                dataset_map = dataset_provider.get_dataset_map()
 | 
			
		||||
                for set_ in ["train", "val", "test"]:
 | 
			
		||||
                    dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
                        getattr(dataset_map, set_),
 | 
			
		||||
                        batch_size=3,
 | 
			
		||||
                        shuffle=True,
 | 
			
		||||
                        collate_fn=FrameData.collate,
 | 
			
		||||
                    )
 | 
			
		||||
                    for _ in dataloader:
 | 
			
		||||
                        pass
 | 
			
		||||
                category_to_subset_list = (
 | 
			
		||||
                    dataset_provider.get_category_to_subset_name_list()
 | 
			
		||||
                )
 | 
			
		||||
                category_to_subset_list_ = {c: [subset_name] for c in categories}
 | 
			
		||||
                self.assertTrue(category_to_subset_list == category_to_subset_list_)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
    root: str,
 | 
			
		||||
    categories: List[str],
 | 
			
		||||
    n_frames: int = 8,
 | 
			
		||||
    n_sequences: int = 5,
 | 
			
		||||
    H: int = 50,
 | 
			
		||||
    W: int = 30,
 | 
			
		||||
    subset_name: str = "test",
 | 
			
		||||
):
 | 
			
		||||
    os.makedirs(root, exist_ok=True)
 | 
			
		||||
    category_to_subset_list = {}
 | 
			
		||||
    for category in categories:
 | 
			
		||||
        frame_annotations = []
 | 
			
		||||
        sequence_annotations = []
 | 
			
		||||
        frame_index = []
 | 
			
		||||
        for seq_i in range(n_sequences):
 | 
			
		||||
            seq_name = str(seq_i)
 | 
			
		||||
            for i in range(n_frames):
 | 
			
		||||
                # generate and store image
 | 
			
		||||
                imdir = os.path.join(root, category, seq_name, "images")
 | 
			
		||||
                os.makedirs(imdir, exist_ok=True)
 | 
			
		||||
                img_path = os.path.join(imdir, f"frame{i:05d}.jpg")
 | 
			
		||||
                img = torch.rand(3, H, W)
 | 
			
		||||
                torchvision.utils.save_image(img, img_path)
 | 
			
		||||
 | 
			
		||||
                # generate and store mask
 | 
			
		||||
                maskdir = os.path.join(root, category, seq_name, "masks")
 | 
			
		||||
                os.makedirs(maskdir, exist_ok=True)
 | 
			
		||||
                mask_path = os.path.join(maskdir, f"frame{i:05d}.png")
 | 
			
		||||
                mask = np.zeros((H, W))
 | 
			
		||||
                mask[H // 2 :, W // 2 :] = 1
 | 
			
		||||
                Image.fromarray((mask * 255.0).astype(np.uint8), mode="L",).convert(
 | 
			
		||||
                    "L"
 | 
			
		||||
                ).save(mask_path)
 | 
			
		||||
 | 
			
		||||
                fa = FrameAnnotation(
 | 
			
		||||
                    sequence_name=seq_name,
 | 
			
		||||
                    frame_number=i,
 | 
			
		||||
                    frame_timestamp=float(i),
 | 
			
		||||
                    image=ImageAnnotation(
 | 
			
		||||
                        path=img_path.replace(os.path.normpath(root) + "/", ""),
 | 
			
		||||
                        size=list(img.shape[-2:]),
 | 
			
		||||
                    ),
 | 
			
		||||
                    mask=MaskAnnotation(
 | 
			
		||||
                        path=mask_path.replace(os.path.normpath(root) + "/", ""),
 | 
			
		||||
                        mass=mask.sum().item(),
 | 
			
		||||
                    ),
 | 
			
		||||
                )
 | 
			
		||||
                frame_annotations.append(fa)
 | 
			
		||||
                frame_index.append((seq_name, i, fa.image.path))
 | 
			
		||||
 | 
			
		||||
            sequence_annotations.append(
 | 
			
		||||
                SequenceAnnotation(
 | 
			
		||||
                    sequence_name=seq_name,
 | 
			
		||||
                    category=category,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        dump_dataclass_jgzip(
 | 
			
		||||
            os.path.join(root, category, "frame_annotations.jgz"),
 | 
			
		||||
            frame_annotations,
 | 
			
		||||
        )
 | 
			
		||||
        dump_dataclass_jgzip(
 | 
			
		||||
            os.path.join(root, category, "sequence_annotations.jgz"),
 | 
			
		||||
            sequence_annotations,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        test_frame_index = frame_index[2::3]
 | 
			
		||||
 | 
			
		||||
        set_list = {
 | 
			
		||||
            "train": frame_index[0::3],
 | 
			
		||||
            "val": frame_index[1::3],
 | 
			
		||||
            "test": test_frame_index,
 | 
			
		||||
        }
 | 
			
		||||
        set_lists_dir = os.path.join(root, category, "set_lists")
 | 
			
		||||
        os.makedirs(set_lists_dir, exist_ok=True)
 | 
			
		||||
        set_list_file = os.path.join(set_lists_dir, f"set_lists_{subset_name}.json")
 | 
			
		||||
        with open(set_list_file, "w") as f:
 | 
			
		||||
            json.dump(set_list, f)
 | 
			
		||||
 | 
			
		||||
        eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
 | 
			
		||||
        eval_b_dir = os.path.join(root, category, "eval_batches")
 | 
			
		||||
        os.makedirs(eval_b_dir, exist_ok=True)
 | 
			
		||||
        eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
 | 
			
		||||
        with open(eval_b_file, "w") as f:
 | 
			
		||||
            json.dump(eval_batches, f)
 | 
			
		||||
 | 
			
		||||
        category_to_subset_list[category] = [subset_name]
 | 
			
		||||
 | 
			
		||||
    with open(os.path.join(root, "category_to_subset_name_list.json"), "w") as f:
 | 
			
		||||
        json.dump(category_to_subset_list, f)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user