mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Adding SQL Dataset related files to the build script
Summary: Now that we have SQLAlchemy 2.0, we can fully use them. Reviewed By: bottler Differential Revision: D66920096 fbshipit-source-id: 25c0ea1c4f7361e66348035519627dc961b9e6e6
This commit is contained in:
parent
055ab3a2e3
commit
64a5bfadc8
@ -8,7 +8,8 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
import urllib
|
||||||
|
from dataclasses import dataclass, Field, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
@ -29,9 +30,9 @@ import sqlalchemy as sa
|
|||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.dataset.frame_data import ( # noqa
|
from pytorch3d.implicitron.dataset.frame_data import (
|
||||||
FrameData,
|
FrameData,
|
||||||
FrameDataBuilder,
|
FrameDataBuilder, # noqa
|
||||||
FrameDataBuilderBase,
|
FrameDataBuilderBase,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
@ -51,7 +52,7 @@ _SET_LISTS_TABLE: str = "set_lists"
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||||
"""
|
"""
|
||||||
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
A dataset with annotations stored as SQLite tables. This is an index-based dataset.
|
||||||
The length is returned after all sequence and frame filters are applied (see param
|
The length is returned after all sequence and frame filters are applied (see param
|
||||||
@ -125,9 +126,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
seed: int = 0
|
seed: int = 0
|
||||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||||
# we set it manually in the constructor
|
# we set it manually in the constructor
|
||||||
# _index: pd.DataFrame = field(init=False)
|
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||||
|
_sql_engine: sa.engine.Engine = field(
|
||||||
|
init=False, metadata={"omegaconf_ignore": True}
|
||||||
|
)
|
||||||
|
eval_batches: Optional[List[Any]] = field(
|
||||||
|
init=False, metadata={"omegaconf_ignore": True}
|
||||||
|
)
|
||||||
|
|
||||||
frame_data_builder: FrameDataBuilderBase
|
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
@ -138,17 +145,23 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
raise ValueError("sqlite_metadata_file must be set")
|
raise ValueError("sqlite_metadata_file must be set")
|
||||||
|
|
||||||
if self.dataset_root:
|
if self.dataset_root:
|
||||||
frame_builder_type = self.frame_data_builder_class_type
|
frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
|
||||||
getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
|
getattr(self, frame_args)["dataset_root"] = self.dataset_root
|
||||||
"dataset_root"
|
getattr(self, frame_args)["path_manager"] = self.path_manager
|
||||||
] = self.dataset_root
|
|
||||||
|
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
self.frame_data_builder.path_manager = self.path_manager
|
|
||||||
|
|
||||||
# pyre-ignore # NOTE: sqlite-specific args (read-only mode).
|
if self.path_manager is not None:
|
||||||
|
self.sqlite_metadata_file = self.path_manager.get_local_path(
|
||||||
|
self.sqlite_metadata_file
|
||||||
|
)
|
||||||
|
self.subset_lists_file = self.path_manager.get_local_path(
|
||||||
|
self.subset_lists_file
|
||||||
|
)
|
||||||
|
|
||||||
|
# NOTE: sqlite-specific args (read-only mode).
|
||||||
self._sql_engine = sa.create_engine(
|
self._sql_engine = sa.create_engine(
|
||||||
f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
|
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||||
)
|
)
|
||||||
|
|
||||||
sequences = self._get_filtered_sequences_if_any()
|
sequences = self._get_filtered_sequences_if_any()
|
||||||
@ -166,16 +179,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
if len(index) == 0:
|
if len(index) == 0:
|
||||||
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
|
||||||
|
|
||||||
self._index = index.set_index(["sequence_name", "frame_number"]) # pyre-ignore
|
self._index = index.set_index(["sequence_name", "frame_number"])
|
||||||
|
|
||||||
self.eval_batches = None # pyre-ignore
|
self.eval_batches = None
|
||||||
if self.eval_batches_file:
|
if self.eval_batches_file:
|
||||||
self.eval_batches = self._load_filter_eval_batches()
|
self.eval_batches = self._load_filter_eval_batches()
|
||||||
|
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
# pyre-ignore[16]
|
|
||||||
return len(self._index)
|
return len(self._index)
|
||||||
|
|
||||||
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
|
||||||
@ -250,7 +262,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
return frame_data
|
return frame_data
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
# pyre-ignore[16]
|
|
||||||
return f"SqlIndexDataset #frames={len(self._index)}"
|
return f"SqlIndexDataset #frames={len(self._index)}"
|
||||||
|
|
||||||
def sequence_names(self) -> Iterable[str]:
|
def sequence_names(self) -> Iterable[str]:
|
||||||
@ -335,12 +346,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
rows = self._index.index.get_loc(seq_name)
|
rows = self._index.index.get_loc(seq_name)
|
||||||
if isinstance(rows, slice):
|
if isinstance(rows, slice):
|
||||||
assert rows.stop is not None, "Unexpected result from pandas"
|
assert rows.stop is not None, "Unexpected result from pandas"
|
||||||
rows = range(rows.start or 0, rows.stop, rows.step or 1)
|
rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
|
||||||
else:
|
else:
|
||||||
rows = np.where(rows)[0]
|
rows_seq = list(np.where(rows)[0])
|
||||||
|
|
||||||
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
|
||||||
rows, seq_name, subset_filter
|
rows_seq, seq_name, subset_filter
|
||||||
)
|
)
|
||||||
index_slice["idx"] = idx
|
index_slice["idx"] = idx
|
||||||
|
|
||||||
@ -461,14 +472,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
|
||||||
|
|
||||||
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
|
||||||
assert self.subsets is not None
|
subsets = self.subsets
|
||||||
|
assert subsets is not None
|
||||||
with open(subset_lists_path, "r") as f:
|
with open(subset_lists_path, "r") as f:
|
||||||
subset_to_seq_frame = json.load(f)
|
subset_to_seq_frame = json.load(f)
|
||||||
|
|
||||||
seq_frame_list = sum(
|
seq_frame_list = sum(
|
||||||
(
|
(
|
||||||
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
[(*row, subset) for row in subset_to_seq_frame[subset]]
|
||||||
for subset in self.subsets
|
for subset in subsets
|
||||||
),
|
),
|
||||||
[],
|
[],
|
||||||
)
|
)
|
||||||
@ -522,7 +534,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.frame_annotations_type.sequence_name,
|
self.frame_annotations_type.sequence_name,
|
||||||
self.frame_annotations_type.frame_number,
|
self.frame_annotations_type.frame_number,
|
||||||
).where(self.frame_annotations_type._mask_mass == 0)
|
).where(self.frame_annotations_type._mask_mass == 0) # pyre-ignore[16]
|
||||||
with Session(self._sql_engine) as session:
|
with Session(self._sql_engine) as session:
|
||||||
to_remove = session.execute(stmt).all()
|
to_remove = session.execute(stmt).all()
|
||||||
|
|
||||||
@ -586,7 +598,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
stmt = sa.select(
|
stmt = sa.select(
|
||||||
self.frame_annotations_type.sequence_name,
|
self.frame_annotations_type.sequence_name,
|
||||||
self.frame_annotations_type.frame_number,
|
self.frame_annotations_type.frame_number,
|
||||||
self.frame_annotations_type._image_path,
|
self.frame_annotations_type._image_path, # pyre-ignore[16]
|
||||||
sa.null().label("subset"),
|
sa.null().label("subset"),
|
||||||
)
|
)
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
@ -600,7 +612,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
logger.info(" excluding samples with empty masks")
|
logger.info(" excluding samples with empty masks")
|
||||||
where_conditions.append(
|
where_conditions.append(
|
||||||
sa.or_(
|
sa.or_(
|
||||||
self.frame_annotations_type._mask_mass.is_(None),
|
self.frame_annotations_type._mask_mass.is_(None), # pyre-ignore[16]
|
||||||
self.frame_annotations_type._mask_mass != 0,
|
self.frame_annotations_type._mask_mass != 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -634,7 +646,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
assert self.eval_batches_file
|
assert self.eval_batches_file
|
||||||
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
logger.info(f"Loading eval batches from {self.eval_batches_file}")
|
||||||
|
|
||||||
if not os.path.isfile(self.eval_batches_file):
|
if (
|
||||||
|
self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
|
||||||
|
) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
|
||||||
# The batch indices file does not exist.
|
# The batch indices file does not exist.
|
||||||
# Most probably the user has not specified the root folder.
|
# Most probably the user has not specified the root folder.
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -642,7 +656,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
+ "Please specify a correct dataset_root folder."
|
+ "Please specify a correct dataset_root folder."
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(self.eval_batches_file, "r") as f:
|
eval_batches_file = self._local_path(self.eval_batches_file)
|
||||||
|
with open(eval_batches_file, "r") as f:
|
||||||
eval_batches = json.load(f)
|
eval_batches = json.load(f)
|
||||||
|
|
||||||
# limit the dataset to sequences to allow multiple evaluations in one file
|
# limit the dataset to sequences to allow multiple evaluations in one file
|
||||||
@ -758,11 +773,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
|||||||
prefixes=["TEMP"], # NOTE SQLite specific!
|
prefixes=["TEMP"], # NOTE SQLite specific!
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def pre_expand(cls) -> None:
|
||||||
|
# remove dataclass annotations that are not meant to be init params
|
||||||
|
# because they cause troubles for OmegaConf
|
||||||
|
for attr, attr_value in list(cls.__dict__.items()): # need to copy as we mutate
|
||||||
|
if isinstance(attr_value, Field) and attr_value.metadata.get(
|
||||||
|
"omegaconf_ignore", False
|
||||||
|
):
|
||||||
|
delattr(cls, attr)
|
||||||
|
del cls.__annotations__[attr]
|
||||||
|
|
||||||
|
|
||||||
def _seq_name_to_seed(seq_name) -> int:
|
def _seq_name_to_seed(seq_name) -> int:
|
||||||
"""Generates numbers in [0, 2 ** 28)"""
|
"""Generates numbers in [0, 2 ** 28)"""
|
||||||
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
|
||||||
|
|
||||||
|
|
||||||
def _safe_as_tensor(data, dtype):
|
|
||||||
return torch.tensor(data, dtype=dtype) if data is not None else None
|
|
||||||
|
@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||||
"""
|
"""
|
||||||
Generates the training, validation, and testing dataset objects for
|
Generates the training, validation, and testing dataset objects for
|
||||||
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
|
||||||
@ -193,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
|
|
||||||
# this is a mould that is never constructed, used to build self._dataset_map values
|
# this is a mould that is never constructed, used to build self._dataset_map values
|
||||||
dataset_class_type: str = "SqlIndexDataset"
|
dataset_class_type: str = "SqlIndexDataset"
|
||||||
dataset: SqlIndexDataset
|
dataset: SqlIndexDataset # pyre-ignore [13]
|
||||||
|
|
||||||
path_manager_factory: PathManagerFactory
|
path_manager_factory: PathManagerFactory # pyre-ignore [13]
|
||||||
path_manager_factory_class_type: str = "PathManagerFactory"
|
path_manager_factory_class_type: str = "PathManagerFactory"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user