mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
more padding options in Dataloader
Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0dce883241
commit
771cf8a328
@@ -10,11 +10,12 @@ import unittest
|
||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
|
||||
BlenderDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
|
||||
LlffDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@@ -102,3 +103,23 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
|
||||
for i in batch[1:]:
|
||||
self.assertEqual(dataset_map.test[i].frame_type, "known")
|
||||
|
||||
def test_loaders(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "BlenderDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "RandomDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_BlenderDatasetMapProvider_args
|
||||
dataset_args.object_name = "lego"
|
||||
dataset_args.base_dir = "manifold://co3d/tree/nerf_data/nerf_synthetic/lego"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["known"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
for i in data_loaders.val:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
for i in data_loaders.test:
|
||||
self.assertEqual(i.frame_type, ["unseen"])
|
||||
self.assertEqual(i.image_rgb.shape, (1, 3, 800, 800))
|
||||
|
||||
Reference in New Issue
Block a user