From d2119c285fa00438de168e17f8a453681dbed4d3 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Mon, 22 May 2023 02:24:49 -0700 Subject: [PATCH] Serialising dynamic arrays in SQL; read-only SQLite connection in SQL Dataset Summary: 1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation. 2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific. Reviewed By: bottler Differential Revision: D46047857 fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26 --- pytorch3d/implicitron/dataset/orm_types.py | 32 ++++++++++++++++++-- pytorch3d/implicitron/dataset/sql_dataset.py | 6 ++-- tests/implicitron/test_orm_types.py | 27 ++++++++++++++++- 3 files changed, 60 insertions(+), 5 deletions(-) diff --git a/pytorch3d/implicitron/dataset/orm_types.py b/pytorch3d/implicitron/dataset/orm_types.py index 5736ab4b..2e916021 100644 --- a/pytorch3d/implicitron/dataset/orm_types.py +++ b/pytorch3d/implicitron/dataset/orm_types.py @@ -33,7 +33,35 @@ from sqlalchemy.types import TypeDecorator # these produce policies to serialize structured types to blobs -def ArrayTypeFactory(shape): +def ArrayTypeFactory(shape=None): + if shape is None: + + class VariableShapeNumpyArrayType(TypeDecorator): + impl = LargeBinary + + def process_bind_param(self, value, dialect): + if value is None: + return None + + ndim_bytes = np.int32(value.ndim).tobytes() + shape_bytes = np.array(value.shape, dtype=np.int64).tobytes() + value_bytes = value.astype(np.float32).tobytes() + return ndim_bytes + shape_bytes + value_bytes + + def process_result_value(self, value, dialect): + if value is None: + return None + + ndim = np.frombuffer(value[:4], dtype=np.int32)[0] + value_start = 4 + 8 * ndim + shape = np.frombuffer(value[4:value_start], dtype=np.int64) + assert shape.shape == (ndim,) + return np.frombuffer(value[value_start:], dtype=np.float32).reshape( + shape + ) + + return VariableShapeNumpyArrayType + class NumpyArrayType(TypeDecorator): impl = LargeBinary @@ -158,4 +186,4 @@ class SqlSequenceAnnotation(Base): mapped_column("_point_cloud_n_points", nullable=True), ) # the bigger the better - viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None) + viewpoint_quality_score: Mapped[Optional[float]] = mapped_column() diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index 4c9d3bb5..2c74e56c 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -142,8 +142,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore run_auto_creation(self) self.frame_data_builder.path_manager = self.path_manager - # pyre-ignore - self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}") + # pyre-ignore # NOTE: sqlite-specific args (read-only mode). + self._sql_engine = sa.create_engine( + f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true" + ) sequences = self._get_filtered_sequences_if_any() diff --git a/tests/implicitron/test_orm_types.py b/tests/implicitron/test_orm_types.py index 7570b002..e6f94c01 100644 --- a/tests/implicitron/test_orm_types.py +++ b/tests/implicitron/test_orm_types.py @@ -8,7 +8,7 @@ import unittest import numpy as np -from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory +from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory class TestOrmTypes(unittest.TestCase): @@ -35,3 +35,28 @@ class TestOrmTypes(unittest.TestCase): self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0])) # we use float32 to serialise np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6) + + def test_array_serialization_none(self): + ttype = ArrayTypeFactory((3, 3))() + output = ttype.process_bind_param(None, None) + self.assertIsNone(output) + output = ttype.process_result_value(output, None) + self.assertIsNone(output) + + def test_array_serialization(self): + for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]: + input_array = np.array(input_list) + + # first, dynamic-size array + ttype = ArrayTypeFactory()() + output = ttype.process_bind_param(input_array, None) + input_hat = ttype.process_result_value(output, None) + self.assertEqual(input_hat.dtype, np.float32) + np.testing.assert_almost_equal(input_hat, input_array, decimal=6) + + # second, fixed-size array + ttype = ArrayTypeFactory(tuple(input_array.shape))() + output = ttype.process_bind_param(input_array, None) + input_hat = ttype.process_result_value(output, None) + self.assertEqual(input_hat.dtype, np.float32) + np.testing.assert_almost_equal(input_hat, input_array, decimal=6)