mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
SQL Index Dataset
Summary: Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay. It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS. Reviewed By: bottler Differential Revision: D45086611 fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7aeedd17a4
commit
32e1992924
1
tests/implicitron/data/sql_dataset/set_lists_100.json
Normal file
1
tests/implicitron/data/sql_dataset/set_lists_100.json
Normal file
File diff suppressed because one or more lines are too long
BIN
tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite
Normal file
BIN
tests/implicitron/data/sql_dataset/sql_dataset_100.sqlite
Normal file
Binary file not shown.
246
tests/implicitron/test_co3d_sql.py
Normal file
246
tests/implicitron/test_co3d_sql.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
|
||||
SequenceDataLoaderMapProvider,
|
||||
SimpleDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
|
||||
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
|
||||
SqlIndexDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
|
||||
TrainEvalDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
|
||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||
sh = logging.StreamHandler()
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
|
||||
|
||||
|
||||
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
|
||||
class TestCo3dSqlDataSource(unittest.TestCase):
|
||||
def test_no_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.ignore_subsets = True
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
dataset_args.limit_sequences_to = 1
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
|
||||
)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 202)
|
||||
for frame in data_loaders.train:
|
||||
self.assertIsNone(frame.frame_type)
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = (
|
||||
"skateboard/set_lists/set_lists_manyview_dev_0.json"
|
||||
)
|
||||
# this will naturally limit to one sequence (no need to limit by cat/sequence)
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
for sampler_type in [
|
||||
"SimpleDataLoaderMapProvider",
|
||||
"SequenceDataLoaderMapProvider",
|
||||
"TrainEvalDataLoaderMapProvider",
|
||||
]:
|
||||
args.data_loader_map_provider_class_type = sampler_type
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 100)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
# check loading blobs
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800)
|
||||
break
|
||||
|
||||
def test_sql_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
|
||||
for sampler_type in [
|
||||
"SimpleDataLoaderMapProvider",
|
||||
"SequenceDataLoaderMapProvider",
|
||||
"TrainEvalDataLoaderMapProvider",
|
||||
]:
|
||||
args.data_loader_map_provider_class_type = sampler_type
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 100)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(
|
||||
frame.image_rgb.shape[-1], 800
|
||||
) # check loading blobs
|
||||
break
|
||||
|
||||
@unittest.skip("It takes 75 seconds; skipping by default")
|
||||
def test_huge_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 3158974)
|
||||
self.assertEqual(len(data_loaders.val), 518417)
|
||||
self.assertEqual(len(data_loaders.test), 518417)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_broken_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "et_non_est"
|
||||
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
|
||||
with self.assertRaises(FileNotFoundError) as err:
|
||||
ImplicitronDataSource(**args)
|
||||
|
||||
# check the hint text
|
||||
self.assertIn("Subset lists path given but not found", str(err.exception))
|
||||
|
||||
def test_eval_batches(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
provider_args.eval_batches_path = (
|
||||
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
|
||||
)
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 50)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_eval_batches_from_subset_list_name(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_list_name = "manyview_dev_0"
|
||||
provider_args.category = "skateboard"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 50)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_frame_access(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
|
||||
frame_builder_args.load_point_clouds = True
|
||||
frame_builder_args.box_crop = False # required for .meta
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
dataset_map, _ = data_source.get_datasets_and_dataloaders()
|
||||
dataset = dataset_map["train"]
|
||||
|
||||
for idx in [10, ("245_26182_52130", 22)]:
|
||||
example_meta = dataset.meta[idx]
|
||||
example = dataset[idx]
|
||||
|
||||
self.assertIsNone(example_meta.image_rgb)
|
||||
self.assertIsNone(example_meta.fg_probability)
|
||||
self.assertIsNone(example_meta.depth_map)
|
||||
self.assertIsNone(example_meta.sequence_point_cloud)
|
||||
self.assertIsNotNone(example_meta.camera)
|
||||
|
||||
self.assertIsNotNone(example.image_rgb)
|
||||
self.assertIsNotNone(example.fg_probability)
|
||||
self.assertIsNotNone(example.depth_map)
|
||||
self.assertIsNotNone(example.sequence_point_cloud)
|
||||
self.assertIsNotNone(example.camera)
|
||||
|
||||
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.focal_length, example.camera.focal_length
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.principal_point, example.camera.principal_point
|
||||
)
|
||||
522
tests/implicitron/test_sql_dataset.py
Normal file
522
tests/implicitron/test_sql_dataset.py
Normal file
@@ -0,0 +1,522 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from collections import Counter
|
||||
|
||||
import pkg_resources
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
||||
|
||||
NO_BLOBS_KWARGS = {
|
||||
"dataset_root": "",
|
||||
"load_images": False,
|
||||
"load_depths": False,
|
||||
"load_masks": False,
|
||||
"load_depth_masks": False,
|
||||
"box_crop": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||
sh = logging.StreamHandler()
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
|
||||
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
|
||||
SET_LIST_FILE = os.path.join(DATASET_ROOT, "set_lists_100.json")
|
||||
|
||||
|
||||
class TestSqlDataset(unittest.TestCase):
|
||||
def test_basic(self, sequence="cat1_seq2", frame_number=4):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number == 0:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
# test indexing
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[len(dataset) + 1]
|
||||
|
||||
# test sequence-frame indexing
|
||||
item = dataset[sequence, frame_number]
|
||||
self.assertEqual(item.sequence_name, sequence)
|
||||
self.assertEqual(item.frame_number, frame_number)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[sequence, 13]
|
||||
|
||||
def test_filter_empty_masks(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 78)
|
||||
|
||||
def test_pick_frames_sql_clause(self):
|
||||
dataset_no_empty_masks = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# check the datasets are equal
|
||||
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
|
||||
for i in range(len(dataset)):
|
||||
item_nem = dataset_no_empty_masks[i]
|
||||
item = dataset[i]
|
||||
self.assertEqual(item_nem.image_path, item.image_path)
|
||||
|
||||
# remove_empty_masks together with the custom criterion
|
||||
dataset_ts = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
pick_frames_sql_clause="frame_timestamp < 0.15",
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
self.assertEqual(len(dataset_ts), 19)
|
||||
|
||||
def test_limit_categories(self, category="cat0"):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_categories=[category],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 50)
|
||||
for i in range(len(dataset)):
|
||||
self.assertEqual(dataset[i].sequence_category, category)
|
||||
|
||||
def test_limit_sequences(self, num_sequences=3):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_sequences_to=num_sequences,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 10 * num_sequences)
|
||||
|
||||
def delist(sequence_name):
|
||||
return sequence_name if isinstance(sequence_name, str) else sequence_name[0]
|
||||
|
||||
unique_seqs = {delist(dataset[i].sequence_name) for i in range(len(dataset))}
|
||||
self.assertEqual(len(unique_seqs), num_sequences)
|
||||
|
||||
def test_pick_exclude_sequencess(self, sequence="cat1_seq2"):
|
||||
# pick sequence
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_sequences=[sequence],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 10)
|
||||
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||
self.assertCountEqual(unique_seqs, {sequence})
|
||||
|
||||
item = dataset[sequence, 0]
|
||||
self.assertEqual(item.sequence_name, sequence)
|
||||
self.assertEqual(item.frame_number, 0)
|
||||
|
||||
# exclude sequence
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
exclude_sequences=[sequence],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 90)
|
||||
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||
self.assertNotIn(sequence, unique_seqs)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[sequence, 0]
|
||||
|
||||
def test_limit_frames(self, num_frames=13):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_to=num_frames,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), num_frames)
|
||||
unique_seqs = {dataset[i].sequence_name for i in range(len(dataset))}
|
||||
self.assertEqual(len(unique_seqs), 2)
|
||||
|
||||
# test when the limit is not binding
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_to=1000,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_limit_frames_per_sequence(self, num_frames=2):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
n_frames_per_sequence=num_frames,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), num_frames * 10)
|
||||
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||
self.assertEqual(len(seq_counts), 10)
|
||||
self.assertCountEqual(
|
||||
set(seq_counts.values()), {2}
|
||||
) # all counts are num_frames
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[next(iter(seq_counts)), num_frames + 1]
|
||||
|
||||
# test when the limit is not binding
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
n_frames_per_sequence=13,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
def test_filter_medley(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
pick_categories=["cat1"],
|
||||
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
|
||||
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
|
||||
limit_to=14, # retaining full "cat1_seq1" and 4 from "cat1_seq2"
|
||||
n_frames_per_sequence=6, # cutting "cat1_seq1" to 6 frames
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
|
||||
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
|
||||
self.assertEqual(seq_counts["cat1_seq1"], 6)
|
||||
self.assertEqual(seq_counts["cat1_seq2"], 4)
|
||||
|
||||
def test_subsets_trivial(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
limit_to=100, # force sorting
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 100)
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number == 0:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
def test_subsets_filter_empty_masks(self):
|
||||
# we need to test this case as it uses quite different logic with `df.drop()`
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 78)
|
||||
|
||||
def test_subsets_pick_frames_sql_clause(self):
|
||||
dataset_no_empty_masks = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
pick_frames_sql_clause="_mask_mass IS NULL OR _mask_mass > 0",
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# check the datasets are equal
|
||||
self.assertEqual(len(dataset), len(dataset_no_empty_masks))
|
||||
for i in range(len(dataset)):
|
||||
item_nem = dataset_no_empty_masks[i]
|
||||
item = dataset[i]
|
||||
self.assertEqual(item_nem.image_path, item.image_path)
|
||||
|
||||
# remove_empty_masks together with the custom criterion
|
||||
dataset_ts = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
pick_frames_sql_clause="frame_timestamp < 0.15",
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset_ts), 19)
|
||||
|
||||
def test_single_subset(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 50)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[51]
|
||||
|
||||
# check the items are consecutive
|
||||
past_sequences = set()
|
||||
last_frame_number = -1
|
||||
last_sequence = ""
|
||||
for i in range(len(dataset)):
|
||||
item = dataset[i]
|
||||
|
||||
if item.frame_number < 2:
|
||||
self.assertNotIn(item.sequence_name, past_sequences)
|
||||
past_sequences.add(item.sequence_name)
|
||||
last_sequence = item.sequence_name
|
||||
else:
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
self.assertEqual(item.frame_number, last_frame_number + 2)
|
||||
|
||||
last_frame_number = item.frame_number
|
||||
|
||||
item = dataset[last_sequence, 0]
|
||||
self.assertEqual(item.sequence_name, last_sequence)
|
||||
|
||||
with self.assertRaises(IndexError):
|
||||
dataset[last_sequence, 1]
|
||||
|
||||
def test_subset_with_filters(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
pick_categories=["cat1"],
|
||||
exclude_sequences=["cat1_seq0"], # retaining "cat1_seq1" and on
|
||||
limit_sequences_to=2, # retaining "cat1_seq1" and "cat1_seq2"
|
||||
limit_to=7, # retaining full train set of "cat1_seq1" and 2 from "cat1_seq2"
|
||||
n_frames_per_sequence=3, # cutting "cat1_seq1" to 3 frames
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
# result: preserved 6 frames from cat1_seq1 and 4 from cat1_seq2
|
||||
seq_counts = Counter(dataset[i].sequence_name for i in range(len(dataset)))
|
||||
self.assertCountEqual(seq_counts.keys(), ["cat1_seq1", "cat1_seq2"])
|
||||
self.assertEqual(seq_counts["cat1_seq1"], 3)
|
||||
self.assertEqual(seq_counts["cat1_seq2"], 2)
|
||||
|
||||
def test_visitor(self):
|
||||
dataset_sorted = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
sequences = dataset_sorted.sequence_names()
|
||||
i = 0
|
||||
for seq in sequences:
|
||||
last_ts = float("-Inf")
|
||||
for ts, _, idx in dataset_sorted.sequence_frames_in_order(seq):
|
||||
self.assertEqual(i, idx)
|
||||
i += 1
|
||||
self.assertGreaterEqual(ts, last_ts)
|
||||
last_ts = ts
|
||||
|
||||
# test legacy visitor
|
||||
old_indices = None
|
||||
for seq in sequences:
|
||||
last_ts = float("-Inf")
|
||||
rows = dataset_sorted._index.index.get_loc(seq)
|
||||
indices = list(range(rows.start or 0, rows.stop, rows.step or 1))
|
||||
fn_ts_list = dataset_sorted.get_frame_numbers_and_timestamps(indices)
|
||||
self.assertEqual(len(fn_ts_list), len(indices))
|
||||
|
||||
if old_indices:
|
||||
# check raising if we ask for multiple sequences
|
||||
with self.assertRaises(ValueError):
|
||||
dataset_sorted.get_frame_numbers_and_timestamps(
|
||||
indices + old_indices
|
||||
)
|
||||
|
||||
old_indices = indices
|
||||
|
||||
def test_visitor_subsets(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
limit_to=100, # force sorting
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
sequences = dataset.sequence_names()
|
||||
i = 0
|
||||
for seq in sequences:
|
||||
last_ts = float("-Inf")
|
||||
seq_frames = list(dataset.sequence_frames_in_order(seq))
|
||||
self.assertEqual(len(seq_frames), 10)
|
||||
for ts, _, idx in seq_frames:
|
||||
self.assertEqual(i, idx)
|
||||
i += 1
|
||||
self.assertGreaterEqual(ts, last_ts)
|
||||
last_ts = ts
|
||||
|
||||
last_ts = float("-Inf")
|
||||
train_frames = list(dataset.sequence_frames_in_order(seq, "train"))
|
||||
self.assertEqual(len(train_frames), 5)
|
||||
for ts, _, _ in train_frames:
|
||||
self.assertGreaterEqual(ts, last_ts)
|
||||
last_ts = ts
|
||||
|
||||
def test_category_to_sequence_names(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
cat_to_seqs = dataset.category_to_sequence_names()
|
||||
self.assertEqual(len(cat_to_seqs), 2)
|
||||
self.assertIn("cat1", cat_to_seqs)
|
||||
self.assertEqual(len(cat_to_seqs["cat1"]), 5)
|
||||
|
||||
# check that override preserves the behavior
|
||||
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
|
||||
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
|
||||
|
||||
def test_category_to_sequence_names_filters(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=True,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
exclude_sequences=["cat1_seq0"],
|
||||
subsets=["train", "test"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
cat_to_seqs = dataset.category_to_sequence_names()
|
||||
self.assertEqual(len(cat_to_seqs), 2)
|
||||
self.assertIn("cat1", cat_to_seqs)
|
||||
self.assertEqual(len(cat_to_seqs["cat1"]), 4) # minus one
|
||||
|
||||
# check that override preserves the behavior
|
||||
cat_to_seqs_base = super(SqlIndexDataset, dataset).category_to_sequence_names()
|
||||
self.assertDictEqual(cat_to_seqs, cat_to_seqs_base)
|
||||
|
||||
def test_meta_access(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||
)
|
||||
|
||||
self.assertEqual(len(dataset), 50)
|
||||
|
||||
for idx in [10, ("cat0_seq2", 2)]:
|
||||
example_meta = dataset.meta[idx]
|
||||
example = dataset[idx]
|
||||
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.focal_length, example.camera.focal_length
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.principal_point, example.camera.principal_point
|
||||
)
|
||||
|
||||
def test_meta_access_no_blobs(self):
|
||||
dataset = SqlIndexDataset(
|
||||
sqlite_metadata_file=METADATA_FILE,
|
||||
remove_empty_masks=False,
|
||||
subset_lists_file=SET_LIST_FILE,
|
||||
subsets=["train"],
|
||||
frame_data_builder_FrameDataBuilder_args={
|
||||
"dataset_root": ".",
|
||||
"box_crop": False, # required by blob-less accessor
|
||||
},
|
||||
)
|
||||
|
||||
self.assertIsNone(dataset.meta[0].image_rgb)
|
||||
self.assertIsNone(dataset.meta[0].fg_probability)
|
||||
self.assertIsNone(dataset.meta[0].depth_map)
|
||||
self.assertIsNone(dataset.meta[0].sequence_point_cloud)
|
||||
self.assertIsNotNone(dataset.meta[0].camera)
|
||||
Reference in New Issue
Block a user