mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
more padding options in Dataloader
Summary: Add facilities for dataloading non-sequential scenes. Reviewed By: shapovalov Differential Revision: D37291277 fbshipit-source-id: 0a33e3727b44c4f0cba3a2abe9b12f40d2a20447
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0dce883241
commit
771cf8a328
@@ -8,6 +8,11 @@
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from itertools import product
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
DoublePoolBatchSampler,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
@@ -214,3 +219,30 @@ def _count_by_quotient(indices, divisor):
|
||||
counter[i // divisor] += 1
|
||||
|
||||
return counter
|
||||
|
||||
|
||||
class TestRandomSampling(unittest.TestCase):
|
||||
def test_double_pool_batch_sampler(self):
|
||||
unknown_idxs = [2, 3, 4, 5, 8]
|
||||
known_idxs = [2, 9, 10, 11, 12, 13, 14, 15, 16, 17]
|
||||
for replacement, num_batches in product([True, False], [None, 4, 5, 6, 30]):
|
||||
with self.subTest(f"{replacement}, {num_batches}"):
|
||||
sampler = DoublePoolBatchSampler(
|
||||
first_indices=unknown_idxs,
|
||||
rest_indices=known_idxs,
|
||||
batch_size=4,
|
||||
replacement=replacement,
|
||||
num_batches=num_batches,
|
||||
)
|
||||
for _ in range(6):
|
||||
epoch = list(sampler)
|
||||
self.assertEqual(len(epoch), num_batches or len(unknown_idxs))
|
||||
for batch in epoch:
|
||||
self.assertEqual(len(batch), 4)
|
||||
self.assertIn(batch[0], unknown_idxs)
|
||||
for i in batch[1:]:
|
||||
self.assertIn(i, known_idxs)
|
||||
if not replacement and 4 != num_batches:
|
||||
self.assertEqual(
|
||||
{batch[0] for batch in epoch}, set(unknown_idxs)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user