mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Adding SQL Dataset related files to the build script
Summary: Now that we have SQLAlchemy 2.0, we can fully use them. Reviewed By: bottler Differential Revision: D66920096 fbshipit-source-id: 25c0ea1c4f7361e66348035519627dc961b9e6e6
This commit is contained in:
		
							parent
							
								
									055ab3a2e3
								
							
						
					
					
						commit
						64a5bfadc8
					
				@ -8,7 +8,8 @@ import hashlib
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
import urllib
 | 
			
		||||
from dataclasses import dataclass, Field, field
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
@ -29,9 +30,9 @@ import sqlalchemy as sa
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import (  # noqa
 | 
			
		||||
from pytorch3d.implicitron.dataset.frame_data import (
 | 
			
		||||
    FrameData,
 | 
			
		||||
    FrameDataBuilder,
 | 
			
		||||
    FrameDataBuilder,  # noqa
 | 
			
		||||
    FrameDataBuilderBase,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
@ -51,7 +52,7 @@ _SET_LISTS_TABLE: str = "set_lists"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
    A dataset with annotations stored as SQLite tables. This is an index-based dataset.
 | 
			
		||||
    The length is returned after all sequence and frame filters are applied (see param
 | 
			
		||||
@ -125,9 +126,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
    seed: int = 0
 | 
			
		||||
    remove_empty_masks_poll_whole_table_threshold: int = 300_000
 | 
			
		||||
    # we set it manually in the constructor
 | 
			
		||||
    # _index: pd.DataFrame = field(init=False)
 | 
			
		||||
    _index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
 | 
			
		||||
    _sql_engine: sa.engine.Engine = field(
 | 
			
		||||
        init=False, metadata={"omegaconf_ignore": True}
 | 
			
		||||
    )
 | 
			
		||||
    eval_batches: Optional[List[Any]] = field(
 | 
			
		||||
        init=False, metadata={"omegaconf_ignore": True}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    frame_data_builder: FrameDataBuilderBase
 | 
			
		||||
    frame_data_builder: FrameDataBuilderBase  # pyre-ignore[13]
 | 
			
		||||
    frame_data_builder_class_type: str = "FrameDataBuilder"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
@ -138,17 +145,23 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
            raise ValueError("sqlite_metadata_file must be set")
 | 
			
		||||
 | 
			
		||||
        if self.dataset_root:
 | 
			
		||||
            frame_builder_type = self.frame_data_builder_class_type
 | 
			
		||||
            getattr(self, f"frame_data_builder_{frame_builder_type}_args")[
 | 
			
		||||
                "dataset_root"
 | 
			
		||||
            ] = self.dataset_root
 | 
			
		||||
            frame_args = f"frame_data_builder_{self.frame_data_builder_class_type}_args"
 | 
			
		||||
            getattr(self, frame_args)["dataset_root"] = self.dataset_root
 | 
			
		||||
            getattr(self, frame_args)["path_manager"] = self.path_manager
 | 
			
		||||
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        self.frame_data_builder.path_manager = self.path_manager
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore  # NOTE: sqlite-specific args (read-only mode).
 | 
			
		||||
        if self.path_manager is not None:
 | 
			
		||||
            self.sqlite_metadata_file = self.path_manager.get_local_path(
 | 
			
		||||
                self.sqlite_metadata_file
 | 
			
		||||
            )
 | 
			
		||||
            self.subset_lists_file = self.path_manager.get_local_path(
 | 
			
		||||
                self.subset_lists_file
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # NOTE: sqlite-specific args (read-only mode).
 | 
			
		||||
        self._sql_engine = sa.create_engine(
 | 
			
		||||
            f"sqlite:///file:{self.sqlite_metadata_file}?mode=ro&uri=true"
 | 
			
		||||
            f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        sequences = self._get_filtered_sequences_if_any()
 | 
			
		||||
@ -166,16 +179,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        if len(index) == 0:
 | 
			
		||||
            raise ValueError(f"There are no frames in the subsets: {self.subsets}!")
 | 
			
		||||
 | 
			
		||||
        self._index = index.set_index(["sequence_name", "frame_number"])  # pyre-ignore
 | 
			
		||||
        self._index = index.set_index(["sequence_name", "frame_number"])
 | 
			
		||||
 | 
			
		||||
        self.eval_batches = None  # pyre-ignore
 | 
			
		||||
        self.eval_batches = None
 | 
			
		||||
        if self.eval_batches_file:
 | 
			
		||||
            self.eval_batches = self._load_filter_eval_batches()
 | 
			
		||||
 | 
			
		||||
        logger.info(str(self))
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return len(self._index)
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, frame_idx: Union[int, Tuple[str, int]]) -> FrameData:
 | 
			
		||||
@ -250,7 +262,6 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        return frame_data
 | 
			
		||||
 | 
			
		||||
    def __str__(self) -> str:
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return f"SqlIndexDataset #frames={len(self._index)}"
 | 
			
		||||
 | 
			
		||||
    def sequence_names(self) -> Iterable[str]:
 | 
			
		||||
@ -335,12 +346,12 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        rows = self._index.index.get_loc(seq_name)
 | 
			
		||||
        if isinstance(rows, slice):
 | 
			
		||||
            assert rows.stop is not None, "Unexpected result from pandas"
 | 
			
		||||
            rows = range(rows.start or 0, rows.stop, rows.step or 1)
 | 
			
		||||
            rows_seq = range(rows.start or 0, rows.stop, rows.step or 1)
 | 
			
		||||
        else:
 | 
			
		||||
            rows = np.where(rows)[0]
 | 
			
		||||
            rows_seq = list(np.where(rows)[0])
 | 
			
		||||
 | 
			
		||||
        index_slice, idx = self._get_frame_no_coalesced_ts_by_row_indices(
 | 
			
		||||
            rows, seq_name, subset_filter
 | 
			
		||||
            rows_seq, seq_name, subset_filter
 | 
			
		||||
        )
 | 
			
		||||
        index_slice["idx"] = idx
 | 
			
		||||
 | 
			
		||||
@ -461,14 +472,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        return [SqlSequenceAnnotation.sequence_name.notin_(self.exclude_sequences)]
 | 
			
		||||
 | 
			
		||||
    def _load_subsets_from_json(self, subset_lists_path: str) -> pd.DataFrame:
 | 
			
		||||
        assert self.subsets is not None
 | 
			
		||||
        subsets = self.subsets
 | 
			
		||||
        assert subsets is not None
 | 
			
		||||
        with open(subset_lists_path, "r") as f:
 | 
			
		||||
            subset_to_seq_frame = json.load(f)
 | 
			
		||||
 | 
			
		||||
        seq_frame_list = sum(
 | 
			
		||||
            (
 | 
			
		||||
                [(*row, subset) for row in subset_to_seq_frame[subset]]
 | 
			
		||||
                for subset in self.subsets
 | 
			
		||||
                for subset in subsets
 | 
			
		||||
            ),
 | 
			
		||||
            [],
 | 
			
		||||
        )
 | 
			
		||||
@ -522,7 +534,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
                stmt = sa.select(
 | 
			
		||||
                    self.frame_annotations_type.sequence_name,
 | 
			
		||||
                    self.frame_annotations_type.frame_number,
 | 
			
		||||
                ).where(self.frame_annotations_type._mask_mass == 0)
 | 
			
		||||
                ).where(self.frame_annotations_type._mask_mass == 0)  # pyre-ignore[16]
 | 
			
		||||
                with Session(self._sql_engine) as session:
 | 
			
		||||
                    to_remove = session.execute(stmt).all()
 | 
			
		||||
 | 
			
		||||
@ -586,7 +598,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        stmt = sa.select(
 | 
			
		||||
            self.frame_annotations_type.sequence_name,
 | 
			
		||||
            self.frame_annotations_type.frame_number,
 | 
			
		||||
            self.frame_annotations_type._image_path,
 | 
			
		||||
            self.frame_annotations_type._image_path,  # pyre-ignore[16]
 | 
			
		||||
            sa.null().label("subset"),
 | 
			
		||||
        )
 | 
			
		||||
        where_conditions = []
 | 
			
		||||
@ -600,7 +612,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
            logger.info("  excluding samples with empty masks")
 | 
			
		||||
            where_conditions.append(
 | 
			
		||||
                sa.or_(
 | 
			
		||||
                    self.frame_annotations_type._mask_mass.is_(None),
 | 
			
		||||
                    self.frame_annotations_type._mask_mass.is_(None),  # pyre-ignore[16]
 | 
			
		||||
                    self.frame_annotations_type._mask_mass != 0,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
@ -634,7 +646,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
        assert self.eval_batches_file
 | 
			
		||||
        logger.info(f"Loading eval batches from {self.eval_batches_file}")
 | 
			
		||||
 | 
			
		||||
        if not os.path.isfile(self.eval_batches_file):
 | 
			
		||||
        if (
 | 
			
		||||
            self.path_manager and not self.path_manager.isfile(self.eval_batches_file)
 | 
			
		||||
        ) or (not self.path_manager and not os.path.isfile(self.eval_batches_file)):
 | 
			
		||||
            # The batch indices file does not exist.
 | 
			
		||||
            # Most probably the user has not specified the root folder.
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -642,7 +656,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
                + "Please specify a correct dataset_root folder."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        with open(self.eval_batches_file, "r") as f:
 | 
			
		||||
        eval_batches_file = self._local_path(self.eval_batches_file)
 | 
			
		||||
        with open(eval_batches_file, "r") as f:
 | 
			
		||||
            eval_batches = json.load(f)
 | 
			
		||||
 | 
			
		||||
        # limit the dataset to sequences to allow multiple evaluations in one file
 | 
			
		||||
@ -758,11 +773,18 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
            prefixes=["TEMP"],  # NOTE SQLite specific!
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def pre_expand(cls) -> None:
 | 
			
		||||
        # remove dataclass annotations that are not meant to be init params
 | 
			
		||||
        # because they cause troubles for OmegaConf
 | 
			
		||||
        for attr, attr_value in list(cls.__dict__.items()):  # need to copy as we mutate
 | 
			
		||||
            if isinstance(attr_value, Field) and attr_value.metadata.get(
 | 
			
		||||
                "omegaconf_ignore", False
 | 
			
		||||
            ):
 | 
			
		||||
                delattr(cls, attr)
 | 
			
		||||
                del cls.__annotations__[attr]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _seq_name_to_seed(seq_name) -> int:
 | 
			
		||||
    """Generates numbers in [0, 2 ** 28)"""
 | 
			
		||||
    return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest()[:7], 16)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _safe_as_tensor(data, dtype):
 | 
			
		||||
    return torch.tensor(data, dtype=dtype) if data is not None else None
 | 
			
		||||
 | 
			
		||||
@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
 | 
			
		||||
    """
 | 
			
		||||
    Generates the training, validation, and testing dataset objects for
 | 
			
		||||
    a dataset laid out on disk like SQL-CO3D, with annotations in an SQLite data base.
 | 
			
		||||
@ -193,9 +193,9 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
 | 
			
		||||
    # this is a mould that is never constructed, used to build self._dataset_map values
 | 
			
		||||
    dataset_class_type: str = "SqlIndexDataset"
 | 
			
		||||
    dataset: SqlIndexDataset
 | 
			
		||||
    dataset: SqlIndexDataset  # pyre-ignore [13]
 | 
			
		||||
 | 
			
		||||
    path_manager_factory: PathManagerFactory
 | 
			
		||||
    path_manager_factory: PathManagerFactory  # pyre-ignore [13]
 | 
			
		||||
    path_manager_factory_class_type: str = "PathManagerFactory"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user