mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 23:06:04 +08:00
Updates to Implicitron dataset, metrics and tools
Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack) Reviewed By: shapovalov Differential Revision: D65942513 fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
42a4a7d432
commit
43cd681d4f
@@ -10,6 +10,7 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import urllib
|
||||
from dataclasses import dataclass, Field, field
|
||||
from typing import (
|
||||
@@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import (
|
||||
FrameDataBuilder, # noqa
|
||||
FrameDataBuilderBase,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||
|
||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||
|
||||
@@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
engine verbatim. Don’t expose it to end users of your application!
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||
sequences within each category (applies after all other sequence filters
|
||||
@@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
more frames than that; applied after other frame-level filters.
|
||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||
random frames per sequence.
|
||||
preload_metadata: If True, the metadata is preloaded into memory.
|
||||
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||
scoped_session: If True, allows different parts of the code to share
|
||||
a global session to access the database.
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||
@@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
pick_categories: Tuple[str, ...] = ()
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
pick_sequences_sql_clause: Optional[str] = None
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_per_category_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
@@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||
preload_metadata: bool = False
|
||||
precompute_seq_to_idx: bool = False
|
||||
# we set it manually in the constructor
|
||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||
_sql_engine: sa.engine.Engine = field(
|
||||
@@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||
|
||||
scoped_session: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if sa.__version__ < "2.0":
|
||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||
@@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||
)
|
||||
|
||||
if self.preload_metadata:
|
||||
self._sql_engine = self._preload_database(self._sql_engine)
|
||||
|
||||
sequences = self._get_filtered_sequences_if_any()
|
||||
|
||||
if self.subsets:
|
||||
@@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
if self.scoped_session:
|
||||
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||
|
||||
if self.precompute_seq_to_idx:
|
||||
# This is deprecated and will be removed in the future.
|
||||
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||
logger.warning(
|
||||
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||
)
|
||||
self._index["rowid"] = np.arange(len(self._index))
|
||||
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||
del self._index["rowid"]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._index)
|
||||
|
||||
@@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||
self.sequence_annotations_type.sequence_name == seq
|
||||
)
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
if self.scoped_session:
|
||||
# pyre-ignore
|
||||
with scoped_session(self._session_factory)() as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
else:
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
|
||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||
|
||||
@@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
|
||||
yield from index_slice.itertuples(index=False)
|
||||
|
||||
# override
|
||||
def sequence_indices_in_order(
|
||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||
) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
if self.precompute_seq_to_idx and subset_filter is None:
|
||||
# pyre-ignore
|
||||
yield from self._seq_to_indices[seq_name]
|
||||
else:
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||
yield idx
|
||||
|
||||
# override
|
||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||
"""
|
||||
@@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
or self.limit_sequences_to > 0
|
||||
or self.limit_sequences_per_category_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or self.pick_sequences_sql_clause is not None
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def _preload_database(
|
||||
self, source_engine: sa.engine.base.Engine
|
||||
) -> sa.engine.base.Engine:
|
||||
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||
metadata = sa.MetaData()
|
||||
metadata.reflect(bind=source_engine)
|
||||
metadata.create_all(bind=destination_engine)
|
||||
|
||||
with source_engine.connect() as source_conn:
|
||||
with destination_engine.connect() as destination_conn:
|
||||
for table_obj in metadata.tables.values():
|
||||
# Select all rows from the source table
|
||||
source_rows = source_conn.execute(table_obj.select())
|
||||
|
||||
# Insert rows into the destination table
|
||||
for row in source_rows:
|
||||
destination_conn.execute(table_obj.insert().values(row))
|
||||
|
||||
# Commit the changes for each table
|
||||
destination_conn.commit()
|
||||
|
||||
return destination_engine
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||
# WHERE category IN 'self.pick_categories'
|
||||
@@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if self.pick_sequences_sql_clause:
|
||||
print("Applying the custom SQL clause.")
|
||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
@@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.frame_annotations_type.sequence_name == seq_name,
|
||||
self.frame_annotations_type.frame_number.in_(frames),
|
||||
)
|
||||
frame_no_ts = None
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
if self.scoped_session:
|
||||
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||
else:
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
|
||||
if len(frame_no_ts) != len(index_slice):
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user