mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
SQL Index Dataset
Summary: Moving SQL dataset to PyTorch3D. It has been extensively tested in pixar_replay. It requires SQLAlchemy 2.0, which is not supported in fbcode. So I exclude the sources and tests that depend on it from buck TARGETS. Reviewed By: bottler Differential Revision: D45086611 fbshipit-source-id: 0285f03e5824c0478c70ad13731525bb5ec7deef
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7aeedd17a4
commit
32e1992924
246
tests/implicitron/test_co3d_sql.py
Normal file
246
tests/implicitron/test_co3d_sql.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import ( # noqa
|
||||
SequenceDataLoaderMapProvider,
|
||||
SimpleDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset # noqa
|
||||
from pytorch3d.implicitron.dataset.sql_dataset_provider import ( # noqa
|
||||
SqlIndexDatasetMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.train_eval_data_loader_provider import (
|
||||
TrainEvalDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
|
||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||
sh = logging.StreamHandler()
|
||||
logger.addHandler(sh)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
_CO3D_SQL_DATASET_ROOT: str = os.getenv("CO3D_SQL_DATASET_ROOT", "")
|
||||
|
||||
|
||||
@unittest.skipUnless(_CO3D_SQL_DATASET_ROOT, "Run only if CO3D is available")
|
||||
class TestCo3dSqlDataSource(unittest.TestCase):
|
||||
def test_no_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.ignore_subsets = True
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
dataset_args.limit_sequences_to = 1
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
data_source.data_loader_map_provider, TrainEvalDataLoaderMapProvider
|
||||
)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 202)
|
||||
for frame in data_loaders.train:
|
||||
self.assertIsNone(frame.frame_type)
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = (
|
||||
"skateboard/set_lists/set_lists_manyview_dev_0.json"
|
||||
)
|
||||
# this will naturally limit to one sequence (no need to limit by cat/sequence)
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
for sampler_type in [
|
||||
"SimpleDataLoaderMapProvider",
|
||||
"SequenceDataLoaderMapProvider",
|
||||
"TrainEvalDataLoaderMapProvider",
|
||||
]:
|
||||
args.data_loader_map_provider_class_type = sampler_type
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 100)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
# check loading blobs
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800)
|
||||
break
|
||||
|
||||
def test_sql_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
|
||||
for sampler_type in [
|
||||
"SimpleDataLoaderMapProvider",
|
||||
"SequenceDataLoaderMapProvider",
|
||||
"TrainEvalDataLoaderMapProvider",
|
||||
]:
|
||||
args.data_loader_map_provider_class_type = sampler_type
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 100)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(
|
||||
frame.image_rgb.shape[-1], 800
|
||||
) # check loading blobs
|
||||
break
|
||||
|
||||
@unittest.skip("It takes 75 seconds; skipping by default")
|
||||
def test_huge_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_fewview_dev.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 3158974)
|
||||
self.assertEqual(len(data_loaders.val), 518417)
|
||||
self.assertEqual(len(data_loaders.test), 518417)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_broken_subsets(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "et_non_est"
|
||||
provider_args.dataset_SqlIndexDataset_args.pick_categories = ["skateboard"]
|
||||
with self.assertRaises(FileNotFoundError) as err:
|
||||
ImplicitronDataSource(**args)
|
||||
|
||||
# check the hint text
|
||||
self.assertIn("Subset lists path given but not found", str(err.exception))
|
||||
|
||||
def test_eval_batches(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
provider_args.eval_batches_path = (
|
||||
"skateboard/eval_batches/eval_batches_manyview_dev_0.json"
|
||||
)
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 50)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_eval_batches_from_subset_list_name(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_list_name = "manyview_dev_0"
|
||||
provider_args.category = "skateboard"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
dataset, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertListEqual(list(dataset.train.pick_categories), ["skateboard"])
|
||||
self.assertEqual(len(data_loaders.train), 102)
|
||||
self.assertEqual(len(data_loaders.val), 100)
|
||||
self.assertEqual(len(data_loaders.test), 50)
|
||||
for split in ["train", "val", "test"]:
|
||||
for frame in data_loaders[split]:
|
||||
self.assertEqual(frame.frame_type, [split])
|
||||
self.assertEqual(frame.image_rgb.shape[-1], 800) # check loading blobs
|
||||
break
|
||||
|
||||
def test_frame_access(self):
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "SqlIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "TrainEvalDataLoaderMapProvider"
|
||||
provider_args = args.dataset_map_provider_SqlIndexDatasetMapProvider_args
|
||||
provider_args.subset_lists_path = "set_lists/set_lists_manyview_dev_0.sqlite"
|
||||
|
||||
dataset_args = provider_args.dataset_SqlIndexDataset_args
|
||||
dataset_args.remove_empty_masks = True
|
||||
dataset_args.pick_categories = ["skateboard"]
|
||||
frame_builder_args = dataset_args.frame_data_builder_FrameDataBuilder_args
|
||||
frame_builder_args.load_point_clouds = True
|
||||
frame_builder_args.box_crop = False # required for .meta
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
dataset_map, _ = data_source.get_datasets_and_dataloaders()
|
||||
dataset = dataset_map["train"]
|
||||
|
||||
for idx in [10, ("245_26182_52130", 22)]:
|
||||
example_meta = dataset.meta[idx]
|
||||
example = dataset[idx]
|
||||
|
||||
self.assertIsNone(example_meta.image_rgb)
|
||||
self.assertIsNone(example_meta.fg_probability)
|
||||
self.assertIsNone(example_meta.depth_map)
|
||||
self.assertIsNone(example_meta.sequence_point_cloud)
|
||||
self.assertIsNotNone(example_meta.camera)
|
||||
|
||||
self.assertIsNotNone(example.image_rgb)
|
||||
self.assertIsNotNone(example.fg_probability)
|
||||
self.assertIsNotNone(example.depth_map)
|
||||
self.assertIsNotNone(example.sequence_point_cloud)
|
||||
self.assertIsNotNone(example.camera)
|
||||
|
||||
self.assertEqual(example_meta.sequence_name, example.sequence_name)
|
||||
self.assertEqual(example_meta.frame_number, example.frame_number)
|
||||
self.assertEqual(example_meta.frame_timestamp, example.frame_timestamp)
|
||||
self.assertEqual(example_meta.sequence_category, example.sequence_category)
|
||||
torch.testing.assert_close(example_meta.camera.R, example.camera.R)
|
||||
torch.testing.assert_close(example_meta.camera.T, example.camera.T)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.focal_length, example.camera.focal_length
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
example_meta.camera.principal_point, example.camera.principal_point
|
||||
)
|
||||
Reference in New Issue
Block a user