mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Fixes to CO3Dv2 provider.
Summary: 1. Random sampling of num batches without replacement not supported. 2.Providers should implement the interface for the training loop to work. Reviewed By: bottler, davnov134 Differential Revision: D37815388 fbshipit-source-id: 8a2795b524e733f07346ffdb20a9c0eb1a2b8190
This commit is contained in:
parent
b95ec190af
commit
36ba079bef
@ -335,6 +335,47 @@ data_source_args:
|
|||||||
sort_frames: false
|
sort_frames: false
|
||||||
path_manager_factory_PathManagerFactory_args:
|
path_manager_factory_PathManagerFactory_args:
|
||||||
silence_logs: true
|
silence_logs: true
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
|
||||||
|
category: ???
|
||||||
|
subset_name: ???
|
||||||
|
dataset_root: ''
|
||||||
|
test_on_train: false
|
||||||
|
only_test_set: false
|
||||||
|
load_eval_batches: true
|
||||||
|
dataset_class_type: JsonIndexDataset
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
dataset_JsonIndexDataset_args:
|
||||||
|
path_manager: null
|
||||||
|
frame_annotations_file: ''
|
||||||
|
sequence_annotations_file: ''
|
||||||
|
subset_lists_file: ''
|
||||||
|
subsets: null
|
||||||
|
limit_to: 0
|
||||||
|
limit_sequences_to: 0
|
||||||
|
pick_sequence: []
|
||||||
|
exclude_sequence: []
|
||||||
|
limit_category_to: []
|
||||||
|
dataset_root: ''
|
||||||
|
load_images: true
|
||||||
|
load_depths: true
|
||||||
|
load_depth_masks: true
|
||||||
|
load_masks: true
|
||||||
|
load_point_clouds: false
|
||||||
|
max_points: 0
|
||||||
|
mask_images: false
|
||||||
|
mask_depths: false
|
||||||
|
image_height: 800
|
||||||
|
image_width: 800
|
||||||
|
box_crop: true
|
||||||
|
box_crop_mask_thr: 0.4
|
||||||
|
box_crop_context: 0.3
|
||||||
|
remove_empty_masks: true
|
||||||
|
n_frames_per_sequence: -1
|
||||||
|
seed: 0
|
||||||
|
sort_frames: false
|
||||||
|
eval_batches: null
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
dataset_map_provider_LlffDatasetMapProvider_args:
|
dataset_map_provider_LlffDatasetMapProvider_args:
|
||||||
base_dir: ???
|
base_dir: ???
|
||||||
object_name: ???
|
object_name: ???
|
||||||
|
@ -354,9 +354,13 @@ class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
|||||||
"""
|
"""
|
||||||
if num_batches > 0:
|
if num_batches > 0:
|
||||||
num_samples = self.batch_size * num_batches
|
num_samples = self.batch_size * num_batches
|
||||||
|
replacement = True
|
||||||
else:
|
else:
|
||||||
num_samples = None
|
num_samples = None
|
||||||
sampler = RandomSampler(dataset, replacement=False, num_samples=num_samples)
|
replacement = False
|
||||||
|
sampler = RandomSampler(
|
||||||
|
dataset, replacement=replacement, num_samples=num_samples
|
||||||
|
)
|
||||||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -17,6 +17,7 @@ from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
|
|||||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||||
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
|
||||||
|
from .json_index_dataset_map_provider_v2 import JsonIndexDatasetMapProviderV2 # noqa
|
||||||
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,12 +9,13 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Type
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
PathManagerFactory,
|
PathManagerFactory,
|
||||||
|
Task,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -23,6 +24,8 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
|
||||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
|
||||||
|
|
||||||
@ -296,6 +299,18 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
)
|
)
|
||||||
return category_to_subset_name_list
|
return category_to_subset_name_list
|
||||||
|
|
||||||
|
def get_task(self) -> Task: # TODO: we plan to get rid of tasks
|
||||||
|
return {
|
||||||
|
"manyview": Task.SINGLE_SEQUENCE,
|
||||||
|
"fewview": Task.MULTI_SEQUENCE,
|
||||||
|
}[self.subset_name.split("_")[0]]
|
||||||
|
|
||||||
|
def get_all_train_cameras(self) -> Optional[CamerasBase]:
|
||||||
|
# pyre-ignore[16]
|
||||||
|
train_dataset = self.dataset_map.train
|
||||||
|
assert isinstance(train_dataset, JsonIndexDataset)
|
||||||
|
return train_dataset.get_all_train_cameras()
|
||||||
|
|
||||||
def _load_annotation_json(self, json_filename: str):
|
def _load_annotation_json(self, json_filename: str):
|
||||||
full_path = os.path.join(
|
full_path = os.path.join(
|
||||||
self.dataset_root,
|
self.dataset_root,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user