mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
		
							parent
							
								
									e8616cc8ba
								
							
						
					
					
						commit
						2ff2c7c836
					
				@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
 | 
				
			|||||||
    test_on_train: false
 | 
					    test_on_train: false
 | 
				
			||||||
    only_test_set: false
 | 
					    only_test_set: false
 | 
				
			||||||
    load_eval_batches: true
 | 
					    load_eval_batches: true
 | 
				
			||||||
 | 
					    n_known_frames_for_test: 0
 | 
				
			||||||
    dataset_class_type: JsonIndexDataset
 | 
					    dataset_class_type: JsonIndexDataset
 | 
				
			||||||
    path_manager_factory_class_type: PathManagerFactory
 | 
					    path_manager_factory_class_type: PathManagerFactory
 | 
				
			||||||
    dataset_JsonIndexDataset_args:
 | 
					    dataset_JsonIndexDataset_args:
 | 
				
			||||||
 | 
				
			|||||||
@ -5,11 +5,15 @@
 | 
				
			|||||||
# LICENSE file in the root directory of this source tree.
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import copy
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
from typing import Dict, List, Optional, Tuple, Type
 | 
					from collections import defaultdict
 | 
				
			||||||
 | 
					from typing import Dict, List, Optional, Tuple, Type, Union
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from omegaconf import DictConfig
 | 
					from omegaconf import DictConfig
 | 
				
			||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
					from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
				
			||||||
@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
				
			|||||||
        only_test_set: Load only the test set. Incompatible with `test_on_train`.
 | 
					        only_test_set: Load only the test set. Incompatible with `test_on_train`.
 | 
				
			||||||
        load_eval_batches: Load the file containing eval batches pointing to the
 | 
					        load_eval_batches: Load the file containing eval batches pointing to the
 | 
				
			||||||
            test dataset.
 | 
					            test dataset.
 | 
				
			||||||
 | 
					        n_known_frames_for_test: Add a certain number of known frames to each
 | 
				
			||||||
 | 
					            eval batch. Useful for evaluating models that require
 | 
				
			||||||
 | 
					            source views as input (e.g. NeRF-WCE / PixelNeRF).
 | 
				
			||||||
        dataset_args: Specifies additional arguments to the
 | 
					        dataset_args: Specifies additional arguments to the
 | 
				
			||||||
            JsonIndexDataset constructor call.
 | 
					            JsonIndexDataset constructor call.
 | 
				
			||||||
        path_manager_factory: (Optional) An object that generates an instance of
 | 
					        path_manager_factory: (Optional) An object that generates an instance of
 | 
				
			||||||
