Fixes to CO3Dv2 provider.

Summary:
1. Random sampling of num batches without replacement not supported.
2.Providers should implement the interface for the training loop to work.

Reviewed By: bottler, davnov134

Differential Revision: D37815388

fbshipit-source-id: 8a2795b524e733f07346ffdb20a9c0eb1a2b8190
This commit is contained in:
Roman Shapovalov
2022-07-13 09:45:29 -07:00
committed by Facebook GitHub Bot
parent b95ec190af
commit 36ba079bef
4 changed files with 63 additions and 2 deletions

View File

@@ -354,9 +354,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
if num_batches > 0:
num_samples = self.batch_size * num_batches
replacement = True
else:
num_samples = None
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
replacement = False
sampler = RandomSampler(
dataset, replacement=replacement, num_samples=num_samples
)
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
return DataLoader(
dataset,

View File

@@ -17,6 +17,7 @@ from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa

View File

@@ -9,12 +9,13 @@ import json
import logging
import os
import warnings
from typing import Dict, List, Type
from typing import Dict, List, Optional, Type
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import (
@@ -23,6 +24,8 @@ from pytorch3d.implicitron.tools.config import (
run_auto_creation,
)
from pytorch3d.renderer.cameras import CamerasBase
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
@@ -296,6 +299,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
)
return category_to_subset_name_list
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
return {
"manyview": Task.SINGLE_SEQUENCE,
"fewview": Task.MULTI_SEQUENCE,
}[self.subset_name.split("_")[0]]
def get_all_train_cameras(self) -> Optional[CamerasBase]:
# pyre-ignore[16]
train_dataset = self.dataset_map.train
assert isinstance(train_dataset, JsonIndexDataset)
return train_dataset.get_all_train_cameras()
def _load_annotation_json(self, json_filename: str):
full_path = os.path.join(
self.dataset_root,