Fixes to CO3Dv2 provider.

Summary:
1. Random sampling of num batches without replacement not supported.
2.Providers should implement the interface for the training loop to work.

Reviewed By: bottler, davnov134

Differential Revision: D37815388

fbshipit-source-id: 8a2795b524e733f07346ffdb20a9c0eb1a2b8190
This commit is contained in:
Roman Shapovalov 2022-07-13 09:45:29 -07:00 committed by Facebook GitHub Bot
parent b95ec190af
commit 36ba079bef
4 changed files with 63 additions and 2 deletions

View File

@ -335,6 +335,47 @@ data_source_args:
sort_frames: false
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
category: ???
subset_name: ???
dataset_root: ''
test_on_train: false
only_test_set: false
load_eval_batches: true
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
path_manager: null
frame_annotations_file: ''
sequence_annotations_file: ''
subset_lists_file: ''
subsets: null
limit_to: 0
limit_sequences_to: 0
pick_sequence: []
exclude_sequence: []
limit_category_to: []
dataset_root: ''
load_images: true
load_depths: true
load_depth_masks: true
load_masks: true
load_point_clouds: false
max_points: 0
mask_images: false
mask_depths: false
image_height: 800
image_width: 800
box_crop: true
box_crop_mask_thr: 0.4
box_crop_context: 0.3
remove_empty_masks: true
n_frames_per_sequence: -1
seed: 0
sort_frames: false
eval_batches: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
dataset_map_provider_LlffDatasetMapProvider_args:
base_dir: ???
object_name: ???

View File

@ -354,9 +354,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
if num_batches > 0:
num_samples = self.batch_size * num_batches
replacement = True
else:
num_samples = None
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
replacement = False
sampler = RandomSampler(
dataset, replacement=replacement, num_samples=num_samples
)
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
return DataLoader(
dataset,

View File

@ -17,6 +17,7 @@ from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa

View File

@ -9,12 +9,13 @@ import json
import logging
import os
import warnings
from typing import Dict, List, Type
from typing import Dict, List, Optional, Type
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import (
@ -23,6 +24,8 @@ from pytorch3d.implicitron.tools.config import (
run_auto_creation,
)
from pytorch3d.renderer.cameras import CamerasBase
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
@ -296,6 +299,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
)
return category_to_subset_name_list
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
return {
"manyview": Task.SINGLE_SEQUENCE,
"fewview": Task.MULTI_SEQUENCE,
}[self.subset_name.split("_")[0]]
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset)
return train_dataset.get_all_train_cameras()
def _load_annotation_json(self, json_filename: str):
full_path = os.path.join(
self.dataset_root,