@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
				
			|||||||
    only_test_set: bool = False
 | 
					    only_test_set: bool = False
 | 
				
			||||||
    load_eval_batches: bool = True
 | 
					    load_eval_batches: bool = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    n_known_frames_for_test: int = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dataset_class_type: str = "JsonIndexDataset"
 | 
					    dataset_class_type: str = "JsonIndexDataset"
 | 
				
			||||||
    dataset: JsonIndexDataset
 | 
					    dataset: JsonIndexDataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -264,6 +273,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
				
			|||||||
            val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
 | 
					            val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
 | 
				
			||||||
            logger.info(f"Val dataset: {str(val_dataset)}")
 | 
					            logger.info(f"Val dataset: {str(val_dataset)}")
 | 
				
			||||||
            logger.debug("Extracting test dataset.")
 | 
					            logger.debug("Extracting test dataset.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
 | 
				
			||||||
 | 
					                # extend the test subset mapping and the dataset with additional
 | 
				
			||||||
 | 
					                # known views from the train dataset
 | 
				
			||||||
 | 
					                (
 | 
				
			||||||
 | 
					                    eval_batch_index,
 | 
				
			||||||
 | 
					                    subset_mapping["test"],
 | 
				
			||||||
 | 
					                ) = self._extend_test_data_with_known_views(
 | 
				
			||||||
 | 
					                    subset_mapping,
 | 
				
			||||||
 | 
					                    eval_batch_index,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
 | 
					            test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
 | 
				
			||||||
            logger.info(f"Test dataset: {str(test_dataset)}")
 | 
					            logger.info(f"Test dataset: {str(test_dataset)}")
 | 
				
			||||||
            if self.load_eval_batches:
 | 
					            if self.load_eval_batches:
 | 
				
			||||||
@ -369,6 +390,40 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
				
			|||||||
            dataset_root = self.dataset_root
 | 
					            dataset_root = self.dataset_root
 | 
				
			||||||
        return get_available_subset_names(dataset_root, self.category)
 | 
					        return get_available_subset_names(dataset_root, self.category)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _extend_test_data_with_known_views(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
 | 
				
			||||||
 | 
					        eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        # convert the train subset mapping to a dict:
 | 
				
			||||||
 | 
					        #   sequence_to_train_frames: {sequence_name: frame_index}
 | 
				
			||||||
 | 
					        sequence_to_train_frames = defaultdict(list)
 | 
				
			||||||
 | 
					        for frame_entry in subset_mapping["train"]:
 | 
				
			||||||
 | 
					            sequence_name = frame_entry[0]
 | 
				
			||||||
 | 
					            sequence_to_train_frames[sequence_name].append(frame_entry)
 | 
				
			||||||
 | 
					        sequence_to_train_frames = dict(sequence_to_train_frames)
 | 
				
			||||||
 | 
					        test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # extend the eval batches / subset mapping with the additional examples
 | 
				
			||||||
 | 
					        eval_batch_index_out = copy.deepcopy(eval_batch_index)
 | 
				
			||||||
 | 
					        generator = np.random.default_rng(seed=0)
 | 
				
			||||||
 | 
					        for batch in eval_batch_index_out:
 | 
				
			||||||
 | 
					            sequence_name = batch[0][0]
 | 
				
			||||||
 | 
					            sequence_known_entries = sequence_to_train_frames[sequence_name]
 | 
				
			||||||
 | 
					            idx_to_add = generator.permutation(len(sequence_known_entries))[
 | 
				
			||||||
 | 
					                : self.n_known_frames_for_test
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					            entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
 | 
				
			||||||
 | 
					            assert all(e in subset_mapping["train"] for e in entries_to_add)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # extend the eval batch with the known views
 | 
				
			||||||
 | 
					            batch.extend(entries_to_add)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # also add these new entries to the test subset mapping
 | 
				
			||||||
 | 
					            test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return eval_batch_index_out, list(test_subset_mapping_set)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
 | 
					def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
				
			|||||||
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
 | 
				
			|||||||
  test_on_train: false
 | 
					  test_on_train: false
 | 
				
			||||||
  only_test_set: false
 | 
					  only_test_set: false
 | 
				
			||||||
  load_eval_batches: true
 | 
					  load_eval_batches: true
 | 
				
			||||||
 | 
					  n_known_frames_for_test: 0
 | 
				
			||||||
  dataset_class_type: JsonIndexDataset
 | 
					  dataset_class_type: JsonIndexDataset
 | 
				
			||||||
  path_manager_factory_class_type: PathManagerFactory
 | 
					  path_manager_factory_class_type: PathManagerFactory
 | 
				
			||||||
  dataset_JsonIndexDataset_args:
 | 
					  dataset_JsonIndexDataset_args:
 | 
				
			||||||
 | 
				
			|||||||
@ -37,29 +37,47 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
 | 
				
			|||||||
        expand_args_fields(JsonIndexDatasetMapProviderV2)
 | 
					        expand_args_fields(JsonIndexDatasetMapProviderV2)
 | 
				
			||||||
        categories = ["A", "B"]
 | 
					        categories = ["A", "B"]
 | 
				
			||||||
        subset_name = "test"
 | 
					        subset_name = "test"
 | 
				
			||||||
 | 
					        eval_batch_size = 5
 | 
				
			||||||
        with tempfile.TemporaryDirectory() as tmpd:
 | 
					        with tempfile.TemporaryDirectory() as tmpd:
 | 
				
			||||||
            _make_random_json_dataset_map_provider_v2_data(tmpd, categories)
 | 
					            _make_random_json_dataset_map_provider_v2_data(
 | 
				
			||||||
            for category in categories:
 | 
					                tmpd,
 | 
				
			||||||
                dataset_provider = JsonIndexDatasetMapProviderV2(
 | 
					                categories,
 | 
				
			||||||
                    category=category,
 | 
					                eval_batch_size=eval_batch_size,
 | 
				
			||||||
                    subset_name="test",
 | 
					            )
 | 
				
			||||||
                    dataset_root=tmpd,
 | 
					            for n_known_frames_for_test in [0, 2]:
 | 
				
			||||||
                )
 | 
					                for category in categories:
 | 
				
			||||||
                dataset_map = dataset_provider.get_dataset_map()
 | 
					                    dataset_provider = JsonIndexDatasetMapProviderV2(
 | 
				
			||||||
                for set_ in ["train", "val", "test"]:
 | 
					                        category=category,
 | 
				
			||||||
                    dataloader = torch.utils.data.DataLoader(
 | 
					                        subset_name="test",
 | 
				
			||||||
                        getattr(dataset_map, set_),
 | 
					                        dataset_root=tmpd,
 | 
				
			||||||
                        batch_size=3,
 | 
					                        n_known_frames_for_test=n_known_frames_for_test,
 | 
				
			||||||
                        shuffle=True,
 | 
					 | 
				
			||||||
                        collate_fn=FrameData.collate,
 | 
					 | 
				
			||||||
                    )
 | 
					                    )
 | 
				
			||||||
                    for _ in dataloader:
 | 
					                    dataset_map = dataset_provider.get_dataset_map()
 | 
				
			||||||
                        pass
 | 
					                    for set_ in ["train", "val", "test"]:
 | 
				
			||||||
                category_to_subset_list = (
 | 
					                        if set_ in ["train", "val"]:
 | 
				
			||||||
                    dataset_provider.get_category_to_subset_name_list()
 | 
					                            dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
                )
 | 
					                                getattr(dataset_map, set_),
 | 
				
			||||||
                category_to_subset_list_ = {c: [subset_name] for c in categories}
 | 
					                                batch_size=3,
 | 
				
			||||||
                self.assertTrue(category_to_subset_list == category_to_subset_list_)
 | 
					                                shuffle=True,
 | 
				
			||||||
 | 
					                                collate_fn=FrameData.collate,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        else:
 | 
				
			||||||
 | 
					                            dataloader = torch.utils.data.DataLoader(
 | 
				
			||||||
 | 
					                                getattr(dataset_map, set_),
 | 
				
			||||||
 | 
					                                batch_sampler=dataset_map[set_].get_eval_batches(),
 | 
				
			||||||
 | 
					                                collate_fn=FrameData.collate,
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					                        for batch in dataloader:
 | 
				
			||||||
 | 
					                            if set_ == "test":
 | 
				
			||||||
 | 
					                                self.assertTrue(
 | 
				
			||||||
 | 
					                                    batch.image_rgb.shape[0]
 | 
				
			||||||
 | 
					                                    == n_known_frames_for_test + eval_batch_size
 | 
				
			||||||
 | 
					                                )
 | 
				
			||||||
 | 
					                    category_to_subset_list = (
 | 
				
			||||||
 | 
					                        dataset_provider.get_category_to_subset_name_list()
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    category_to_subset_list_ = {c: [subset_name] for c in categories}
 | 
				
			||||||
 | 
					                    self.assertTrue(category_to_subset_list == category_to_subset_list_)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _make_random_json_dataset_map_provider_v2_data(
 | 
					def _make_random_json_dataset_map_provider_v2_data(
 | 
				
			||||||
@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
				
			|||||||
    H: int = 50,
 | 
					    H: int = 50,
 | 
				
			||||||
    W: int = 30,
 | 
					    W: int = 30,
 | 
				
			||||||
    subset_name: str = "test",
 | 
					    subset_name: str = "test",
 | 
				
			||||||
 | 
					    eval_batch_size: int = 5,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    os.makedirs(root, exist_ok=True)
 | 
					    os.makedirs(root, exist_ok=True)
 | 
				
			||||||
    category_to_subset_list = {}
 | 
					    category_to_subset_list = {}
 | 
				
			||||||
@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
				
			|||||||
        with open(set_list_file, "w") as f:
 | 
					        with open(set_list_file, "w") as f:
 | 
				
			||||||
            json.dump(set_list, f)
 | 
					            json.dump(set_list, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
 | 
					        eval_batches = [
 | 
				
			||||||
 | 
					            random.sample(test_frame_index, eval_batch_size) for _ in range(10)
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        eval_b_dir = os.path.join(root, category, "eval_batches")
 | 
					        eval_b_dir = os.path.join(root, category, "eval_batches")
 | 
				
			||||||
        os.makedirs(eval_b_dir, exist_ok=True)
 | 
					        os.makedirs(eval_b_dir, exist_ok=True)
 | 
				
			||||||
        eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
 | 
					        eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user