mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Enable additional test-time source views for json dataset provider v2
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge Reviewed By: bottler Differential Revision: D38705904 fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
This commit is contained in:
		
							parent
							
								
									e8616cc8ba
								
							
						
					
					
						commit
						2ff2c7c836
					
				@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
    test_on_train: false
 | 
			
		||||
    only_test_set: false
 | 
			
		||||
    load_eval_batches: true
 | 
			
		||||
    n_known_frames_for_test: 0
 | 
			
		||||
    dataset_class_type: JsonIndexDataset
 | 
			
		||||
    path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
    dataset_JsonIndexDataset_args:
 | 
			
		||||
 | 
			
		||||
@ -5,11 +5,15 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import copy
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Type
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from typing import Dict, List, Optional, Tuple, Type, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
			
		||||
@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
        only_test_set: Load only the test set. Incompatible with `test_on_train`.
 | 
			
		||||
        load_eval_batches: Load the file containing eval batches pointing to the
 | 
			
		||||
            test dataset.
 | 
			
		||||
        n_known_frames_for_test: Add a certain number of known frames to each
 | 
			
		||||
            eval batch. Useful for evaluating models that require
 | 
			
		||||
            source views as input (e.g. NeRF-WCE / PixelNeRF).
 | 
			
		||||
        dataset_args: Specifies additional arguments to the
 | 
			
		||||
            JsonIndexDataset constructor call.
 | 
			
		||||
        path_manager_factory: (Optional) An object that generates an instance of
 | 
			
		||||
@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
    only_test_set: bool = False
 | 
			
		||||
    load_eval_batches: bool = True
 | 
			
		||||
 | 
			
		||||
    n_known_frames_for_test: int = 0
 | 
			
		||||
 | 
			
		||||
    dataset_class_type: str = "JsonIndexDataset"
 | 
			
		||||
    dataset: JsonIndexDataset
 | 
			
		||||
 | 
			
		||||
@ -264,6 +273,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
            val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
 | 
			
		||||
            logger.info(f"Val dataset: {str(val_dataset)}")
 | 
			
		||||
            logger.debug("Extracting test dataset.")
 | 
			
		||||
 | 
			
		||||
            if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
 | 
			
		||||
                # extend the test subset mapping and the dataset with additional
 | 
			
		||||
                # known views from the train dataset
 | 
			
		||||
                (
 | 
			
		||||
                    eval_batch_index,
 | 
			
		||||
                    subset_mapping["test"],
 | 
			
		||||
                ) = self._extend_test_data_with_known_views(
 | 
			
		||||
                    subset_mapping,
 | 
			
		||||
                    eval_batch_index,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
 | 
			
		||||
            logger.info(f"Test dataset: {str(test_dataset)}")
 | 
			
		||||
            if self.load_eval_batches:
 | 
			
		||||
@ -369,6 +390,40 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
            dataset_root = self.dataset_root
 | 
			
		||||
        return get_available_subset_names(dataset_root, self.category)
 | 
			
		||||
 | 
			
		||||
    def _extend_test_data_with_known_views(
 | 
			
		||||
        self,
 | 
			
		||||
        subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
 | 
			
		||||
        eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
 | 
			
		||||
    ):
 | 
			
		||||
        # convert the train subset mapping to a dict:
 | 
			
		||||
        #   sequence_to_train_frames: {sequence_name: frame_index}
 | 
			
		||||
        sequence_to_train_frames = defaultdict(list)
 | 
			
		||||
        for frame_entry in subset_mapping["train"]:
 | 
			
		||||
            sequence_name = frame_entry[0]
 | 
			
		||||
            sequence_to_train_frames[sequence_name].append(frame_entry)
 | 
			
		||||
        sequence_to_train_frames = dict(sequence_to_train_frames)
 | 
			
		||||
        test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}
 | 
			
		||||
 | 
			
		||||
        # extend the eval batches / subset mapping with the additional examples
 | 
			
		||||
        eval_batch_index_out = copy.deepcopy(eval_batch_index)
 | 
			
		||||
        generator = np.random.default_rng(seed=0)
 | 
			
		||||
        for batch in eval_batch_index_out:
 | 
			
		||||
            sequence_name = batch[0][0]
 | 
			
		||||
            sequence_known_entries = sequence_to_train_frames[sequence_name]
 | 
			
		||||
            idx_to_add = generator.permutation(len(sequence_known_entries))[
 | 
			
		||||
                : self.n_known_frames_for_test
 | 
			
		||||
            ]
 | 
			
		||||
            entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
 | 
			
		||||
            assert all(e in subset_mapping["train"] for e in entries_to_add)
 | 
			
		||||
 | 
			
		||||
            # extend the eval batch with the known views
 | 
			
		||||
            batch.extend(entries_to_add)
 | 
			
		||||
 | 
			
		||||
            # also add these new entries to the test subset mapping
 | 
			
		||||
            test_subset_mapping_set.update(tuple(e) for e in entries_to_add)
 | 
			
		||||
 | 
			
		||||
        return eval_batch_index_out, list(test_subset_mapping_set)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
 | 
			
		||||
  test_on_train: false
 | 
			
		||||
  only_test_set: false
 | 
			
		||||
  load_eval_batches: true
 | 
			
		||||
  n_known_frames_for_test: 0
 | 
			
		||||
  dataset_class_type: JsonIndexDataset
 | 
			
		||||
  path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
  dataset_JsonIndexDataset_args:
 | 
			
		||||
 | 
			
		||||
