mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e8616cc8ba
commit
2ff2c7c836
@@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
n_known_frames_for_test: 0
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
|
||||
@@ -37,29 +37,47 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
|
||||
expand_args_fields(JsonIndexDatasetMapProviderV2)
|
||||
categories = ["A", "B"]
|
||||
subset_name = "test"
|
||||
eval_batch_size = 5
|
||||
with tempfile.TemporaryDirectory() as tmpd:
|
||||
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
)
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
batch_size=3,
|
||||
shuffle=True,
|
||||
collate_fn=FrameData.collate,
|
||||
_make_random_json_dataset_map_provider_v2_data(
|
||||
tmpd,
|
||||
categories,
|
||||
eval_batch_size=eval_batch_size,
|
||||
)
|
||||
for n_known_frames_for_test in [0, 2]:
|
||||
for category in categories:
|
||||
dataset_provider = JsonIndexDatasetMapProviderV2(
|
||||
category=category,
|
||||
subset_name="test",
|
||||
dataset_root=tmpd,
|
||||
n_known_frames_for_test=n_known_frames_for_test,
|
||||
)
|
||||
for _ in dataloader:
|
||||
pass
|
||||
category_to_subset_list = (
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||
dataset_map = dataset_provider.get_dataset_map()
|
||||
for set_ in ["train", "val", "test"]:
|
||||
if set_ in ["train", "val"]:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
batch_size=3,
|
||||
shuffle=True,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
else:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
getattr(dataset_map, set_),
|
||||
batch_sampler=dataset_map[set_].get_eval_batches(),
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
for batch in dataloader:
|
||||
if set_ == "test":
|
||||
self.assertTrue(
|
||||
batch.image_rgb.shape[0]
|
||||
== n_known_frames_for_test + eval_batch_size
|
||||
)
|
||||
category_to_subset_list = (
|
||||
dataset_provider.get_category_to_subset_name_list()
|
||||
)
|
||||
category_to_subset_list_ = {c: [subset_name] for c in categories}
|
||||
self.assertTrue(category_to_subset_list == category_to_subset_list_)
|
||||
|
||||
|
||||
def _make_random_json_dataset_map_provider_v2_data(
|
||||
@@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
H: int = 50,
|
||||
W: int = 30,
|
||||
subset_name: str = "test",
|
||||
eval_batch_size: int = 5,
|
||||
):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
category_to_subset_list = {}
|
||||
@@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
|
||||
with open(set_list_file, "w") as f:
|
||||
json.dump(set_list, f)
|
||||
|
||||
eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
|
||||
eval_batches = [
|
||||
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
|
||||
]
|
||||
|
||||
eval_b_dir = os.path.join(root, category, "eval_batches")
|
||||
os.makedirs(eval_b_dir, exist_ok=True)
|
||||
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
|
||||
|
||||
Reference in New Issue
Block a user