mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
parent
e8616cc8ba
commit
2ff2c7c836
@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
n_known_frames_for_test: 0
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
|
@ -5,11 +5,15 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Type
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||
load_eval_batches: Load the file containing eval batches pointing to the
|
||||
test dataset.
|
||||
n_known_frames_for_test: Add a certain number of known frames to each
|
||||
eval batch. Useful for evaluating models that require
|
||||
source views as input (e.g. NeRF-WCE / PixelNeRF).
|
||||
dataset_args: Specifies additional arguments to the
|
||||
JsonIndexDataset constructor call.
|
||||
path_manager_factory: (Optional) An object that generates an instance of
|
||||
@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
only_test_set: bool = False
|
||||
load_eval_batches: bool = True
|
||||
|
||||
n_known_frames_for_test: int = 0
|
||||
|
||||
dataset_class_type: str = "JsonIndexDataset"
|
||||
dataset: JsonIndexDataset
|
||||
|
||||
@ -264,6 +273,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
logger.debug("Extracting test dataset.")
|
||||
|
||||
if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
|
||||
# extend the test subset mapping and the dataset with additional
|
||||
# known views from the train dataset
|
||||
(
|
||||
eval_batch_index,
|
||||
subset_mapping["test"],
|
||||
) = self._extend_test_data_with_known_views(
|
||||
subset_mapping,
|
||||
eval_batch_index,
|
||||
)
|
||||
|
||||
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
||||
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||
if self.load_eval_batches:
|
||||
@ -369,6 +390,40 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
dataset_root = self.dataset_root
|
||||
return get_available_subset_names(dataset_root, self.category)
|
||||
|
||||
def _extend_test_data_with_known_views(
|
||||
self,
|
||||
subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
|
||||
eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||
):
|
||||
# convert the train subset mapping to a dict:
|
||||
# sequence_to_train_frames: {sequence_name: frame_index}
|
||||
sequence_to_train_frames = defaultdict(list)
|
||||
for frame_entry in subset_mapping["train"]:
|
||||
sequence_name = frame_entry[0]
|
||||
sequence_to_train_frames[sequence_name].append(frame_entry)
|
||||
sequence_to_train_frames = dict(sequence_to_train_frames)
|
||||
test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
|
||||
|
||||
# extend the eval batches / subset mapping with the additional examples
|
||||
eval_batch_index_out = copy.deepcopy(eval_batch_index)
|
||||
generator = np.random.default_rng(seed=0)
|
||||
for batch in eval_batch_index_out:
|
||||
sequence_name = batch[0][0]
|
||||
sequence_known_entries = sequence_to_train_frames[sequence_name]
|
||||
idx_to_add = generator.permutation(len(sequence_known_entries))[
|
||||
: self.n_known_frames_for_test
|
||||
]
|
||||
entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
|
||||
assert all(e in subset_mapping["train"] for e in entries_to_add)
|
||||
|
||||
# extend the eval batch with the known views
|
||||
batch.extend(entries_to_add)
|
||||
|
||||
# also add these new entries to the test subset mapping
|
||||
test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
|
||||
|
||||
return eval_batch_index_out, list(test_subset_mapping_set)
|
||||
|
||||
|
||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||
"""
|
||||
|
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
n_known_frames_for_test: 0
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
|
@ -37,24 +37,42 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
||||
categories = ["A", "B"]
|
||||
subset_name = "test"
|
||||
eval_batch_size = 5
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
|
||||
_make_random_json_dataset_map_provider_v2_data(
|
||||
tmpd,
|
||||
categories,
|
||||
eval_batch_size=eval_batch_size,
|
||||
)
|
||||
for n_known_frames_for_test in [0, 2]:
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
n_known_frames_for_test=n_known_frames_for_test,
|
||||
)
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
if set_ in ["train", "val"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
batch_size=3,
|
||||
shuffle=True,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
for _ in dataloader:
|
||||
pass
|
||||
else:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
batch_sampler=dataset_map[set_].get_eval_batches(),
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
for batch in dataloader:
|
||||
if set_ == "test":
|
||||
self.assertTrue(
|
||||
batch.image_rgb.shape[0]
|
||||
== n_known_frames_for_test + eval_batch_size
|
||||
)
|
||||
category_to_subset_list = (
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
H: int = 50,
|
||||
W: int = 30,
|
||||
subset_name: str = "test",
|
||||
eval_batch_size: int = 5,
|
||||
):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
category_to_subset_list = {}
|
||||
@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
with open(set_list_file, "w") as f:
|
||||
json.dump(set_list, f)
|
||||
|
||||
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
|
||||
eval_batches = [
|
||||
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
||||
]
|
||||
|
||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||
os.makedirs(eval_b_dir, exist_ok=True)
|
||||
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
||||
|
Loading…
x
Reference in New Issue
Block a user