mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
parent
e8616cc8ba
commit
2ff2c7c836
@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
|
|||||||
test_on_train: false
|
test_on_train: false
|
||||||
only_test_set: false
|
only_test_set: false
|
||||||
load_eval_batches: true
|
load_eval_batches: true
|
||||||
|
n_known_frames_for_test: 0
|
||||||
dataset_class_type: JsonIndexDataset
|
dataset_class_type: JsonIndexDataset
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
dataset_JsonIndexDataset_args:
|
dataset_JsonIndexDataset_args:
|
||||||
|
@ -5,11 +5,15 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import copy
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Tuple, Type
|
from collections import defaultdict
|
||||||
|
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
only_test_set: Load only the test set. Incompatible with `test_on_train`.
|
||||||
load_eval_batches: Load the file containing eval batches pointing to the
|
load_eval_batches: Load the file containing eval batches pointing to the
|
||||||
test dataset.
|
test dataset.
|
||||||
|
n_known_frames_for_test: Add a certain number of known frames to each
|
||||||
|
eval batch. Useful for evaluating models that require
|
||||||
|
source views as input (e.g. NeRF-WCE / PixelNeRF).
|
||||||
dataset_args: Specifies additional arguments to the
|
dataset_args: Specifies additional arguments to the
|
||||||
JsonIndexDataset constructor call.
|
JsonIndexDataset constructor call.
|
||||||
path_manager_factory: (Optional) An object that generates an instance of
|
path_manager_factory: (Optional) An object that generates an instance of
|
||||||
@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
only_test_set: bool = False
|
only_test_set: bool = False
|
||||||
load_eval_batches: bool = True
|
load_eval_batches: bool = True
|
||||||
|
|
||||||
|
n_known_frames_for_test: int = 0
|
||||||
|
|
||||||
dataset_class_type: str = "JsonIndexDataset"
|
dataset_class_type: str = "JsonIndexDataset"
|
||||||
dataset: JsonIndexDataset
|
dataset: JsonIndexDataset
|
||||||
|
|
||||||
@ -264,6 +273,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
|
||||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
logger.debug("Extracting test dataset.")
|
logger.debug("Extracting test dataset.")
|
||||||
|
|
||||||
|
if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
|
||||||
|
# extend the test subset mapping and the dataset with additional
|
||||||
|
# known views from the train dataset
|
||||||
|
(
|
||||||
|
eval_batch_index,
|
||||||
|
subset_mapping["test"],
|
||||||
|
) = self._extend_test_data_with_known_views(
|
||||||
|
subset_mapping,
|
||||||
|
eval_batch_index,
|
||||||
|
)
|
||||||
|
|
||||||
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
|
||||||
logger.info(f"Test dataset: {str(test_dataset)}")
|
logger.info(f"Test dataset: {str(test_dataset)}")
|
||||||
if self.load_eval_batches:
|
if self.load_eval_batches:
|
||||||
@ -369,6 +390,40 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
dataset_root = self.dataset_root
|
dataset_root = self.dataset_root
|
||||||
return get_available_subset_names(dataset_root, self.category)
|
return get_available_subset_names(dataset_root, self.category)
|
||||||
|
|
||||||
|
def _extend_test_data_with_known_views(
|
||||||
|
self,
|
||||||
|
subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
|
||||||
|
eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
|
||||||
|
):
|
||||||
|
# convert the train subset mapping to a dict:
|
||||||
|
# sequence_to_train_frames: {sequence_name: frame_index}
|
||||||
|
sequence_to_train_frames = defaultdict(list)
|
||||||
|
for frame_entry in subset_mapping["train"]:
|
||||||
|
sequence_name = frame_entry[0]
|
||||||
|
sequence_to_train_frames[sequence_name].append(frame_entry)
|
||||||
|
sequence_to_train_frames = dict(sequence_to_train_frames)
|
||||||
|
test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
|
||||||
|
|
||||||
|
# extend the eval batches / subset mapping with the additional examples
|
||||||
|
eval_batch_index_out = copy.deepcopy(eval_batch_index)
|
||||||
|
generator = np.random.default_rng(seed=0)
|
||||||
|
for batch in eval_batch_index_out:
|
||||||
|
sequence_name = batch[0][0]
|
||||||
|
sequence_known_entries = sequence_to_train_frames[sequence_name]
|
||||||
|
idx_to_add = generator.permutation(len(sequence_known_entries))[
|
||||||
|
: self.n_known_frames_for_test
|
||||||
|
]
|
||||||
|
entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
|
||||||
|
assert all(e in subset_mapping["train"] for e in entries_to_add)
|
||||||
|
|
||||||
|
# extend the eval batch with the known views
|
||||||
|
batch.extend(entries_to_add)
|
||||||
|
|
||||||
|
# also add these new entries to the test subset mapping
|
||||||
|
test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
|
||||||
|
|
||||||
|
return eval_batch_index_out, list(test_subset_mapping_set)
|
||||||
|
|
||||||
|
|
||||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
|
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
|||||||
test_on_train: false
|
test_on_train: false
|
||||||
only_test_set: false
|
only_test_set: false
|
||||||
load_eval_batches: true
|
load_eval_batches: true
|
||||||
|
n_known_frames_for_test: 0
|
||||||
dataset_class_type: JsonIndexDataset
|
dataset_class_type: JsonIndexDataset
|
||||||
path_manager_factory_class_type: PathManagerFactory
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
dataset_JsonIndexDataset_args:
|
dataset_JsonIndexDataset_args:
|
||||||
|
@ -37,29 +37,47 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
|||||||
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
||||||
categories = ["A", "B"]
|
categories = ["A", "B"]
|
||||||
subset_name = "test"
|
subset_name = "test"
|
||||||
|
eval_batch_size = 5
|
||||||
with tempfile.TemporaryDirectory() as tmpd:
|
with tempfile.TemporaryDirectory() as tmpd:
|
||||||
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
|
_make_random_json_dataset_map_provider_v2_data(
|
||||||
for category in categories:
|
tmpd,
|
||||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
categories,
|
||||||
category=category,
|
eval_batch_size=eval_batch_size,
|
||||||
subset_name="test",
|
)
|
||||||
dataset_root=tmpd,
|
for n_known_frames_for_test in [0, 2]:
|
||||||
)
|
for category in categories:
|
||||||
dataset_map = dataset_provider.get_dataset_map()
|
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||||
for set_ in ["train", "val", "test"]:
|
category=category,
|
||||||
dataloader = torch.utils.data.DataLoader(
|
subset_name="test",
|
||||||
getattr(dataset_map, set_),
|
dataset_root=tmpd,
|
||||||
batch_size=3,
|
n_known_frames_for_test=n_known_frames_for_test,
|
||||||
shuffle=True,
|
|
||||||
collate_fn=FrameData.collate,
|
|
||||||
)
|
)
|
||||||
for _ in dataloader:
|
dataset_map = dataset_provider.get_dataset_map()
|
||||||
pass
|
for set_ in ["train", "val", "test"]:
|
||||||
category_to_subset_list = (
|
if set_ in ["train", "val"]:
|
||||||
dataset_provider.get_category_to_subset_name_list()
|
dataloader = torch.utils.data.DataLoader(
|
||||||
)
|
getattr(dataset_map, set_),
|
||||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
batch_size=3,
|
||||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
shuffle=True,
|
||||||
|
collate_fn=FrameData.collate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
dataloader = torch.utils.data.DataLoader(
|
||||||
|
getattr(dataset_map, set_),
|
||||||
|
batch_sampler=dataset_map[set_].get_eval_batches(),
|
||||||
|
collate_fn=FrameData.collate,
|
||||||
|
)
|
||||||
|
for batch in dataloader:
|
||||||
|
if set_ == "test":
|
||||||
|
self.assertTrue(
|
||||||
|
batch.image_rgb.shape[0]
|
||||||
|
== n_known_frames_for_test + eval_batch_size
|
||||||
|
)
|
||||||
|
category_to_subset_list = (
|
||||||
|
dataset_provider.get_category_to_subset_name_list()
|
||||||
|
)
|
||||||
|
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||||
|
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||||
|
|
||||||
|
|
||||||
def _make_random_json_dataset_map_provider_v2_data(
|
def _make_random_json_dataset_map_provider_v2_data(
|
||||||
@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
|||||||
H: int = 50,
|
H: int = 50,
|
||||||
W: int = 30,
|
W: int = 30,
|
||||||
subset_name: str = "test",
|
subset_name: str = "test",
|
||||||
|
eval_batch_size: int = 5,
|
||||||
):
|
):
|
||||||
os.makedirs(root, exist_ok=True)
|
os.makedirs(root, exist_ok=True)
|
||||||
category_to_subset_list = {}
|
category_to_subset_list = {}
|
||||||
@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
|
|||||||
with open(set_list_file, "w") as f:
|
with open(set_list_file, "w") as f:
|
||||||
json.dump(set_list, f)
|
json.dump(set_list, f)
|
||||||
|
|
||||||
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
|
eval_batches = [
|
||||||
|
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||||
os.makedirs(eval_b_dir, exist_ok=True)
|
os.makedirs(eval_b_dir, exist_ok=True)
|
||||||
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user