@ -37,24 +37,42 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
 | 
			
		||||
        expand_args_fields(JsonIndexDatasetMapProviderV2)
 | 
			
		||||
        categories = ["A", "B"]
 | 
			
		||||
        subset_name = "test"
 | 
			
		||||
        eval_batch_size = 5
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmpd:
 | 
			
		||||
            _make_random_json_dataset_map_provider_v2_data(tmpd, categories)
 | 
			
		||||
            _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
                tmpd,
 | 
			
		||||
                categories,
 | 
			
		||||
                eval_batch_size=eval_batch_size,
 | 
			
		||||
            )
 | 
			
		||||
            for n_known_frames_for_test in [0, 2]:
 | 
			
		||||
                for category in categories:
 | 
			
		||||
                    dataset_provider = JsonIndexDatasetMapProviderV2(
 | 
			
		||||
                        category=category,
 | 
			
		||||
                        subset_name="test",
 | 
			
		||||
                        dataset_root=tmpd,
 | 
			
		||||
                        n_known_frames_for_test=n_known_frames_for_test,
 | 
			
		||||
                    )
 | 
			
		||||
                    dataset_map = dataset_provider.get_dataset_map()
 | 
			
		||||
                    for set_ in ["train", "val", "test"]:
 | 
			
		||||
                        if set_ in ["train", "val"]:
 | 
			
		||||
                            dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
                                getattr(dataset_map, set_),
 | 
			
		||||
                                batch_size=3,
 | 
			
		||||
                                shuffle=True,
 | 
			
		||||
                                collate_fn=FrameData.collate,
 | 
			
		||||
                            )
 | 
			
		||||
                    for _ in dataloader:
 | 
			
		||||
                        pass
 | 
			
		||||
                        else:
 | 
			
		||||
                            dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
                                getattr(dataset_map, set_),
 | 
			
		||||
                                batch_sampler=dataset_map[set_].get_eval_batches(),
 | 
			
		||||
                                collate_fn=FrameData.collate,
 | 
			
		||||
                            )
 | 
			
		||||
                        for batch in dataloader:
 | 
			
		||||
                            if set_ == "test":
 | 
			
		||||
                                self.assertTrue(
 | 
			
		||||
                                    batch.image_rgb.shape[0]
 | 
			
		||||
                                    == n_known_frames_for_test + eval_batch_size
 | 
			
		||||
                                )
 | 
			
		||||
                    category_to_subset_list = (
 | 
			
		||||
                        dataset_provider.get_category_to_subset_name_list()
 | 
			
		||||
                    )
 | 
			
		||||
@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
    H: int = 50,
 | 
			
		||||
    W: int = 30,
 | 
			
		||||
    subset_name: str = "test",
 | 
			
		||||
    eval_batch_size: int = 5,
 | 
			
		||||
):
 | 
			
		||||
    os.makedirs(root, exist_ok=True)
 | 
			
		||||
    category_to_subset_list = {}
 | 
			
		||||
@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
        with open(set_list_file, "w") as f:
 | 
			
		||||
            json.dump(set_list, f)
 | 
			
		||||
 | 
			
		||||
        eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
 | 
			
		||||
        eval_batches = [
 | 
			
		||||
            random.sample(test_frame_index, eval_batch_size) for _ in range(10)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        eval_b_dir = os.path.join(root, category, "eval_batches")
 | 
			
		||||
        os.makedirs(eval_b_dir, exist_ok=True)
 | 
			
		||||
        eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user