mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fixes to CO3Dv2 provider.
Summary: 1. Random sampling of num batches without replacement not supported. 2.Providers should implement the interface for the training loop to work. Reviewed By: bottler, davnov134 Differential Revision: D37815388 fbshipit-source-id: 8a2795b524e733f07346ffdb20a9c0eb1a2b8190
This commit is contained in:
parent
b95ec190af
commit
36ba079bef
@ -335,6 +335,47 @@ data_source_args:
|
||||
sort_frames: false
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||
category: ???
|
||||
subset_name: ???
|
||||
dataset_root: ''
|
||||
test_on_train: false
|
||||
only_test_set: false
|
||||
load_eval_batches: true
|
||||
dataset_class_type: JsonIndexDataset
|
||||
path_manager_factory_class_type: PathManagerFactory
|
||||
dataset_JsonIndexDataset_args:
|
||||
path_manager: null
|
||||
frame_annotations_file: ''
|
||||
sequence_annotations_file: ''
|
||||
subset_lists_file: ''
|
||||
subsets: null
|
||||
limit_to: 0
|
||||
limit_sequences_to: 0
|
||||
pick_sequence: []
|
||||
exclude_sequence: []
|
||||
limit_category_to: []
|
||||
dataset_root: ''
|
||||
load_images: true
|
||||
load_depths: true
|
||||
load_depth_masks: true
|
||||
load_masks: true
|
||||
load_point_clouds: false
|
||||
max_points: 0
|
||||
mask_images: false
|
||||
mask_depths: false
|
||||
image_height: 800
|
||||
image_width: 800
|
||||
box_crop: true
|
||||
box_crop_mask_thr: 0.4
|
||||
box_crop_context: 0.3
|
||||
remove_empty_masks: true
|
||||
n_frames_per_sequence: -1
|
||||
seed: 0
|
||||
sort_frames: false
|
||||
eval_batches: null
|
||||
path_manager_factory_PathManagerFactory_args:
|
||||
silence_logs: true
|
||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||
base_dir: ???
|
||||
object_name: ???
|
||||
|
@ -354,9 +354,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
"""
|
||||
if num_batches > 0:
|
||||
num_samples = self.batch_size * num_batches
|
||||
replacement = True
|
||||
else:
|
||||
num_samples = None
|
||||
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
|
||||
replacement = False
|
||||
sampler = RandomSampler(
|
||||
dataset, replacement=replacement, num_samples=num_samples
|
||||
)
|
||||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
|
@ -17,6 +17,7 @@ from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||
|
||||
|
||||
|
@ -9,12 +9,13 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Type
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
PathManagerFactory,
|
||||
Task,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
@ -23,6 +24,8 @@ from pytorch3d.implicitron.tools.config import (
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
|
||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||
|
||||
@ -296,6 +299,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
)
|
||||
return category_to_subset_name_list
|
||||
|
||||
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
||||
return {
|
||||
"manyview": Task.SINGLE_SEQUENCE,
|
||||
"fewview": Task.MULTI_SEQUENCE,
|
||||
}[self.subset_name.split("_")[0]]
|
||||
|
||||
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||
# pyre-ignore[16]
|
||||
train_dataset = self.dataset_map.train
|
||||
assert isinstance(train_dataset, JsonIndexDataset)
|
||||
return train_dataset.get_all_train_cameras()
|
||||
|
||||
def _load_annotation_json(self, json_filename: str):
|
||||
full_path = os.path.join(
|
||||
self.dataset_root,
|
||||
|
Loading…
x
Reference in New Issue
Block a user