pytorch3d/tests/implicitron/test_batch_sampler.py
Jeremy Reizenstein cdd2142dd5
implicitron v0 (#1133)
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
2022-03-21 13:20:10 -07:00

216 lines
7.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
frame_timestamp: float = 0.0
class MockDataset:
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
"""
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
self.seq_to_idx = {
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
}
# frame numbers within sequence: [0, ..., 4, n, ..., n+4]
# where n - 4 == max_frame_gap
frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap))
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
def get_frame_numbers_and_timestamps(self, idxs):
out = []
for idx in idxs:
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):
self.dataset_overfit = MockDataset(1)
def test_overfit(self):
num_batches = 3
batch_size = 10
sampler = SceneBatchSampler(
self.dataset_overfit,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=[10], # will try to sample batch_size anyway
)
self.assertEqual(len(sampler), num_batches)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch)
self.assertEqual(len(batch), batch_size) # true for our examples
self.assertTrue(all(idx // 10 == 0 for idx in batch))
with self.assertRaises(StopIteration):
batch = next(it)
def test_multiseq(self):
for ips_options in [[10], [2], [3], [2, 3, 4]]:
for sample_consecutive_frames in [True, False]:
for consecutive_frames_max_gap in [0, 1, 3]:
self._test_multiseq_flavour(
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
)
def test_multiseq_gaps(self):
num_batches = 16
batch_size = 10
dataset_multiseq = MockDataset(5, max_frame_gap=3)
for ips_options in [[10], [2], [3], [2, 3, 4]]:
debug_info = f" Images per sequence: {ips_options}."
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=True,
consecutive_frames_max_gap=1,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
if max(ips_options) > 5:
# true for our examples
self.assertEqual(len(batch), 5, msg=debug_info)
else:
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
self._check_frames_are_consecutive(
batch, dataset_multiseq.frame_annots, debug_info
)
def _test_multiseq_flavour(
self,
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
num_batches=16,
batch_size=10,
):
debug_info = (
f" Images per sequence: {ips_options}, "
f"sample_consecutive_frames: {sample_consecutive_frames}, "
f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, "
)
# in this test, either consecutive_frames_max_gap == max_frame_gap,
# or consecutive_frames_max_gap == 0, so segments consist of full sequences
frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3
dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap)
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
typical_counts = set()
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
# find distribution over sequences
counts = _count_by_quotient(batch, 10)
freqs = _count_by_quotient(counts.values(), 1)
self.assertLessEqual(
len(freqs),
2,
msg="We should have maximum of 2 different "
"frequences of sequences in the batch." + debug_info,
)
if len(freqs) == 2:
most_seq_count = max(*freqs.keys())
last_seq = min(*freqs.keys())
self.assertEqual(
freqs[last_seq],
1,
msg="Only one odd sequence allowed." + debug_info,
)
else:
self.assertEqual(len(freqs), 1)
most_seq_count = next(iter(freqs))
self.assertIn(most_seq_count, ips_options)
typical_counts.add(most_seq_count)
if sample_consecutive_frames:
self._check_frames_are_consecutive(
batch,
dataset_multiseq.frame_annots,
debug_info,
max_gap=consecutive_frames_max_gap,
)
self.assertTrue(
all(i in typical_counts for i in ips_options),
"Some of the frequency options did not occur among "
f"the {num_batches} batches (could be just bad luck)." + debug_info,
)
with self.assertRaises(StopIteration):
batch = next(it)
def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1):
# make sure that sampled frames are consecutive
for i in range(len(batch) - 1):
curr_idx, next_idx = batch[i : i + 2]
if curr_idx // 10 == next_idx // 10: # same sequence
if max_gap > 0:
curr_idx, next_idx = [
annots[idx]["frame_annotation"].frame_number
for idx in (curr_idx, next_idx)
]
gap = max_gap
else:
gap = 1 # we'll check that raw dataset indices are consecutive
self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info)
def _count_by_quotient(indices, divisor):
counter = defaultdict(int)
for i in indices:
counter[i // divisor] += 1
return counter