data_source

Summary:
Move dataset_args and dataloader_args from ExperimentConfig into a new member called datasource so that it can contain replaceables.

Also add enum Task for task type.

Reviewed By: shapovalov

Differential Revision: D36201719

fbshipit-source-id: 47d6967bfea3b7b146b6bbd1572e0457c9365871
This commit is contained in:
Jeremy Reizenstein
2022-05-20 07:50:30 -07:00
committed by Facebook GitHub Bot
parent 9ec9d057cc
commit 73dc109dba
10 changed files with 194 additions and 124 deletions

View File

@@ -23,6 +23,7 @@ import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
@@ -326,12 +327,14 @@ def export_scenes(
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False
config.data_source_args.dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name
config.data_source_args.dataset_args.restrict_sequence_name = (
restrict_sequence_name
)
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
@@ -343,7 +346,8 @@ def export_scenes(
model.eval()
# Setup the dataset
datasets = dataset_zoo(**config.dataset_args)
datasource = ImplicitronDataSource(**config.data_source_args)
datasets = dataset_zoo(**datasource.dataset_args)
dataset: Optional[ImplicitronDatasetBase] = getattr(datasets, split, None)
if dataset is None:
raise ValueError(f"{split} dataset not provided")