# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import unittest from collections import defaultdict from dataclasses import dataclass from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDatasetBase from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler @dataclass class MockFrameAnnotation: frame_number: int frame_timestamp: float = 0.0 class MockDataset(ImplicitronDatasetBase): def __init__(self, num_seq, max_frame_gap=1): """ Makes a gap of max_frame_gap frame numbers in the middle of each sequence """ self.seq_annots = {f"seq_{i}": None for i in range(num_seq)} self.seq_to_idx = { f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq) } # frame numbers within sequence: [0, ..., 4, n, ..., n+4] # where n - 4 == max_frame_gap frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap)) self.frame_annots = [ {"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq ] def get_frame_numbers_and_timestamps(self, idxs): out = [] for idx in idxs: frame_annotation = self.frame_annots[idx]["frame_annotation"] out.append( (frame_annotation.frame_number, frame_annotation.frame_timestamp) ) return out class TestSceneBatchSampler(unittest.TestCase): def setUp(self): self.dataset_overfit = MockDataset(1) def test_overfit(self): num_batches = 3 batch_size = 10 sampler = SceneBatchSampler( self.dataset_overfit, batch_size=batch_size, num_batches=num_batches, images_per_seq_options=[10], # will try to sample batch_size anyway ) self.assertEqual(len(sampler), num_batches) it = iter(sampler) for _ in range(num_batches): batch = next(it) self.assertIsNotNone(batch) self.assertEqual(len(batch), batch_size) # true for our examples self.assertTrue(all(idx // 10 == 0 for idx in batch)) with self.assertRaises(StopIteration): batch = next(it) def test_multiseq(self): for ips_options in [[10], [2], [3], [2, 3, 4]]: for sample_consecutive_frames in [True, False]: for consecutive_frames_max_gap in [0, 1, 3]: self._test_multiseq_flavour( ips_options, sample_consecutive_frames, consecutive_frames_max_gap, ) def test_multiseq_gaps(self): num_batches = 16 batch_size = 10 dataset_multiseq = MockDataset(5, max_frame_gap=3) for ips_options in [[10], [2], [3], [2, 3, 4]]: debug_info = f" Images per sequence: {ips_options}." sampler = SceneBatchSampler( dataset_multiseq, batch_size=batch_size, num_batches=num_batches, images_per_seq_options=ips_options, sample_consecutive_frames=True, consecutive_frames_max_gap=1, ) self.assertEqual(len(sampler), num_batches, msg=debug_info) it = iter(sampler) for _ in range(num_batches): batch = next(it) self.assertIsNotNone(batch, "batch is None in" + debug_info) if max(ips_options) > 5: # true for our examples self.assertEqual(len(batch), 5, msg=debug_info) else: # true for our examples self.assertEqual(len(batch), batch_size, msg=debug_info) self._check_frames_are_consecutive( batch, dataset_multiseq.frame_annots, debug_info ) def _test_multiseq_flavour( self, ips_options, sample_consecutive_frames, consecutive_frames_max_gap, num_batches=16, batch_size=10, ): debug_info = ( f" Images per sequence: {ips_options}, " f"sample_consecutive_frames: {sample_consecutive_frames}, " f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, " ) # in this test, either consecutive_frames_max_gap == max_frame_gap, # or consecutive_frames_max_gap == 0, so segments consist of full sequences frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3 dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap) sampler = SceneBatchSampler( dataset_multiseq, batch_size=batch_size, num_batches=num_batches, images_per_seq_options=ips_options, sample_consecutive_frames=sample_consecutive_frames, consecutive_frames_max_gap=consecutive_frames_max_gap, ) self.assertEqual(len(sampler), num_batches, msg=debug_info) it = iter(sampler) typical_counts = set() for _ in range(num_batches): batch = next(it) self.assertIsNotNone(batch, "batch is None in" + debug_info) # true for our examples self.assertEqual(len(batch), batch_size, msg=debug_info) # find distribution over sequences counts = _count_by_quotient(batch, 10) freqs = _count_by_quotient(counts.values(), 1) self.assertLessEqual( len(freqs), 2, msg="We should have maximum of 2 different " "frequences of sequences in the batch." + debug_info, ) if len(freqs) == 2: most_seq_count = max(*freqs.keys()) last_seq = min(*freqs.keys()) self.assertEqual( freqs[last_seq], 1, msg="Only one odd sequence allowed." + debug_info, ) else: self.assertEqual(len(freqs), 1) most_seq_count = next(iter(freqs)) self.assertIn(most_seq_count, ips_options) typical_counts.add(most_seq_count) if sample_consecutive_frames: self._check_frames_are_consecutive( batch, dataset_multiseq.frame_annots, debug_info, max_gap=consecutive_frames_max_gap, ) self.assertTrue( all(i in typical_counts for i in ips_options), "Some of the frequency options did not occur among " f"the {num_batches} batches (could be just bad luck)." + debug_info, ) with self.assertRaises(StopIteration): batch = next(it) def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1): # make sure that sampled frames are consecutive for i in range(len(batch) - 1): curr_idx, next_idx = batch[i : i + 2] if curr_idx // 10 == next_idx // 10: # same sequence if max_gap > 0: curr_idx, next_idx = [ annots[idx]["frame_annotation"].frame_number for idx in (curr_idx, next_idx) ] gap = max_gap else: gap = 1 # we'll check that raw dataset indices are consecutive self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info) def _count_by_quotient(indices, divisor): counter = defaultdict(int) for i in indices: counter[i // divisor] += 1 return counter