Serialising dynamic arrays in SQL; read-only SQLite connection in SQL Dataset

Summary:
1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation.

2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific.

Reviewed By: bottler

Differential Revision: D46047857

fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26
This commit is contained in:
Roman Shapovalov
2023-05-22 02:24:49 -07:00
committed by Facebook GitHub Bot
parent ff80183fdb
commit d2119c285f
3 changed files with 60 additions and 5 deletions

View File

@@ -8,7 +8,7 @@ import unittest
import numpy as np
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
class TestOrmTypes(unittest.TestCase):
@@ -35,3 +35,28 @@ class TestOrmTypes(unittest.TestCase):
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
# we use float32 to serialise
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
def test_array_serialization_none(self):
ttype = ArrayTypeFactory((3, 3))()
output = ttype.process_bind_param(None, None)
self.assertIsNone(output)
output = ttype.process_result_value(output, None)
self.assertIsNone(output)
def test_array_serialization(self):
for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]:
input_array = np.array(input_list)
# first, dynamic-size array
ttype = ArrayTypeFactory()()
output = ttype.process_bind_param(input_array, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(input_hat.dtype, np.float32)
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
# second, fixed-size array
ttype = ArrayTypeFactory(tuple(input_array.shape))()
output = ttype.process_bind_param(input_array, None)
input_hat = ttype.process_result_value(output, None)
self.assertEqual(input_hat.dtype, np.float32)
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)