dataset_map_provider

Summary: replace dataset_zoo with a pluggable DatasetMapProvider. The logic is now in annotated_file_dataset_map_provider.

Reviewed By: shapovalov

Differential Revision: D36443965

fbshipit-source-id: 9087649802810055e150b2fbfcc3c197a761f28a
This commit is contained in:
Jeremy Reizenstein
2022-05-20 07:50:30 -07:00
committed by Facebook GitHub Bot
parent 69c6d06ed8
commit 79c61a2d86
15 changed files with 305 additions and 175 deletions

View File

@@ -66,7 +66,7 @@ If you have a custom `experiment.py` script (as in the Option 2 above), replace
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
dataset_args=data_source_args.dataset_args
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
```
@@ -85,7 +85,7 @@ To run evaluation on the latest checkpoint after (or during) training, simply ad
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
dataset_args=data_source_args.dataset_args
dataset_args=data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf $dataset_args.dataset_root=<CO3D_DATASET_ROOT> $dataset_args.category='skateboard' $dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
@@ -236,7 +236,7 @@ generic_model_args: GenericModel
╘== ReductionFeatureAggregator
solver_args: init_optimizer
data_source_args: ImplicitronDataSource
└-- dataset_args
└-- dataset_map_provider_*_args
└-- dataloader_args
```

View File

@@ -6,6 +6,7 @@ architecture: generic
visualize_interval: 0
visdom_port: 8097
data_source_args:
dataset_provider_class_type: JsonIndexDatasetMapProvider
dataloader_args:
batch_size: 10
dataset_len: 1000
@@ -21,7 +22,7 @@ data_source_args:
- 8
- 9
- 10
dataset_args:
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
load_point_clouds: false
mask_depths: false

View File

@@ -17,9 +17,9 @@ data_source_args:
- 8
- 9
- 10
dataset_args:
dataset_map_provider_JsonIndexDatasetMapProvider_args:
assert_single_seq: false
dataset_name: co3d_multisequence
task_str: multisequence
load_point_clouds: false
mask_depths: false
mask_images: false

View File

@@ -9,8 +9,8 @@ data_source_args:
num_workers: 8
images_per_seq_options:
- 2
dataset_args:
dataset_name: co3d_singlesequence
dataset_map_provider_JsonIndexDatasetMapProvider_args:
dataset_name: singlesequence
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0

View File

@@ -67,7 +67,7 @@ from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_zoo import Datasets
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
@@ -552,7 +552,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
def _eval_and_dump(
cfg,
task: Task,
datasets: Datasets,
datasets: DatasetMap,
dataloaders: Dataloaders,
model,
stats,

View File

@@ -24,8 +24,7 @@ import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData, ImplicitronDatasetBase
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode
@@ -296,7 +295,7 @@ def export_scenes(
output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test
split: str = "train", # train | val | test
n_source_views: int = 9,
n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1",
@@ -324,14 +323,15 @@ def export_scenes(
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.data_source_args.dataset_args.test_on_train = False
dataset_args = (
config.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.data_source_args.dataset_args.restrict_sequence_name = (
restrict_sequence_name
)
dataset_args.restrict_sequence_name = restrict_sequence_name
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@@ -344,8 +344,8 @@ def export_scenes(
# Setup the dataset
datasource = ImplicitronDataSource(**config.data_source_args)
datasets = dataset_zoo(**datasource.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
dataset_map = datasource.dataset_map_provider.get_dataset_map()
dataset = dataset_map[split]
if dataset is None:
raise ValueError(f"{split} dataset not provided")