more padding options in Dataloader

Summary: Add facilities for dataloading non-sequential scenes.

Reviewed By: shapovalov

Differential Revision: D37291277

fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
Jeremy Reizenstein
2022-07-06 07:13:41 -07:00
committed by Facebook GitHub Bot
parent 0dce883241
commit 771cf8a328
16 changed files with 449 additions and 70 deletions

View File

@@ -10,11 +10,12 @@ import unittest
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
BlenderDatasetMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
LlffDatasetMapProvider,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from tests.common_testing import TestCaseMixin
@@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
for i in batch[1:]:
self.assertEqual(dataset_map.test[i].frame_type, "known")
def test_loaders(self):
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
dataset_args.object_name = "lego"
dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
data_source = ImplicitronDataSource(**args)
_, data_loaders = data_source.get_datasets_and_dataloaders()
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["known"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
for i in data_loaders.val:
self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
for i in data_loaders.test:
self.assertEqual(i.frame_type, ["unseen"])
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))