From 43cd681d4fc55fefa9a47e70d3ca0a02818ce44a Mon Sep 17 00:00:00 2001 From: Antoine Toisoul Date: Mon, 27 Jan 2025 09:43:42 -0800 Subject: [PATCH] Updates to Implicitron dataset, metrics and tools Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack) Reviewed By: shapovalov Differential Revision: D65942513 fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f --- pytorch3d/implicitron/dataset/frame_data.py | 9 +- pytorch3d/implicitron/dataset/sql_dataset.py | 94 +++++++++++++++++-- .../dataset/sql_dataset_provider.py | 10 +- pytorch3d/implicitron/dataset/utils.py | 19 +++- pytorch3d/implicitron/models/metrics.py | 42 +++++++-- pytorch3d/implicitron/tools/metric_utils.py | 29 +++++- pytorch3d/implicitron/tools/video_writer.py | 94 +++++++++++++------ 7 files changed, 240 insertions(+), 57 deletions(-) diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index 137b6324..ed88c0f8 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -48,6 +48,7 @@ from pytorch3d.implicitron.dataset.utils import ( from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation @@ -158,7 +159,7 @@ class FrameData(Mapping[str, Any]): new_params = {} for field_name in iter(self): value = getattr(self, field_name) - if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): + if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)): new_params[field_name] = value.to(*args, **kwargs) else: new_params[field_name] = value @@ -420,7 +421,6 @@ class FrameData(Mapping[str, Any]): for f in fields(elem): if not f.init: continue - list_values = override_fields.get( f.name, [getattr(d, f.name) for d in batch] ) @@ -429,7 +429,7 @@ class FrameData(Mapping[str, Any]): if all(list_value is not None for list_value in list_values) else None ) - return cls(**collated) + return type(elem)(**collated) elif isinstance(elem, Pointclouds): return join_pointclouds_as_batch(batch) @@ -437,6 +437,8 @@ class FrameData(Mapping[str, Any]): elif isinstance(elem, CamerasBase): # TODO: don't store K; enforce working in NDC space return join_cameras_as_batch(batch) + elif isinstance(elem, Meshes): + return join_meshes_as_batch(batch) else: return torch.utils.data.dataloader.default_collate(batch) @@ -592,6 +594,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): fg_mask_np: np.ndarray | None = None bbox_xywh: tuple[float, float, float, float] | None = None mask_annotation = frame_annotation.mask + if mask_annotation is not None: if load_blobs and self.load_masks: fg_mask_np, mask_path = self._load_fg_probability(frame_annotation) diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index ea42de43..4062cfc4 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -10,6 +10,7 @@ import hashlib import json import logging import os + import urllib from dataclasses import dataclass, Field, field from typing import ( @@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import ( FrameDataBuilder, # noqa FrameDataBuilderBase, ) + from pytorch3d.implicitron.tools.config import ( registry, ReplaceableBase, run_auto_creation, ) -from sqlalchemy.orm import Session +from sqlalchemy.orm import scoped_session, Session, sessionmaker from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation @@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): engine verbatim. Don’t expose it to end users of your application! pick_categories: Restrict the dataset to the given list of categories. pick_sequences: A Sequence of sequence names to restrict the dataset to. + pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations. exclude_sequences: A Sequence of the names of the sequences to exclude. limit_sequences_per_category_to: Limit the dataset to the first up to N sequences within each category (applies after all other sequence filters @@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): more frames than that; applied after other frame-level filters. seed: The seed of the random generator sampling `n_frames_per_sequence` random frames per sequence. + preload_metadata: If True, the metadata is preloaded into memory. + precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices. + scoped_session: If True, allows different parts of the code to share + a global session to access the database. """ frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation @@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): pick_categories: Tuple[str, ...] = () pick_sequences: Tuple[str, ...] = () + pick_sequences_sql_clause: Optional[str] = None exclude_sequences: Tuple[str, ...] = () limit_sequences_per_category_to: int = 0 limit_sequences_to: int = 0 @@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): n_frames_per_sequence: int = -1 seed: int = 0 remove_empty_masks_poll_whole_table_threshold: int = 300_000 + preload_metadata: bool = False + precompute_seq_to_idx: bool = False # we set it manually in the constructor _index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True}) _sql_engine: sa.engine.Engine = field( @@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): frame_data_builder: FrameDataBuilderBase # pyre-ignore[13] frame_data_builder_class_type: str = "FrameDataBuilder" + scoped_session: bool = False + def __post_init__(self) -> None: if sa.__version__ < "2.0": raise ImportError("This class requires SQL Alchemy 2.0 or later") @@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true" ) + if self.preload_metadata: + self._sql_engine = self._preload_database(self._sql_engine) + sequences = self._get_filtered_sequences_if_any() if self.subsets: @@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): logger.info(str(self)) + if self.scoped_session: + self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore + + if self.precompute_seq_to_idx: + # This is deprecated and will be removed in the future. + # After we backport https://github.com/facebookresearch/uco3d/pull/3 + logger.warning( + "Using precompute_seq_to_idx is deprecated and will be removed in the future." + ) + self._index["rowid"] = np.arange(len(self._index)) + groupby = self._index.groupby("sequence_name", sort=False)["rowid"] + self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore + del self._index["rowid"] + def __len__(self) -> int: return len(self._index) @@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): seq_stmt = sa.select(self.sequence_annotations_type).where( self.sequence_annotations_type.sequence_name == seq ) - with Session(self._sql_engine) as session: - entry = session.scalars(stmt).one() - seq_metadata = session.scalars(seq_stmt).one() + if self.scoped_session: + # pyre-ignore + with scoped_session(self._session_factory)() as session: + entry = session.scalars(stmt).one() + seq_metadata = session.scalars(seq_stmt).one() + else: + with Session(self._sql_engine) as session: + entry = session.scalars(stmt).one() + seq_metadata = session.scalars(seq_stmt).one() assert entry.image.path == self._index.loc[(seq, frame), "_image_path"] @@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): yield from index_slice.itertuples(index=False) + # override + def sequence_indices_in_order( + self, seq_name: str, subset_filter: Optional[Sequence[str]] = None + ) -> Iterator[int]: + """Same as `sequence_frames_in_order` but returns the iterator over + only dataset indices. + """ + if self.precompute_seq_to_idx and subset_filter is None: + # pyre-ignore + yield from self._seq_to_indices[seq_name] + else: + for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter): + yield idx + # override def get_eval_batches(self) -> Optional[List[Any]]: """ @@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): or self.limit_sequences_to > 0 or self.limit_sequences_per_category_to > 0 or len(self.pick_sequences) > 0 + or self.pick_sequences_sql_clause is not None or len(self.exclude_sequences) > 0 or len(self.pick_categories) > 0 or self.n_frames_per_sequence > 0 ) + def _preload_database( + self, source_engine: sa.engine.base.Engine + ) -> sa.engine.base.Engine: + destination_engine = sa.create_engine("sqlite:///:memory:") + metadata = sa.MetaData() + metadata.reflect(bind=source_engine) + metadata.create_all(bind=destination_engine) + + with source_engine.connect() as source_conn: + with destination_engine.connect() as destination_conn: + for table_obj in metadata.tables.values(): + # Select all rows from the source table + source_rows = source_conn.execute(table_obj.select()) + + # Insert rows into the destination table + for row in source_rows: + destination_conn.execute(table_obj.insert().values(row)) + + # Commit the changes for each table + destination_conn.commit() + + return destination_engine + def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]: # maximum possible filter (if limit_sequences_per_category_to == 0): # WHERE category IN 'self.pick_categories' @@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): *self._get_pick_filters(), *self._get_exclude_filters(), ] + if self.pick_sequences_sql_clause: + print("Applying the custom SQL clause.") + where_conditions.append(sa.text(self.pick_sequences_sql_clause)) def add_where(stmt): return stmt.where(*where_conditions) if where_conditions else stmt @@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): self.frame_annotations_type.sequence_name == seq_name, self.frame_annotations_type.frame_number.in_(frames), ) + frame_no_ts = None - with self._sql_engine.connect() as connection: - frame_no_ts = pd.read_sql_query(stmt, connection) + if self.scoped_session: + stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True})) + with scoped_session(self._session_factory)() as session: # pyre-ignore + frame_no_ts = pd.read_sql_query(stmt_text, session.connection()) + else: + with self._sql_engine.connect() as connection: + frame_no_ts = pd.read_sql_query(stmt, connection) if len(frame_no_ts) != len(index_slice): raise ValueError( diff --git a/pytorch3d/implicitron/dataset/sql_dataset_provider.py b/pytorch3d/implicitron/dataset/sql_dataset_provider.py index a03616cb..abc9a9c9 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset_provider.py +++ b/pytorch3d/implicitron/dataset/sql_dataset_provider.py @@ -284,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase): logger.info(f"Val dataset: {str(val_dataset)}") logger.debug("Extracting test dataset.") - eval_batches_file = self._get_lists_file("eval_batches") - del common_dataset_kwargs["eval_batches_file"] + if self.eval_batches_path is None: + eval_batches_file = None + else: + eval_batches_file = self._get_lists_file("eval_batches") + + if "eval_batches_file" in common_dataset_kwargs: + common_dataset_kwargs.pop("eval_batches_file", None) + test_dataset = dataset_type( **common_dataset_kwargs, subsets=self._get_subsets(self.test_subsets, True), diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index 3e8ba35b..be70973b 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -211,14 +211,21 @@ def resize_image( if isinstance(image, np.ndarray): image = torch.from_numpy(image) - if image_height is None or image_width is None: + if ( + image_height is None + or image_width is None + or image.shape[-2] == 0 + or image.shape[-1] == 0 + ): # skip the resizing return image, 1.0, torch.ones_like(image[:1]) + # takes numpy array or tensor, returns pytorch tensor minscale = min( image_height / image.shape[-2], image_width / image.shape[-1], ) + imre = torch.nn.functional.interpolate( image[None], scale_factor=minscale, @@ -226,6 +233,7 @@ def resize_image( align_corners=False if mode == "bilinear" else None, recompute_scale_factor=True, )[0] + imre_ = torch.zeros(image.shape[0], image_height, image_width) imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre mask = torch.zeros(1, image_height, image_width) @@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray: return im.astype(np.float32) / 255.0 -def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray: +def load_image( + path: str, try_read_alpha: bool = False, pil_format: str = "RGB" +) -> np.ndarray: """ Load an image from a path and return it as a numpy array. If try_read_alpha is True, the image is read as RGBA and the alpha channel is returned as the fourth channel. Otherwise, the image is read as RGB and a three-channel image is returned. """ - with Image.open(path) as pil_im: # Check if the image has an alpha channel if try_read_alpha and pil_im.mode == "RGBA": im = np.array(pil_im) else: - im = np.array(pil_im.convert("RGB")) + im = np.array(pil_im.convert(pil_format)) return transpose_normalize_image(im) @@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_( ) camera.focal_length = focal_length_scaled[None] # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`. - camera.principal_point = principal_point_scaled[None] + camera.principal_point = principal_point_scaled[None] # pyre-ignore[16] # NOTE this cache is per-worker; they are implemented as processes. diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py index 0e1a3b62..9555b62f 100644 --- a/pytorch3d/implicitron/models/metrics.py +++ b/pytorch3d/implicitron/models/metrics.py @@ -6,7 +6,6 @@ # pyre-unsafe - import warnings from typing import Any, Dict, Optional @@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase): _rgb_metrics( image_rgb, image_rgb_pred, - fg_probability, - fg_probability_pred, - mask_crop, + masks=fg_probability, + masks_crop=mask_crop, ) ) @@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase): metrics["mask_neg_iou"] = utils.neg_iou_loss( fg_probability_pred, fg_probability, mask=mask_crop ) - metrics["mask_bce"] = utils.calc_bce( - fg_probability_pred, fg_probability, mask=mask_crop - ) + if torch.is_autocast_enabled(): + # To avoid issues with mixed precision + metrics["mask_bce"] = utils.calc_bce( + fg_probability_pred.logit(), + fg_probability, + mask=mask_crop, + pred_logits=True, + ) + else: + metrics["mask_bce"] = utils.calc_bce( + fg_probability_pred, + fg_probability, + mask=mask_crop, + pred_logits=False, + ) if depth_map is not None and depth_map_pred is not None: assert mask_crop is not None @@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase): if fg_probability is not None: mask = fg_probability * mask_crop _, abs_ = utils.eval_depth( - depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0 + depth_map_pred, + depth_map, + get_best_scale=True, + mask=mask, + crop=0, ) metrics["depth_abs_fg"] = abs_.mean() @@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase): return metrics -def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop): +def _rgb_metrics( + images, + images_pred, + masks=None, + masks_crop=None, + huber_scaling: float = 0.03, +): assert masks_crop is not None if images.shape[1] != images_pred.shape[1]: raise ValueError( f"Network output's RGB images had {images_pred.shape[1]} " f"channels. {images.shape[1]} expected." ) + rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True) rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) - rgb_loss = utils.huber(rgb_squared, scaling=0.03) + rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling) crop_mass = masks_crop.sum().clamp(1.0) results = { "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, + "rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass, "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop), } diff --git a/pytorch3d/implicitron/tools/metric_utils.py b/pytorch3d/implicitron/tools/metric_utils.py index d4debeef..bdb8cfaf 100644 --- a/pytorch3d/implicitron/tools/metric_utils.py +++ b/pytorch3d/implicitron/tools/metric_utils.py @@ -6,12 +6,15 @@ # pyre-unsafe +import logging import math from typing import Optional, Tuple import torch from torch.nn import functional as F +logger = logging.getLogger(__name__) + def eval_depth( pred: torch.Tensor, @@ -21,6 +24,8 @@ def eval_depth( get_best_scale: bool = True, mask_thr: float = 0.5, best_scale_clamp_thr: float = 1e-4, + use_disparity: bool = False, + disparity_eps: float = 1e-4, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Evaluate the depth error between the prediction `pred` and the ground @@ -64,6 +69,13 @@ def eval_depth( # s.t. we get best possible mse error scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr) pred = pred * scale_best[:, None, None, None] + if use_disparity: + gt = torch.div(1.0, (gt + disparity_eps)) + pred = torch.div(1.0, (pred + disparity_eps)) + scale_best = estimate_depth_scale_factor( + pred, gt, dmask, best_scale_clamp_thr + ).detach() + pred = pred * scale_best[:, None, None, None] df = gt - pred @@ -117,6 +129,7 @@ def calc_bce( pred_eps: float = 0.01, mask: Optional[torch.Tensor] = None, lerp_bound: Optional[float] = None, + pred_logits: bool = False, ) -> torch.Tensor: """ Calculates the binary cross entropy. @@ -139,9 +152,23 @@ def calc_bce( weight = torch.ones_like(gt) * mask if lerp_bound is not None: + # binary_cross_entropy_lerp requires pred to be in [0, 1] + if pred_logits: + pred = F.sigmoid(pred) + return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound) else: - return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight) + if pred_logits: + loss = F.binary_cross_entropy_with_logits( + pred, + gt, + reduction="none", + weight=weight, + ) + else: + loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight) + + return loss.mean() def binary_cross_entropy_lerp( diff --git a/pytorch3d/implicitron/tools/video_writer.py b/pytorch3d/implicitron/tools/video_writer.py index 86764071..4969466a 100644 --- a/pytorch3d/implicitron/tools/video_writer.py +++ b/pytorch3d/implicitron/tools/video_writer.py @@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union import matplotlib import matplotlib.pyplot as plt import numpy as np +import torch + from PIL import Image +_NO_TORCHVISION = False +try: + import torchvision +except ImportError: + _NO_TORCHVISION = True + + _DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg") matplotlib.use("Agg") @@ -36,6 +45,7 @@ class VideoWriter: fps: int = 20, output_format: str = "visdom", rmdir_allowed: bool = False, + use_torchvision_video_writer: bool = False, **kwargs, ) -> None: """ @@ -49,6 +59,8 @@ class VideoWriter: is supported. rmdir_allowed: If `True` delete and create `cache_dir` in case it is not empty. + use_torchvision_video_writer: If `True` use `torchvision.io.write_video` + to write the video """ self.rmdir_allowed = rmdir_allowed self.output_format = output_format @@ -56,10 +68,14 @@ class VideoWriter: self.out_path = out_path self.cache_dir = cache_dir self.ffmpeg_bin = ffmpeg_bin + self.use_torchvision_video_writer = use_torchvision_video_writer self.frames = [] self.regexp = "frame_%08d.png" self.frame_num = 0 + if self.use_torchvision_video_writer: + assert not _NO_TORCHVISION, "torchvision not available" + if self.cache_dir is not None: self.tmp_dir = None if os.path.isdir(self.cache_dir): @@ -114,7 +130,7 @@ class VideoWriter: resize = im.size # make sure size is divisible by 2 resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)]) - # pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`. + im = im.resize(resize, Image.ANTIALIAS) im.save(outfile) @@ -139,38 +155,56 @@ class VideoWriter: # got `Optional[str]`. regexp = os.path.join(self.cache_dir, self.regexp) - if shutil.which(self.ffmpeg_bin) is None: - raise ValueError( - f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. " - + "Please set FFMPEG in the environment or ffmpeg_bin on this class." - ) - if self.output_format == "visdom": # works for ppt too - args = [ - self.ffmpeg_bin, - "-r", - str(self.fps), - "-i", - regexp, - "-vcodec", - "h264", - "-f", - "mp4", - "-y", - "-crf", - "18", - "-b", - "2000k", - "-pix_fmt", - "yuv420p", - self.out_path, - ] - if quiet: - subprocess.check_call( - args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + # Video codec parameters + video_codec = "h264" + crf = "18" + b = "2000k" + pix_fmt = "yuv420p" + + if self.use_torchvision_video_writer: + torchvision.io.write_video( + self.out_path, + torch.stack( + [torch.from_numpy(np.array(Image.open(f))) for f in self.frames] + ), + fps=self.fps, + video_codec=video_codec, + options={"crf": crf, "b": b, "pix_fmt": pix_fmt}, ) + else: - subprocess.check_call(args) + if shutil.which(self.ffmpeg_bin) is None: + raise ValueError( + f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. " + + "Please set FFMPEG in the environment or ffmpeg_bin on this class." + ) + + args = [ + self.ffmpeg_bin, + "-r", + str(self.fps), + "-i", + regexp, + "-vcodec", + video_codec, + "-f", + "mp4", + "-y", + "-crf", + crf, + "-b", + b, + "-pix_fmt", + pix_fmt, + self.out_path, + ] + if quiet: + subprocess.check_call( + args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL + ) + else: + subprocess.check_call(args) else: raise ValueError("no such output type %s" % str(self.output_format))