mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	More tests for SQL Dataset
Summary: I forgot to include these tests to D45086611 when transferring code from pixar_replay repo. They test the new ORM types used in SQL dataset and are SQL Alchemy 2.0 specific. An important test for extending types is a proof of concept for generality of SQL Dataset. The idea is to extend FrameAnnotation and FrameData in parallel. Reviewed By: bottler Differential Revision: D45529284 fbshipit-source-id: 2a634e518f580c312602107c85fc320db43abcf5
This commit is contained in:
		
							parent
							
								
									178a7774d4
								
							
						
					
					
						commit
						3e3644e534
					
				
							
								
								
									
										230
									
								
								tests/implicitron/test_extending_orm_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								tests/implicitron/test_extending_orm_types.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,230 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import dataclasses
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
from typing import ClassVar, Optional, Type
 | 
			
		||||
 | 
			
		||||
import pandas as pd
 | 
			
		||||
import pkg_resources
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset import types
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import FrameData, GenericFrameDataBuilder
 | 
			
		||||
from pytorch3d.implicitron.dataset.orm_types import (
 | 
			
		||||
    SqlFrameAnnotation,
 | 
			
		||||
    SqlSequenceAnnotation,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.dataset.sql_dataset import SqlIndexDataset
 | 
			
		||||
from pytorch3d.implicitron.dataset.utils import GenericWorkaround
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry
 | 
			
		||||
from sqlalchemy.orm import composite, Mapped, mapped_column, Session
 | 
			
		||||
 | 
			
		||||
NO_BLOBS_KWARGS = {
 | 
			
		||||
    "dataset_root": "",
 | 
			
		||||
    "load_images": False,
 | 
			
		||||
    "load_depths": False,
 | 
			
		||||
    "load_masks": False,
 | 
			
		||||
    "load_depth_masks": False,
 | 
			
		||||
    "box_crop": False,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
DATASET_ROOT = pkg_resources.resource_filename(__name__, "data/sql_dataset")
 | 
			
		||||
METADATA_FILE = os.path.join(DATASET_ROOT, "sql_dataset_100.sqlite")
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger("pytorch3d.implicitron.dataset.sql_dataset")
 | 
			
		||||
sh = logging.StreamHandler()
 | 
			
		||||
logger.addHandler(sh)
 | 
			
		||||
logger.setLevel(logging.DEBUG)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclasses.dataclass
 | 
			
		||||
class MagneticFieldAnnotation:
 | 
			
		||||
    path: str
 | 
			
		||||
    average_flux_density: Optional[float] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExtendedSqlFrameAnnotation(SqlFrameAnnotation):
 | 
			
		||||
    num_dogs: Mapped[Optional[int]] = mapped_column(default=None)
 | 
			
		||||
 | 
			
		||||
    magnetic_field: Mapped[MagneticFieldAnnotation] = composite(
 | 
			
		||||
        mapped_column("_magnetic_field_path", nullable=True),
 | 
			
		||||
        mapped_column("_magnetic_field_average_flux_density", nullable=True),
 | 
			
		||||
        default_factory=lambda: None,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ExtendedSqlIndexDataset(SqlIndexDataset):
 | 
			
		||||
    frame_annotations_type: ClassVar[
 | 
			
		||||
        Type[SqlFrameAnnotation]
 | 
			
		||||
    ] = ExtendedSqlFrameAnnotation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CanineFrameData(FrameData):
 | 
			
		||||
    num_dogs: Optional[int] = None
 | 
			
		||||
    magnetic_field_average_flux_density: Optional[float] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class CanineFrameDataBuilder(
 | 
			
		||||
    GenericWorkaround, GenericFrameDataBuilder[CanineFrameData]
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    A concrete class to build an extended FrameData object
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_data_type: ClassVar[Type[FrameData]] = CanineFrameData
 | 
			
		||||
 | 
			
		||||
    def build(
 | 
			
		||||
        self,
 | 
			
		||||
        frame_annotation: ExtendedSqlFrameAnnotation,
 | 
			
		||||
        sequence_annotation: types.SequenceAnnotation,
 | 
			
		||||
        load_blobs: bool = True,
 | 
			
		||||
    ) -> CanineFrameData:
 | 
			
		||||
        frame_data = super().build(frame_annotation, sequence_annotation, load_blobs)
 | 
			
		||||
        frame_data.num_dogs = frame_annotation.num_dogs or 101
 | 
			
		||||
        frame_data.magnetic_field_average_flux_density = (
 | 
			
		||||
            frame_annotation.magnetic_field.average_flux_density
 | 
			
		||||
        )
 | 
			
		||||
        return frame_data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CanineSqlIndexDataset(SqlIndexDataset):
 | 
			
		||||
    frame_annotations_type: ClassVar[
 | 
			
		||||
        Type[SqlFrameAnnotation]
 | 
			
		||||
    ] = ExtendedSqlFrameAnnotation
 | 
			
		||||
 | 
			
		||||
    frame_data_builder_class_type: str = "CanineFrameDataBuilder"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestExtendingOrmTypes(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        # create a temporary copy of the DB with an extended schema
 | 
			
		||||
        engine = sa.create_engine(f"sqlite:///{METADATA_FILE}")
 | 
			
		||||
        with Session(engine) as session:
 | 
			
		||||
            extended_annots = [
 | 
			
		||||
                ExtendedSqlFrameAnnotation(
 | 
			
		||||
                    **{
 | 
			
		||||
                        k: v
 | 
			
		||||
                        for k, v in frame_annot.__dict__.items()
 | 
			
		||||
                        if not k.startswith("_")  # remove mapped fields and SA metadata
 | 
			
		||||
                    }
 | 
			
		||||
                )
 | 
			
		||||
                for frame_annot in session.scalars(sa.select(SqlFrameAnnotation))
 | 
			
		||||
            ]
 | 
			
		||||
            seq_annots = session.scalars(
 | 
			
		||||
                sa.select(SqlSequenceAnnotation),
 | 
			
		||||
                execution_options={"prebuffer_rows": True},
 | 
			
		||||
            )
 | 
			
		||||
            session.expunge_all()
 | 
			
		||||
 | 
			
		||||
        self._temp_db = tempfile.NamedTemporaryFile(delete=False)
 | 
			
		||||
        engine_ext = sa.create_engine(f"sqlite:///{self._temp_db.name}")
 | 
			
		||||
        ExtendedSqlFrameAnnotation.metadata.create_all(engine_ext, checkfirst=True)
 | 
			
		||||
        with Session(engine_ext, expire_on_commit=False) as session_ext:
 | 
			
		||||
            session_ext.add_all(extended_annots)
 | 
			
		||||
            for instance in seq_annots:
 | 
			
		||||
                session_ext.merge(instance)
 | 
			
		||||
            session_ext.commit()
 | 
			
		||||
 | 
			
		||||
        # check the setup is correct
 | 
			
		||||
        with engine_ext.connect() as connection_ext:
 | 
			
		||||
            df = pd.read_sql_query(
 | 
			
		||||
                sa.select(ExtendedSqlFrameAnnotation), connection_ext
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(len(df), 100)
 | 
			
		||||
            self.assertIn("_magnetic_field_average_flux_density", df.columns)
 | 
			
		||||
 | 
			
		||||
            df_seq = pd.read_sql_query(sa.select(SqlSequenceAnnotation), connection_ext)
 | 
			
		||||
            self.assertEqual(len(df_seq), 10)
 | 
			
		||||
 | 
			
		||||
    def tearDown(self):
 | 
			
		||||
        self._temp_db.close()
 | 
			
		||||
        os.remove(self._temp_db.name)
 | 
			
		||||
 | 
			
		||||
    def test_basic(self, sequence="cat1_seq2", frame_number=4):
 | 
			
		||||
        dataset = ExtendedSqlIndexDataset(
 | 
			
		||||
            sqlite_metadata_file=self._temp_db.name,
 | 
			
		||||
            remove_empty_masks=False,
 | 
			
		||||
            frame_data_builder_FrameDataBuilder_args=NO_BLOBS_KWARGS,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(dataset), 100)
 | 
			
		||||
 | 
			
		||||
        # check the items are consecutive
 | 
			
		||||
        past_sequences = set()
 | 
			
		||||
        last_frame_number = -1
 | 
			
		||||
        last_sequence = ""
 | 
			
		||||
        for i in range(len(dataset)):
 | 
			
		||||
            item = dataset[i]
 | 
			
		||||
 | 
			
		||||
            if item.frame_number == 0:
 | 
			
		||||
                self.assertNotIn(item.sequence_name, past_sequences)
 | 
			
		||||
                past_sequences.add(item.sequence_name)
 | 
			
		||||
                last_sequence = item.sequence_name
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertEqual(item.sequence_name, last_sequence)
 | 
			
		||||
                self.assertEqual(item.frame_number, last_frame_number + 1)
 | 
			
		||||
 | 
			
		||||
            last_frame_number = item.frame_number
 | 
			
		||||
 | 
			
		||||
        # test indexing
 | 
			
		||||
        with self.assertRaises(IndexError):
 | 
			
		||||
            dataset[len(dataset) + 1]
 | 
			
		||||
 | 
			
		||||
        # test sequence-frame indexing
 | 
			
		||||
        item = dataset[sequence, frame_number]
 | 
			
		||||
        self.assertEqual(item.sequence_name, sequence)
 | 
			
		||||
        self.assertEqual(item.frame_number, frame_number)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(IndexError):
 | 
			
		||||
            dataset[sequence, 13]
 | 
			
		||||
 | 
			
		||||
    def test_extending_frame_data(self, sequence="cat1_seq2", frame_number=4):
 | 
			
		||||
        dataset = CanineSqlIndexDataset(
 | 
			
		||||
            sqlite_metadata_file=self._temp_db.name,
 | 
			
		||||
            remove_empty_masks=False,
 | 
			
		||||
            frame_data_builder_CanineFrameDataBuilder_args=NO_BLOBS_KWARGS,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(dataset), 100)
 | 
			
		||||
 | 
			
		||||
        # check the items are consecutive
 | 
			
		||||
        past_sequences = set()
 | 
			
		||||
        last_frame_number = -1
 | 
			
		||||
        last_sequence = ""
 | 
			
		||||
        for i in range(len(dataset)):
 | 
			
		||||
            item = dataset[i]
 | 
			
		||||
            self.assertIsInstance(item, CanineFrameData)
 | 
			
		||||
            self.assertEqual(item.num_dogs, 101)
 | 
			
		||||
            self.assertIsNone(item.magnetic_field_average_flux_density)
 | 
			
		||||
 | 
			
		||||
            if item.frame_number == 0:
 | 
			
		||||
                self.assertNotIn(item.sequence_name, past_sequences)
 | 
			
		||||
                past_sequences.add(item.sequence_name)
 | 
			
		||||
                last_sequence = item.sequence_name
 | 
			
		||||
            else:
 | 
			
		||||
                self.assertEqual(item.sequence_name, last_sequence)
 | 
			
		||||
                self.assertEqual(item.frame_number, last_frame_number + 1)
 | 
			
		||||
 | 
			
		||||
            last_frame_number = item.frame_number
 | 
			
		||||
 | 
			
		||||
        # test indexing
 | 
			
		||||
        with self.assertRaises(IndexError):
 | 
			
		||||
            dataset[len(dataset) + 1]
 | 
			
		||||
 | 
			
		||||
        # test sequence-frame indexing
 | 
			
		||||
        item = dataset[sequence, frame_number]
 | 
			
		||||
        self.assertIsInstance(item, CanineFrameData)
 | 
			
		||||
        self.assertEqual(item.sequence_name, sequence)
 | 
			
		||||
        self.assertEqual(item.frame_number, frame_number)
 | 
			
		||||
        self.assertEqual(item.num_dogs, 101)
 | 
			
		||||
 | 
			
		||||
        with self.assertRaises(IndexError):
 | 
			
		||||
            dataset[sequence, 13]
 | 
			
		||||
							
								
								
									
										37
									
								
								tests/implicitron/test_orm_types.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										37
									
								
								tests/implicitron/test_orm_types.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,37 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestOrmTypes(unittest.TestCase):
 | 
			
		||||
    def test_tuple_serialization_none(self):
 | 
			
		||||
        ttype = TupleTypeFactory()()
 | 
			
		||||
        output = ttype.process_bind_param(None, None)
 | 
			
		||||
        self.assertIsNone(output)
 | 
			
		||||
        output = ttype.process_result_value(output, None)
 | 
			
		||||
        self.assertIsNone(output)
 | 
			
		||||
 | 
			
		||||
    def test_tuple_serialization_1d(self):
 | 
			
		||||
        for input_tuple in [(1, 2, 3), (4.5, 6.7)]:
 | 
			
		||||
            ttype = TupleTypeFactory(type(input_tuple[0]), (len(input_tuple),))()
 | 
			
		||||
            output = ttype.process_bind_param(input_tuple, None)
 | 
			
		||||
            input_hat = ttype.process_result_value(output, None)
 | 
			
		||||
            self.assertEqual(type(input_hat[0]), type(input_tuple[0]))
 | 
			
		||||
            np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
 | 
			
		||||
 | 
			
		||||
    def test_tuple_serialization_2d(self):
 | 
			
		||||
        input_tuple = ((1.0, 2.0, 3.0), (4.5, 5.5, 6.6))
 | 
			
		||||
        ttype = TupleTypeFactory(type(input_tuple[0][0]), (2, 3))()
 | 
			
		||||
        output = ttype.process_bind_param(input_tuple, None)
 | 
			
		||||
        input_hat = ttype.process_result_value(output, None)
 | 
			
		||||
        self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
 | 
			
		||||
        # we use float32 to serialise
 | 
			
		||||
        np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user