Enable additional test-time source views for json dataset provider v2

Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge

Reviewed By: bottler

Differential Revision: D38705904

fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
David Novotny 2022-08-15 08:40:09 -07:00 committed by Facebook GitHub Bot
parent e8616cc8ba
commit 2ff2c7c836
4 changed files with 102 additions and 23 deletions

View File

@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
test_on_train: false test_on_train: false
only_test_set: false only_test_set: false
load_eval_batches: true load_eval_batches: true
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args: dataset_JsonIndexDataset_args:

View File

@ -5,11 +5,15 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import copy
import json import json
import logging import logging
import os import os
import warnings import warnings
from typing import Dict, List, Optional, Tuple, Type from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type, Union
import numpy as np
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.dataset_map_provider import ( from pytorch3d.implicitron.dataset.dataset_map_provider import (
@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
only_test_set: Load only the test set. Incompatible with `test_on_train`. only_test_set: Load only the test set. Incompatible with `test_on_train`.
load_eval_batches: Load the file containing eval batches pointing to the load_eval_batches: Load the file containing eval batches pointing to the
test dataset. test dataset.
n_known_frames_for_test: Add a certain number of known frames to each
eval batch. Useful for evaluating models that require
source views as input (e.g. NeRF-WCE / PixelNeRF).
dataset_args: Specifies additional arguments to the dataset_args: Specifies additional arguments to the
JsonIndexDataset constructor call. JsonIndexDataset constructor call.
path_manager_factory: (Optional) An object that generates an instance of path_manager_factory: (Optional) An object that generates an instance of
@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
only_test_set: bool = False only_test_set: bool = False
load_eval_batches: bool = True load_eval_batches: bool = True
n_known_frames_for_test: int = 0
dataset_class_type: str = "JsonIndexDataset" dataset_class_type: str = "JsonIndexDataset"
dataset: JsonIndexDataset dataset: JsonIndexDataset
@ -264,6 +273,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"]) val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
logger.info(f"Val dataset: {str(val_dataset)}") logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.") logger.debug("Extracting test dataset.")
if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
# extend the test subset mapping and the dataset with additional
# known views from the train dataset
(
eval_batch_index,
subset_mapping["test"],
) = self._extend_test_data_with_known_views(
subset_mapping,
eval_batch_index,
)
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"]) test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
logger.info(f"Test dataset: {str(test_dataset)}") logger.info(f"Test dataset: {str(test_dataset)}")
if self.load_eval_batches: if self.load_eval_batches:
@ -369,6 +390,40 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
dataset_root = self.dataset_root dataset_root = self.dataset_root
return get_available_subset_names(dataset_root, self.category) return get_available_subset_names(dataset_root, self.category)
def _extend_test_data_with_known_views(
self,
subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
):
# convert the train subset mapping to a dict:
# sequence_to_train_frames: {sequence_name: frame_index}
sequence_to_train_frames = defaultdict(list)
for frame_entry in subset_mapping["train"]:
sequence_name = frame_entry[0]
sequence_to_train_frames[sequence_name].append(frame_entry)
sequence_to_train_frames = dict(sequence_to_train_frames)
test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
# extend the eval batches / subset mapping with the additional examples
eval_batch_index_out = copy.deepcopy(eval_batch_index)
generator = np.random.default_rng(seed=0)
for batch in eval_batch_index_out:
sequence_name = batch[0][0]
sequence_known_entries = sequence_to_train_frames[sequence_name]
idx_to_add = generator.permutation(len(sequence_known_entries))[
: self.n_known_frames_for_test
]
entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
assert all(e in subset_mapping["train"] for e in entries_to_add)
# extend the eval batch with the known views
batch.extend(entries_to_add)
# also add these new entries to the test subset mapping
test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
return eval_batch_index_out, list(test_subset_mapping_set)
def get_available_subset_names(dataset_root: str, category: str) -> List[str]: def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
""" """

View File

@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
test_on_train: false test_on_train: false
only_test_set: false only_test_set: false
load_eval_batches: true load_eval_batches: true
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args: dataset_JsonIndexDataset_args:

View File

@ -37,29 +37,47 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
expand_args_fields(JsonIndexDatasetMapProviderV2) expand_args_fields(JsonIndexDatasetMapProviderV2)
categories = ["A", "B"] categories = ["A", "B"]
subset_name = "test" subset_name = "test"
eval_batch_size = 5
with tempfile.TemporaryDirectory() as tmpd: with tempfile.TemporaryDirectory() as tmpd:
_make_random_json_dataset_map_provider_v2_data(tmpd, categories) _make_random_json_dataset_map_provider_v2_data(
for category in categories: tmpd,
dataset_provider = JsonIndexDatasetMapProviderV2( categories,
category=category, eval_batch_size=eval_batch_size,
subset_name="test", )
dataset_root=tmpd, for n_known_frames_for_test in [0, 2]:
) for category in categories:
dataset_map = dataset_provider.get_dataset_map() dataset_provider = JsonIndexDatasetMapProviderV2(
for set_ in ["train", "val", "test"]: category=category,
dataloader = torch.utils.data.DataLoader( subset_name="test",
getattr(dataset_map, set_), dataset_root=tmpd,
batch_size=3, n_known_frames_for_test=n_known_frames_for_test,
shuffle=True,
collate_fn=FrameData.collate,
) )
for _ in dataloader: dataset_map = dataset_provider.get_dataset_map()
pass for set_ in ["train", "val", "test"]:
category_to_subset_list = ( if set_ in ["train", "val"]:
dataset_provider.get_category_to_subset_name_list() dataloader = torch.utils.data.DataLoader(
) getattr(dataset_map, set_),
category_to_subset_list_ = {c: [subset_name] for c in categories} batch_size=3,
self.assertTrue(category_to_subset_list == category_to_subset_list_) shuffle=True,
collate_fn=FrameData.collate,
)
else:
dataloader = torch.utils.data.DataLoader(
getattr(dataset_map, set_),
batch_sampler=dataset_map[set_].get_eval_batches(),
collate_fn=FrameData.collate,
)
for batch in dataloader:
if set_ == "test":
self.assertTrue(
batch.image_rgb.shape[0]
== n_known_frames_for_test + eval_batch_size
)
category_to_subset_list = (
dataset_provider.get_category_to_subset_name_list()
)
category_to_subset_list_ = {c: [subset_name] for c in categories}
self.assertTrue(category_to_subset_list == category_to_subset_list_)
def _make_random_json_dataset_map_provider_v2_data( def _make_random_json_dataset_map_provider_v2_data(
@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
H: int = 50, H: int = 50,
W: int = 30, W: int = 30,
subset_name: str = "test", subset_name: str = "test",
eval_batch_size: int = 5,
): ):
os.makedirs(root, exist_ok=True) os.makedirs(root, exist_ok=True)
category_to_subset_list = {} category_to_subset_list = {}
@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
with open(set_list_file, "w") as f: with open(set_list_file, "w") as f:
json.dump(set_list, f) json.dump(set_list, f)
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)] eval_batches = [
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
]
eval_b_dir = os.path.join(root, category, "eval_batches") eval_b_dir = os.path.join(root, category, "eval_batches")
os.makedirs(eval_b_dir, exist_ok=True) os.makedirs(eval_b_dir, exist_ok=True)
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json") eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")