mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Serialising dynamic arrays in SQL; read-only SQLite connection in SQL Dataset
Summary: 1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation. 2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific. Reviewed By: bottler Differential Revision: D46047857 fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26
This commit is contained in:
parent
ff80183fdb
commit
d2119c285f
@ -33,7 +33,35 @@ from sqlalchemy.types import TypeDecorator
|
|||||||
|
|
||||||
|
|
||||||
# these produce policies to serialize structured types to blobs
|
# these produce policies to serialize structured types to blobs
|
||||||
def ArrayTypeFactory(shape):
|
def ArrayTypeFactory(shape=None):
|
||||||
|
if shape is None:
|
||||||
|
|
||||||
|
class VariableShapeNumpyArrayType(TypeDecorator):
|
||||||
|
impl = LargeBinary
|
||||||
|
|
||||||
|
def process_bind_param(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ndim_bytes = np.int32(value.ndim).tobytes()
|
||||||
|
shape_bytes = np.array(value.shape, dtype=np.int64).tobytes()
|
||||||
|
value_bytes = value.astype(np.float32).tobytes()
|
||||||
|
return ndim_bytes + shape_bytes + value_bytes
|
||||||
|
|
||||||
|
def process_result_value(self, value, dialect):
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
ndim = np.frombuffer(value[:4], dtype=np.int32)[0]
|
||||||
|
value_start = 4 + 8 * ndim
|
||||||
|
shape = np.frombuffer(value[4:value_start], dtype=np.int64)
|
||||||
|
assert shape.shape == (ndim,)
|
||||||
|
return np.frombuffer(value[value_start:], dtype=np.float32).reshape(
|
||||||
|
shape
|
||||||
|
)
|
||||||
|
|
||||||
|
return VariableShapeNumpyArrayType
|
||||||
|
|
||||||
class NumpyArrayType(TypeDecorator):
|
class NumpyArrayType(TypeDecorator):
|
||||||
impl = LargeBinary
|
impl = LargeBinary
|
||||||
|
|
||||||
@ -158,4 +186,4 @@ class SqlSequenceAnnotation(Base):
|
|||||||
mapped_column("_point_cloud_n_points", nullable=True),
|
mapped_column("_point_cloud_n_points", nullable=True),
|
||||||
)
|
)
|
||||||
# the bigger the better
|
# the bigger the better
|
||||||
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column(default=None)
|
viewpoint_quality_score: Mapped[Optional[float]] = mapped_column()
|
||||||
|
@ -142,8 +142,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
self.frame_data_builder.path_manager = self.path_manager
|
self.frame_data_builder.path_manager = self.path_manager
|
||||||
|
|
||||||
# pyre-ignore
|
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
||||||
self._sql_engine = sa.create_engine(f"sqlite:///{self.sqlite_metadata_file}")
|
self._sql_engine = sa.create_engine(
|
||||||
|
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
||||||
|
)
|
||||||
|
|
||||||
sequences = self._get_filtered_sequences_if_any()
|
sequences = self._get_filtered_sequences_if_any()
|
||||||
|
|
||||||
|
@ -8,7 +8,7 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
|
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
|
||||||
|
|
||||||
|
|
||||||
class TestOrmTypes(unittest.TestCase):
|
class TestOrmTypes(unittest.TestCase):
|
||||||
@ -35,3 +35,28 @@ class TestOrmTypes(unittest.TestCase):
|
|||||||
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
|
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
|
||||||
# we use float32 to serialise
|
# we use float32 to serialise
|
||||||
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
||||||
|
|
||||||
|
def test_array_serialization_none(self):
|
||||||
|
ttype = ArrayTypeFactory((3, 3))()
|
||||||
|
output = ttype.process_bind_param(None, None)
|
||||||
|
self.assertIsNone(output)
|
||||||
|
output = ttype.process_result_value(output, None)
|
||||||
|
self.assertIsNone(output)
|
||||||
|
|
||||||
|
def test_array_serialization(self):
|
||||||
|
for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]:
|
||||||
|
input_array = np.array(input_list)
|
||||||
|
|
||||||
|
# first, dynamic-size array
|
||||||
|
ttype = ArrayTypeFactory()()
|
||||||
|
output = ttype.process_bind_param(input_array, None)
|
||||||
|
input_hat = ttype.process_result_value(output, None)
|
||||||
|
self.assertEqual(input_hat.dtype, np.float32)
|
||||||
|
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
|
||||||
|
|
||||||
|
# second, fixed-size array
|
||||||
|
ttype = ArrayTypeFactory(tuple(input_array.shape))()
|
||||||
|
output = ttype.process_bind_param(input_array, None)
|
||||||
|
input_hat = ttype.process_result_value(output, None)
|
||||||
|
self.assertEqual(input_hat.dtype, np.float32)
|
||||||
|
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user