mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Serialising dynamic arrays in SQL; read-only SQLite connection in SQL Dataset
Summary: 1. We may need to store arrays of unknown shape in the database. It implements and tests serialisation. 2. Previously, when an inexisting metadata file was passed to SqlIndexDataset, it would try to open it and create an empty file, then crash. We now open the file in a read-only mode, so the error message is more intuitive. Note that the implementation is SQLite specific. Reviewed By: bottler Differential Revision: D46047857 fbshipit-source-id: 3064ae4f8122b4fc24ad3d6ab696572ebe8d0c26
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ff80183fdb
commit
d2119c285f
@@ -8,7 +8,7 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pytorch3d.implicitron.dataset.orm_types import TupleTypeFactory
|
||||
from pytorch3d.implicitron.dataset.orm_types import ArrayTypeFactory, TupleTypeFactory
|
||||
|
||||
|
||||
class TestOrmTypes(unittest.TestCase):
|
||||
@@ -35,3 +35,28 @@ class TestOrmTypes(unittest.TestCase):
|
||||
self.assertEqual(type(input_hat[0][0]), type(input_tuple[0][0]))
|
||||
# we use float32 to serialise
|
||||
np.testing.assert_almost_equal(input_hat, input_tuple, decimal=6)
|
||||
|
||||
def test_array_serialization_none(self):
|
||||
ttype = ArrayTypeFactory((3, 3))()
|
||||
output = ttype.process_bind_param(None, None)
|
||||
self.assertIsNone(output)
|
||||
output = ttype.process_result_value(output, None)
|
||||
self.assertIsNone(output)
|
||||
|
||||
def test_array_serialization(self):
|
||||
for input_list in [[1, 2, 3], [[4.5, 6.7], [8.9, 10.0]]]:
|
||||
input_array = np.array(input_list)
|
||||
|
||||
# first, dynamic-size array
|
||||
ttype = ArrayTypeFactory()()
|
||||
output = ttype.process_bind_param(input_array, None)
|
||||
input_hat = ttype.process_result_value(output, None)
|
||||
self.assertEqual(input_hat.dtype, np.float32)
|
||||
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
|
||||
|
||||
# second, fixed-size array
|
||||
ttype = ArrayTypeFactory(tuple(input_array.shape))()
|
||||
output = ttype.process_bind_param(input_array, None)
|
||||
input_hat = ttype.process_result_value(output, None)
|
||||
self.assertEqual(input_hat.dtype, np.float32)
|
||||
np.testing.assert_almost_equal(input_hat, input_array, decimal=6)
|
||||
|
||||
Reference in New Issue
Block a user