mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
More tests for SQL Dataset
Summary: I forgot to include these tests to D45086611 when transferring code from pixar_replay repo. They test the new ORM types used in SQL dataset and are SQL Alchemy 2.0 specific. An important test for extending types is a proof of concept for generality of SQL Dataset. The idea is to extend FrameAnnotation and FrameData in parallel. Reviewed By: bottler Differential Revision: D45529284 fbshipit-source-id: 2a634e518f580c312602107c85fc320db43abcf5
This commit is contained in:
parent
178a7774d4
commit
3e3644e534
230
tests/implicitron/test_extending_orm_types.py
Normal file
230
tests/implicitron/test_extending_orm_types.py
Normal file
@ -0,0 +1,230 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from typing import ClassVar, Optional, Type
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import pkg_resources
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset import types
|
||||||
|
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder
|
||||||
|
from pytorch3d.implicitron.dataset.orm_types import (
|
||||||
|
SqlFrameAnnotation,
|
||||||
|
SqlSequenceAnnotation,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
|
||||||
|
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
|
||||||
|
from pytorch3d.implicitron.tools.config import registry
|
||||||
|
from sqlalchemy.orm import composite, Mapped, mapped_column, Session
|
||||||
|
|
||||||
|
NO_BLOBS_KWARGS = {
|
||||||
|
"dataset_root": "",
|
||||||
|
"load_images": False,
|
||||||
|
"load_depths": False,
|
||||||
|
"load_masks": False,
|
||||||
|
"load_depth_masks": False,
|
||||||
|
"box_crop": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
|
||||||
|
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
|
||||||
|
|
||||||
|
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
|
||||||
|
sh = logging.StreamHandler()
|
||||||
|
logger.addHandler(sh)
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class MagneticFieldAnnotation:
|
||||||
|
path: str
|
||||||
|
average_flux_density: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedSqlFrameAnnotation(SqlFrameAnnotation):
|
||||||
|
num_dogs: Mapped[Optional[int]] = mapped_column(default=None)
|
||||||
|
|
||||||
|
magnetic_field: Mapped[MagneticFieldAnnotation] = composite(
|
||||||
|
mapped_column("_magnetic_field_path", nullable=True),
|
||||||
|
mapped_column("_magnetic_field_average_flux_density", nullable=True),
|
||||||
|
default_factory=lambda: None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExtendedSqlIndexDataset(SqlIndexDataset):
|
||||||
|
frame_annotations_type: ClassVar[
|
||||||
|
Type[SqlFrameAnnotation]
|
||||||
|
] = ExtendedSqlFrameAnnotation
|
||||||
|
|
||||||
|
|
||||||
|
class CanineFrameData(FrameData):
|
||||||
|
num_dogs: Optional[int] = None
|
||||||
|
magnetic_field_average_flux_density: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class CanineFrameDataBuilder(
|
||||||
|
GenericWorkaround, GenericFrameDataBuilder[CanineFrameData]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A concrete class to build an extended FrameData object
|
||||||
|
"""
|
||||||
|
|
||||||
|
frame_data_type: ClassVar[Type[FrameData]] = CanineFrameData
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
frame_annotation: ExtendedSqlFrameAnnotation,
|
||||||
|
sequence_annotation: types.SequenceAnnotation,
|
||||||
|
load_blobs: bool = True,
|
||||||
|
) -> CanineFrameData:
|
||||||
|
frame_data = super().build(frame_annotation, sequence_annotation, load_blobs)
|
||||||
|
frame_data.num_dogs = frame_annotation.num_dogs or 101
|
||||||
|
frame_data.magnetic_field_average_flux_density = (
|
||||||
|
frame_annotation.magnetic_field.average_flux_density
|
||||||
|
)
|
||||||
|
return frame_data
|
||||||
|
|
||||||
|
|
||||||
|
class CanineSqlIndexDataset(SqlIndexDataset):
|
||||||
|
frame_annotations_type: ClassVar[
|
||||||
|
Type[SqlFrameAnnotation]
|
||||||
|
] = ExtendedSqlFrameAnnotation
|
||||||
|
|
||||||
|
frame_data_builder_class_type: str = "CanineFrameDataBuilder"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtendingOrmTypes(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
# create a temporary copy of the DB with an extended schema
|
||||||
|
engine = sa.create_engine(f"sqlite:///{METADATA_FILE}")
|
||||||
|
with Session(engine) as session:
|
||||||
|
extended_annots = [
|
||||||
|
ExtendedSqlFrameAnnotation(
|
||||||
|
**{
|
||||||
|
k: v
|
||||||
|
for k, v in frame_annot.__dict__.items()
|
||||||
|
if not k.startswith("_") # remove mapped fields and SA metadata
|
||||||
|
}
|
||||||
|
)
|
||||||
|
for frame_annot in session.scalars(sa.select(SqlFrameAnnotation))
|
||||||
|
]
|
||||||
|
seq_annots = session.scalars(
|
||||||
|
sa.select(SqlSequenceAnnotation),
|
||||||
|
execution_options={"prebuffer_rows": True},
|
||||||
|
)
|
||||||
|
session.expunge_all()
|
||||||
|
|
||||||
|
self._temp_db = tempfile.NamedTemporaryFile(delete=False)
|
||||||
|
engine_ext = sa.create_engine(f"sqlite:///{self._temp_db.name}")
|
||||||
|
ExtendedSqlFrameAnnotation.metadata.create_all(engine_ext, checkfirst=True)
|
||||||
|
with Session(engine_ext, expire_on_commit=False) as session_ext:
|
||||||
|
session_ext.add_all(extended_annots)
|
||||||
|
for instance in seq_annots:
|
||||||
|
session_ext.merge(instance)
|
||||||
|
session_ext.commit()
|
||||||
|
|
||||||
|
# check the setup is correct
|
||||||
|
with engine_ext.connect() as connection_ext:
|
||||||
|
df = pd.read_sql_query(
|
||||||
|
sa.select(ExtendedSqlFrameAnnotation), connection_ext
|
||||||
|
)
|
||||||
|
self.assertEqual(len(df), 100)
|
||||||
|
self.assertIn("_magnetic_field_average_flux_density", df.columns)
|
||||||
|
|
||||||
|
df_seq = pd.read_sql_query(sa.select(SqlSequenceAnnotation), connection_ext)
|
||||||
|
self.assertEqual(len(df_seq), 10)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self._temp_db.close()
|
||||||
|
os.remove(self._temp_db.name)
|
||||||
|
|
||||||
|
def test_basic(self, sequence="cat1_seq2", frame_number=4):
|
||||||
|
dataset = ExtendedSqlIndexDataset(
|
||||||
|
sqlite_metadata_file=self._temp_db.name,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
# check the items are consecutive
|
||||||
|
past_sequences = set()
|
||||||
|
last_frame_number = -1
|
||||||
|
last_sequence = ""
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
|
||||||
|
if item.frame_number == 0:
|
||||||
|
self.assertNotIn(item.sequence_name, past_sequences)
|
||||||
|
past_sequences.add(item.sequence_name)
|
||||||
|
last_sequence = item.sequence_name
|
||||||
|
else:
|
||||||
|
self.assertEqual(item.sequence_name, last_sequence)
|
||||||
|
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||||
|
|
||||||
|
last_frame_number = item.frame_number
|
||||||
|
|
||||||
|
# test indexing
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[len(dataset) + 1]
|
||||||
|
|
||||||
|
# test sequence-frame indexing
|
||||||
|
item = dataset[sequence, frame_number]
|
||||||
|
self.assertEqual(item.sequence_name, sequence)
|
||||||
|
self.assertEqual(item.frame_number, frame_number)
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[sequence, 13]
|
||||||
|
|
||||||
|
def test_extending_frame_data(self, sequence="cat1_seq2", frame_number=4):
|
||||||
|
dataset = CanineSqlIndexDataset(
|
||||||
|
sqlite_metadata_file=self._temp_db.name,
|
||||||
|
remove_empty_masks=False,
|
||||||
|
frame_data_builder_CanineFrameDataBuilder_args=NO_BLOBS_KWARGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(dataset), 100)
|
||||||
|
|
||||||
|
# check the items are consecutive
|
||||||
|
past_sequences = set()
|
||||||
|
last_frame_number = -1
|
||||||
|
last_sequence = ""
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
item = dataset[i]
|
||||||
|
self.assertIsInstance(item, CanineFrameData)
|
||||||
|
self.assertEqual(item.num_dogs, 101)
|
||||||
|
self.assertIsNone(item.magnetic_field_average_flux_density)
|
||||||
|
|
||||||
|
if item.frame_number == 0:
|
||||||
|
self.assertNotIn(item.sequence_name, past_sequences)
|
||||||
|
past_sequences.add(item.sequence_name)
|
||||||
|
last_sequence = item.sequence_name
|
||||||
|
else:
|
||||||
|
self.assertEqual(item.sequence_name, last_sequence)
|
||||||
|
self.assertEqual(item.frame_number, last_frame_number + 1)
|
||||||
|
|
||||||
|
last_frame_number = item.frame_number
|
||||||
|
|
||||||
|
# test indexing
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[len(dataset) + 1]
|
||||||
|
|
||||||
|
# test sequence-frame indexing
|
||||||
|
item = dataset[sequence, frame_number]
|
||||||
|
self.assertIsInstance(item, CanineFrameData)
|
||||||
|
self.assertEqual(item.sequence_name, sequence)
|
||||||
|
self.assertEqual(item.frame_number, frame_number)
|
||||||
|
self.assertEqual(item.num_dogs, 101)
|
||||||
|
|
||||||
|
with self.assertRaises(IndexError):
|
||||||
|
dataset[sequence, 13]
|
37
tests/implicitron/test_orm_types.py
Normal file
37
tests/implicitron/test_orm_types.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrmTypes(unittest.TestCase):
|
||||||
|
def test_tuple_serialization_none(self):
|
||||||
|
ttype = TupleTypeFactory()()
|
||||||
|
output = ttype.process_bind_param(None, None)
|
||||||
|
self.assertIsNone(output)
|
||||||
|
output = ttype.process_result_value(output, None)
|
||||||
|
self.assertIsNone(output)
|
||||||
|
|
||||||
|
def test_tuple_serialization_1d(self):
|
||||||
|
for input_tuple in [(1, 2, 3), (4.5, 6.7)]:
|
||||||
|
ttype = TupleTypeFactory(type(input_tuple[0]), (len(input_tuple),))()
|
||||||
|
output = ttype.process_bind_param(input_tuple, None)
|
||||||
|
input_hat = ttype.process_result_value(output, None)
|
||||||
|
self.assertEqual(type(input_hat[0]), type(input_tuple[0]))
|
||||||
|
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
||||||
|
|
||||||
|
def test_tuple_serialization_2d(self):
|
||||||
|
input_tuple = ((1.0, 2.0, 3.0), (4.5, 5.5, 6.6))
|
||||||
|
ttype = TupleTypeFactory(type(input_tuple[0][0]), (2, 3))()
|
||||||
|
output = ttype.process_bind_param(input_tuple, None)
|
||||||
|
input_hat = ttype.process_result_value(output, None)
|
||||||
|
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
|
||||||
|
# we use float32 to serialise
|
||||||
|
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
Loading…
x
Reference in New Issue
Block a user