mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-30 02:28:56 +08:00
Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e8616cc8ba
commit
2ff2c7c836
@@ -5,11 +5,15 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
@@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||
load_eval_batches: Load the file containing eval batches pointing to the
|
||||
test dataset.
|
||||
n_known_frames_for_test: Add a certain number of known frames to each
|
||||
eval batch. Useful for evaluating models that require
|
||||
source views as input (e.g. NeRF-WCE / PixelNeRF).
|
||||
dataset_args: Specifies additional arguments to the
|
||||
JsonIndexDataset constructor call.
|
||||
path_manager_factory: (Optional) An object that generates an instance of
|
||||
@@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
only_test_set: bool = False
|
||||
load_eval_batches: bool = True
|
||||
|
||||
n_known_frames_for_test: int = 0
|
||||
|
||||
dataset_class_type: str = "JsonIndexDataset"
|
||||
dataset: JsonIndexDataset
|
||||
|
||||
@@ -264,6 +273,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
logger.debug("Extracting test dataset.")
|
||||
|
||||
if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
|
||||
# extend the test subset mapping and the dataset with additional
|
||||
# known views from the train dataset
|
||||
(
|
||||
eval_batch_index,
|
||||
subset_mapping["test"],
|
||||
) = self._extend_test_data_with_known_views(
|
||||
subset_mapping,
|
||||
eval_batch_index,
|
||||
)
|
||||
|
||||
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
||||
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||
if self.load_eval_batches:
|
||||
@@ -369,6 +390,40 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
dataset_root = self.dataset_root
|
||||
return get_available_subset_names(dataset_root, self.category)
|
||||
|
||||
def _extend_test_data_with_known_views(
|
||||
self,
|
||||
subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
|
||||
eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
):
|
||||
# convert the train subset mapping to a dict:
|
||||
# sequence_to_train_frames: {sequence_name: frame_index}
|
||||
sequence_to_train_frames = defaultdict(list)
|
||||
for frame_entry in subset_mapping["train"]:
|
||||
sequence_name = frame_entry[0]
|
||||
sequence_to_train_frames[sequence_name].append(frame_entry)
|
||||
sequence_to_train_frames = dict(sequence_to_train_frames)
|
||||
test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
|
||||
|
||||
# extend the eval batches / subset mapping with the additional examples
|
||||
eval_batch_index_out = copy.deepcopy(eval_batch_index)
|
||||
generator = np.random.default_rng(seed=0)
|
||||
for batch in eval_batch_index_out:
|
||||
sequence_name = batch[0][0]
|
||||
sequence_known_entries = sequence_to_train_frames[sequence_name]
|
||||
idx_to_add = generator.permutation(len(sequence_known_entries))[
|
||||
: self.n_known_frames_for_test
|
||||
]
|
||||
entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
|
||||
assert all(e in subset_mapping["train"] for e in entries_to_add)
|
||||
|
||||
# extend the eval batch with the known views
|
||||
batch.extend(entries_to_add)
|
||||
|
||||
# also add these new entries to the test subset mapping
|
||||
test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
|
||||
|
||||
return eval_batch_index_out, list(test_subset_mapping_set)
|
||||
|
||||
|
||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user