Serialising dynamic arrays in SQL; read-only SQLite connection in SQL Dataset

Summary:
1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation.

2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific.

Reviewed By: bottler

Differential Revision: D46047857

fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26
This commit is contained in:
Roman Shapovalov 2023-05-22 02:24:49 -07:00 committed by Facebook GitHub Bot
parent ff80183fdb
commit d2119c285f
3 changed files with 60 additions and 5 deletions

View File

@ -33,7 +33,35 @@ from sqlalchemy.types import TypeDecorator
# these produce policies to serialize structured types to blobs # these produce policies to serialize structured types to blobs
def ArrayTypeFactory(shape): def ArrayTypeFactory(shape=None):
if shape is None:
class VariableShapeNumpyArrayType(TypeDecorator):
impl = LargeBinary
def process_bind_param(self, value, dialect):
if value is None:
return None
ndim_bytes = np.int32(value.ndim).tobytes()
shape_bytes = np.array(value.shape, dtype=np.int64).tobytes()
value_bytes = value.astype(np.float32).tobytes()
return ndim_bytes + shape_bytes + value_bytes
def process_result_value(self, value, dialect):
if value is None:
return None
ndim = np.frombuffer(value[:4], dtype=np.int32)[0]
value_start = 4 + 8 * ndim
shape = np.frombuffer(value[4:value_start], dtype=np.int64)
assert shape.shape == (ndim,)
return np.frombuffer(value[value_start:], dtype=np.float32).reshape(
shape
)
return VariableShapeNumpyArrayType
class NumpyArrayType(TypeDecorator): class NumpyArrayType(TypeDecorator):
impl = LargeBinary impl = LargeBinary
@ -158,4 +186,4 @@ class SqlSequenceAnnotation(Base):
mapped_column("_point_cloud_n_points", nullable=True), mapped_column("_point_cloud_n_points", nullable=True),
) )
# the bigger the better # the bigger the better
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None) viewpoint_quality_score: Mapped[Optional[float]] = mapped_column()

View File

@ -142,8 +142,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
run_auto_creation(self) run_auto_creation(self)
self.frame_data_builder.path_manager = self.path_manager self.frame_data_builder.path_manager = self.path_manager
# pyre-ignore # pyre-ignore # NOTE: sqlite-specific args (read-only mode).
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}") self._sql_engine = sa.create_engine(
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
)
sequences = self._get_filtered_sequences_if_any() sequences = self._get_filtered_sequences_if_any()

View File

@ -8,7 +8,7 @@ import unittest
import numpy as np import numpy as np
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
class TestOrmTypes(unittest.TestCase): class TestOrmTypes(unittest.TestCase):
@ -35,3 +35,28 @@ class TestOrmTypes(unittest.TestCase):
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0])) self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
# we use float32 to serialise # we use float32 to serialise
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6) np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
def test_array_serialization_none(self):
ttype = ArrayTypeFactory((3, 3))()
output = ttype.process_bind_param(None, None)
self.assertIsNone(output)
output = ttype.process_result_value(output, None)
self.assertIsNone(output)
def test_array_serialization(self):
for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]:
input_array = np.array(input_list)
# first, dynamic-size array
ttype = ArrayTypeFactory()()
output = ttype.process_bind_param(input_array, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(input_hat.dtype, np.float32)
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
# second, fixed-size array
ttype = ArrayTypeFactory(tuple(input_array.shape))()
output = ttype.process_bind_param(input_array, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(input_hat.dtype, np.float32)
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)