mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Updates to Implicitron dataset, metrics and tools
Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack) Reviewed By: shapovalov Differential Revision: D65942513 fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
This commit is contained in:
		
							parent
							
								
									42a4a7d432
								
							
						
					
					
						commit
						43cd681d4f
					
				@ -48,6 +48,7 @@ from pytorch3d.implicitron.dataset.utils import (
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
 | 
			
		||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
 | 
			
		||||
 | 
			
		||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
 | 
			
		||||
@ -158,7 +159,7 @@ class FrameData(Mapping[str, Any]):
 | 
			
		||||
        new_params = {}
 | 
			
		||||
        for field_name in iter(self):
 | 
			
		||||
            value = getattr(self, field_name)
 | 
			
		||||
            if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
 | 
			
		||||
            if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
 | 
			
		||||
                new_params[field_name] = value.to(*args, **kwargs)
 | 
			
		||||
            else:
 | 
			
		||||
                new_params[field_name] = value
 | 
			
		||||
@ -420,7 +421,6 @@ class FrameData(Mapping[str, Any]):
 | 
			
		||||
            for f in fields(elem):
 | 
			
		||||
                if not f.init:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                list_values = override_fields.get(
 | 
			
		||||
                    f.name, [getattr(d, f.name) for d in batch]
 | 
			
		||||
                )
 | 
			
		||||
@ -429,7 +429,7 @@ class FrameData(Mapping[str, Any]):
 | 
			
		||||
                    if all(list_value is not None for list_value in list_values)
 | 
			
		||||
                    else None
 | 
			
		||||
                )
 | 
			
		||||
            return cls(**collated)
 | 
			
		||||
            return type(elem)(**collated)
 | 
			
		||||
 | 
			
		||||
        elif isinstance(elem, Pointclouds):
 | 
			
		||||
            return join_pointclouds_as_batch(batch)
 | 
			
		||||
@ -437,6 +437,8 @@ class FrameData(Mapping[str, Any]):
 | 
			
		||||
        elif isinstance(elem, CamerasBase):
 | 
			
		||||
            # TODO: don't store K; enforce working in NDC space
 | 
			
		||||
            return join_cameras_as_batch(batch)
 | 
			
		||||
        elif isinstance(elem, Meshes):
 | 
			
		||||
            return join_meshes_as_batch(batch)
 | 
			
		||||
        else:
 | 
			
		||||
            return torch.utils.data.dataloader.default_collate(batch)
 | 
			
		||||
 | 
			
		||||
@ -592,6 +594,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
        fg_mask_np: np.ndarray | None = None
 | 
			
		||||
        bbox_xywh: tuple[float, float, float, float] | None = None
 | 
			
		||||
        mask_annotation = frame_annotation.mask
 | 
			
		||||
 | 
			
		||||
        if mask_annotation is not None:
 | 
			
		||||
            if load_blobs and self.load_masks:
 | 
			
		||||
                fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ import hashlib
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
import urllib
 | 
			
		||||
from dataclasses import dataclass, Field, field
 | 
			
		||||
from typing import (
 | 
			
		||||
@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import (
 | 
			
		||||
    FrameDataBuilder,  # noqa
 | 
			
		||||
    FrameDataBuilderBase,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    registry,
 | 
			
		||||
    ReplaceableBase,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
from sqlalchemy.orm import Session
 | 
			
		||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
 | 
			
		||||
 | 
			
		||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
 | 
			
		||||
 | 
			
		||||
@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            engine verbatim. Don’t expose it to end users of your application!
 | 
			
		||||
        pick_categories: Restrict the dataset to the given list of categories.
 | 
			
		||||
        pick_sequences: A Sequence of sequence names to restrict the dataset to.
 | 
			
		||||
        pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
 | 
			
		||||
        exclude_sequences: A Sequence of the names of the sequences to exclude.
 | 
			
		||||
        limit_sequences_per_category_to: Limit the dataset to the first up to N
 | 
			
		||||
            sequences within each category (applies after all other sequence filters
 | 
			
		||||
@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            more frames than that; applied after other frame-level filters.
 | 
			
		||||
        seed: The seed of the random generator sampling `n_frames_per_sequence`
 | 
			
		||||
            random frames per sequence.
 | 
			
		||||
        preload_metadata: If True, the metadata is preloaded into memory.
 | 
			
		||||
        precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
 | 
			
		||||
        scoped_session: If True, allows different parts of the code to share
 | 
			
		||||
            a global session to access the database.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
 | 
			
		||||
@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    pick_categories: Tuple[str, ...] = ()
 | 
			
		||||
 | 
			
		||||
    pick_sequences: Tuple[str, ...] = ()
 | 
			
		||||
    pick_sequences_sql_clause: Optional[str] = None
 | 
			
		||||
    exclude_sequences: Tuple[str, ...] = ()
 | 
			
		||||
    limit_sequences_per_category_to: int = 0
 | 
			
		||||
    limit_sequences_to: int = 0
 | 
			
		||||
@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    n_frames_per_sequence: int = -1
 | 
			
		||||
    seed: int = 0
 | 
			
		||||
    remove_empty_masks_poll_whole_table_threshold: int = 300_000
 | 
			
		||||
    preload_metadata: bool = False
 | 
			
		||||
    precompute_seq_to_idx: bool = False
 | 
			
		||||
    # we set it manually in the constructor
 | 
			
		||||
    _index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
 | 
			
		||||
    _sql_engine: sa.engine.Engine = field(
 | 
			
		||||
@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    frame_data_builder: FrameDataBuilderBase  # pyre-ignore[13]
 | 
			
		||||
    frame_data_builder_class_type: str = "FrameDataBuilder"
 | 
			
		||||
 | 
			
		||||
    scoped_session: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        if sa.__version__ < "2.0":
 | 
			
		||||
            raise ImportError("This class requires SQL Alchemy 2.0 or later")
 | 
			
		||||
@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if self.preload_metadata:
 | 
			
		||||
            self._sql_engine = self._preload_database(self._sql_engine)
 | 
			
		||||
 | 
			
		||||
        sequences = self._get_filtered_sequences_if_any()
 | 
			
		||||
 | 
			
		||||
        if self.subsets:
 | 
			
		||||
@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
        logger.info(str(self))
 | 
			
		||||
 | 
			
		||||
        if self.scoped_session:
 | 
			
		||||
            self._session_factory = sessionmaker(bind=self._sql_engine)  # pyre-ignore
 | 
			
		||||
 | 
			
		||||
        if self.precompute_seq_to_idx:
 | 
			
		||||
            # This is deprecated and will be removed in the future.
 | 
			
		||||
            # After we backport https://github.com/facebookresearch/uco3d/pull/3
 | 
			
		||||
            logger.warning(
 | 
			
		||||
                "Using precompute_seq_to_idx is deprecated and will be removed in the future."
 | 
			
		||||
            )
 | 
			
		||||
            self._index["rowid"] = np.arange(len(self._index))
 | 
			
		||||
            groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
 | 
			
		||||
            self._seq_to_indices = dict(groupby.apply(list))  # pyre-ignore
 | 
			
		||||
            del self._index["rowid"]
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return len(self._index)
 | 
			
		||||
 | 
			
		||||
@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        seq_stmt = sa.select(self.sequence_annotations_type).where(
 | 
			
		||||
            self.sequence_annotations_type.sequence_name == seq
 | 
			
		||||
        )
 | 
			
		||||
        with Session(self._sql_engine) as session:
 | 
			
		||||
            entry = session.scalars(stmt).one()
 | 
			
		||||
            seq_metadata = session.scalars(seq_stmt).one()
 | 
			
		||||
        if self.scoped_session:
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            with scoped_session(self._session_factory)() as session:
 | 
			
		||||
                entry = session.scalars(stmt).one()
 | 
			
		||||
                seq_metadata = session.scalars(seq_stmt).one()
 | 
			
		||||
        else:
 | 
			
		||||
            with Session(self._sql_engine) as session:
 | 
			
		||||
                entry = session.scalars(stmt).one()
 | 
			
		||||
                seq_metadata = session.scalars(seq_stmt).one()
 | 
			
		||||
 | 
			
		||||
        assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
 | 
			
		||||
 | 
			
		||||
@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
        yield from index_slice.itertuples(index=False)
 | 
			
		||||
 | 
			
		||||
    # override
 | 
			
		||||
    def sequence_indices_in_order(
 | 
			
		||||
        self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
 | 
			
		||||
    ) -> Iterator[int]:
 | 
			
		||||
        """Same as `sequence_frames_in_order` but returns the iterator over
 | 
			
		||||
        only dataset indices.
 | 
			
		||||
        """
 | 
			
		||||
        if self.precompute_seq_to_idx and subset_filter is None:
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            yield from self._seq_to_indices[seq_name]
 | 
			
		||||
        else:
 | 
			
		||||
            for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
 | 
			
		||||
                yield idx
 | 
			
		||||
 | 
			
		||||
    # override
 | 
			
		||||
    def get_eval_batches(self) -> Optional[List[Any]]:
 | 
			
		||||
        """
 | 
			
		||||
@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            or self.limit_sequences_to > 0
 | 
			
		||||
            or self.limit_sequences_per_category_to > 0
 | 
			
		||||
            or len(self.pick_sequences) > 0
 | 
			
		||||
            or self.pick_sequences_sql_clause is not None
 | 
			
		||||
            or len(self.exclude_sequences) > 0
 | 
			
		||||
            or len(self.pick_categories) > 0
 | 
			
		||||
            or self.n_frames_per_sequence > 0
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _preload_database(
 | 
			
		||||
        self, source_engine: sa.engine.base.Engine
 | 
			
		||||
    ) -> sa.engine.base.Engine:
 | 
			
		||||
        destination_engine = sa.create_engine("sqlite:///:memory:")
 | 
			
		||||
        metadata = sa.MetaData()
 | 
			
		||||
        metadata.reflect(bind=source_engine)
 | 
			
		||||
        metadata.create_all(bind=destination_engine)
 | 
			
		||||
 | 
			
		||||
        with source_engine.connect() as source_conn:
 | 
			
		||||
            with destination_engine.connect() as destination_conn:
 | 
			
		||||
                for table_obj in metadata.tables.values():
 | 
			
		||||
                    # Select all rows from the source table
 | 
			
		||||
                    source_rows = source_conn.execute(table_obj.select())
 | 
			
		||||
 | 
			
		||||
                    # Insert rows into the destination table
 | 
			
		||||
                    for row in source_rows:
 | 
			
		||||
                        destination_conn.execute(table_obj.insert().values(row))
 | 
			
		||||
 | 
			
		||||
                    # Commit the changes for each table
 | 
			
		||||
                    destination_conn.commit()
 | 
			
		||||
 | 
			
		||||
        return destination_engine
 | 
			
		||||
 | 
			
		||||
    def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
 | 
			
		||||
        # maximum possible filter (if limit_sequences_per_category_to == 0):
 | 
			
		||||
        # WHERE category IN 'self.pick_categories'
 | 
			
		||||
@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            *self._get_pick_filters(),
 | 
			
		||||
            *self._get_exclude_filters(),
 | 
			
		||||
        ]
 | 
			
		||||
        if self.pick_sequences_sql_clause:
 | 
			
		||||
            print("Applying the custom SQL clause.")
 | 
			
		||||
            where_conditions.append(sa.text(self.pick_sequences_sql_clause))
 | 
			
		||||
 | 
			
		||||
        def add_where(stmt):
 | 
			
		||||
            return stmt.where(*where_conditions) if where_conditions else stmt
 | 
			
		||||
@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            self.frame_annotations_type.sequence_name == seq_name,
 | 
			
		||||
            self.frame_annotations_type.frame_number.in_(frames),
 | 
			
		||||
        )
 | 
			
		||||
        frame_no_ts = None
 | 
			
		||||
 | 
			
		||||
        with self._sql_engine.connect() as connection:
 | 
			
		||||
            frame_no_ts = pd.read_sql_query(stmt, connection)
 | 
			
		||||
        if self.scoped_session:
 | 
			
		||||
            stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
 | 
			
		||||
            with scoped_session(self._session_factory)() as session:  # pyre-ignore
 | 
			
		||||
                frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
 | 
			
		||||
        else:
 | 
			
		||||
            with self._sql_engine.connect() as connection:
 | 
			
		||||
                frame_no_ts = pd.read_sql_query(stmt, connection)
 | 
			
		||||
 | 
			
		||||
        if len(frame_no_ts) != len(index_slice):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
 | 
			
		||||
@ -284,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
 | 
			
		||||
                logger.info(f"Val dataset: {str(val_dataset)}")
 | 
			
		||||
 | 
			
		||||
            logger.debug("Extracting test dataset.")
 | 
			
		||||
            eval_batches_file = self._get_lists_file("eval_batches")
 | 
			
		||||
            del common_dataset_kwargs["eval_batches_file"]
 | 
			
		||||
            if self.eval_batches_path is None:
 | 
			
		||||
                eval_batches_file = None
 | 
			
		||||
            else:
 | 
			
		||||
                eval_batches_file = self._get_lists_file("eval_batches")
 | 
			
		||||
 | 
			
		||||
            if "eval_batches_file" in common_dataset_kwargs:
 | 
			
		||||
                common_dataset_kwargs.pop("eval_batches_file", None)
 | 
			
		||||
 | 
			
		||||
            test_dataset = dataset_type(
 | 
			
		||||
                **common_dataset_kwargs,
 | 
			
		||||
                subsets=self._get_subsets(self.test_subsets, True),
 | 
			
		||||
 | 
			
		||||
@ -211,14 +211,21 @@ def resize_image(
 | 
			
		||||
    if isinstance(image, np.ndarray):
 | 
			
		||||
        image = torch.from_numpy(image)
 | 
			
		||||
 | 
			
		||||
    if image_height is None or image_width is None:
 | 
			
		||||
    if (
 | 
			
		||||
        image_height is None
 | 
			
		||||
        or image_width is None
 | 
			
		||||
        or image.shape[-2] == 0
 | 
			
		||||
        or image.shape[-1] == 0
 | 
			
		||||
    ):
 | 
			
		||||
        # skip the resizing
 | 
			
		||||
        return image, 1.0, torch.ones_like(image[:1])
 | 
			
		||||
 | 
			
		||||
    # takes numpy array or tensor, returns pytorch tensor
 | 
			
		||||
    minscale = min(
 | 
			
		||||
        image_height / image.shape[-2],
 | 
			
		||||
        image_width / image.shape[-1],
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    imre = torch.nn.functional.interpolate(
 | 
			
		||||
        image[None],
 | 
			
		||||
        scale_factor=minscale,
 | 
			
		||||
@ -226,6 +233,7 @@ def resize_image(
 | 
			
		||||
        align_corners=False if mode == "bilinear" else None,
 | 
			
		||||
        recompute_scale_factor=True,
 | 
			
		||||
    )[0]
 | 
			
		||||
 | 
			
		||||
    imre_ = torch.zeros(image.shape[0], image_height, image_width)
 | 
			
		||||
    imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
 | 
			
		||||
    mask = torch.zeros(1, image_height, image_width)
 | 
			
		||||
@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
 | 
			
		||||
    return im.astype(np.float32) / 255.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
 | 
			
		||||
def load_image(
 | 
			
		||||
    path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
 | 
			
		||||
) -> np.ndarray:
 | 
			
		||||
    """
 | 
			
		||||
    Load an image from a path and return it as a numpy array.
 | 
			
		||||
    If try_read_alpha is True, the image is read as RGBA and the alpha channel is
 | 
			
		||||
    returned as the fourth channel.
 | 
			
		||||
    Otherwise, the image is read as RGB and a three-channel image is returned.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        # Check if the image has an alpha channel
 | 
			
		||||
        if try_read_alpha and pil_im.mode == "RGBA":
 | 
			
		||||
            im = np.array(pil_im)
 | 
			
		||||
        else:
 | 
			
		||||
            im = np.array(pil_im.convert("RGB"))
 | 
			
		||||
            im = np.array(pil_im.convert(pil_format))
 | 
			
		||||
 | 
			
		||||
    return transpose_normalize_image(im)
 | 
			
		||||
 | 
			
		||||
@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_(
 | 
			
		||||
    )
 | 
			
		||||
    camera.focal_length = focal_length_scaled[None]
 | 
			
		||||
    # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
 | 
			
		||||
    camera.principal_point = principal_point_scaled[None]
 | 
			
		||||
    camera.principal_point = principal_point_scaled[None]  # pyre-ignore[16]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# NOTE this cache is per-worker; they are implemented as processes.
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,6 @@
 | 
			
		||||
 | 
			
		||||
# pyre-unsafe
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Any, Dict, Optional
 | 
			
		||||
 | 
			
		||||
@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
 | 
			
		||||
                _rgb_metrics(
 | 
			
		||||
                    image_rgb,
 | 
			
		||||
                    image_rgb_pred,
 | 
			
		||||
                    fg_probability,
 | 
			
		||||
                    fg_probability_pred,
 | 
			
		||||
                    mask_crop,
 | 
			
		||||
                    masks=fg_probability,
 | 
			
		||||
                    masks_crop=mask_crop,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
 | 
			
		||||
            metrics["mask_neg_iou"] = utils.neg_iou_loss(
 | 
			
		||||
                fg_probability_pred, fg_probability, mask=mask_crop
 | 
			
		||||
            )
 | 
			
		||||
            metrics["mask_bce"] = utils.calc_bce(
 | 
			
		||||
                fg_probability_pred, fg_probability, mask=mask_crop
 | 
			
		||||
            )
 | 
			
		||||
            if torch.is_autocast_enabled():
 | 
			
		||||
                # To avoid issues with mixed precision
 | 
			
		||||
                metrics["mask_bce"] = utils.calc_bce(
 | 
			
		||||
                    fg_probability_pred.logit(),
 | 
			
		||||
                    fg_probability,
 | 
			
		||||
                    mask=mask_crop,
 | 
			
		||||
                    pred_logits=True,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                metrics["mask_bce"] = utils.calc_bce(
 | 
			
		||||
                    fg_probability_pred,
 | 
			
		||||
                    fg_probability,
 | 
			
		||||
                    mask=mask_crop,
 | 
			
		||||
                    pred_logits=False,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if depth_map is not None and depth_map_pred is not None:
 | 
			
		||||
            assert mask_crop is not None
 | 
			
		||||
@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
 | 
			
		||||
            if fg_probability is not None:
 | 
			
		||||
                mask = fg_probability * mask_crop
 | 
			
		||||
                _, abs_ = utils.eval_depth(
 | 
			
		||||
                    depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
 | 
			
		||||
                    depth_map_pred,
 | 
			
		||||
                    depth_map,
 | 
			
		||||
                    get_best_scale=True,
 | 
			
		||||
                    mask=mask,
 | 
			
		||||
                    crop=0,
 | 
			
		||||
                )
 | 
			
		||||
                metrics["depth_abs_fg"] = abs_.mean()
 | 
			
		||||
 | 
			
		||||
@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
 | 
			
		||||
        return metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
 | 
			
		||||
def _rgb_metrics(
 | 
			
		||||
    images,
 | 
			
		||||
    images_pred,
 | 
			
		||||
    masks=None,
 | 
			
		||||
    masks_crop=None,
 | 
			
		||||
    huber_scaling: float = 0.03,
 | 
			
		||||
):
 | 
			
		||||
    assert masks_crop is not None
 | 
			
		||||
    if images.shape[1] != images_pred.shape[1]:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"Network output's RGB images had {images_pred.shape[1]} "
 | 
			
		||||
            f"channels. {images.shape[1]} expected."
 | 
			
		||||
        )
 | 
			
		||||
    rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
 | 
			
		||||
    rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
 | 
			
		||||
    rgb_loss = utils.huber(rgb_squared, scaling=0.03)
 | 
			
		||||
    rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
 | 
			
		||||
    crop_mass = masks_crop.sum().clamp(1.0)
 | 
			
		||||
    results = {
 | 
			
		||||
        "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
 | 
			
		||||
        "rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
 | 
			
		||||
        "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
 | 
			
		||||
        "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -6,12 +6,15 @@
 | 
			
		||||
 | 
			
		||||
# pyre-unsafe
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from torch.nn import functional as F
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def eval_depth(
 | 
			
		||||
    pred: torch.Tensor,
 | 
			
		||||
@ -21,6 +24,8 @@ def eval_depth(
 | 
			
		||||
    get_best_scale: bool = True,
 | 
			
		||||
    mask_thr: float = 0.5,
 | 
			
		||||
    best_scale_clamp_thr: float = 1e-4,
 | 
			
		||||
    use_disparity: bool = False,
 | 
			
		||||
    disparity_eps: float = 1e-4,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    Evaluate the depth error between the prediction `pred` and the ground
 | 
			
		||||
@ -64,6 +69,13 @@ def eval_depth(
 | 
			
		||||
        # 	s.t. we get best possible mse error
 | 
			
		||||
        scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
 | 
			
		||||
        pred = pred * scale_best[:, None, None, None]
 | 
			
		||||
    if use_disparity:
 | 
			
		||||
        gt = torch.div(1.0, (gt + disparity_eps))
 | 
			
		||||
        pred = torch.div(1.0, (pred + disparity_eps))
 | 
			
		||||
        scale_best = estimate_depth_scale_factor(
 | 
			
		||||
            pred, gt, dmask, best_scale_clamp_thr
 | 
			
		||||
        ).detach()
 | 
			
		||||
        pred = pred * scale_best[:, None, None, None]
 | 
			
		||||
 | 
			
		||||
    df = gt - pred
 | 
			
		||||
 | 
			
		||||
@ -117,6 +129,7 @@ def calc_bce(
 | 
			
		||||
    pred_eps: float = 0.01,
 | 
			
		||||
    mask: Optional[torch.Tensor] = None,
 | 
			
		||||
    lerp_bound: Optional[float] = None,
 | 
			
		||||
    pred_logits: bool = False,
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Calculates the binary cross entropy.
 | 
			
		||||
@ -139,9 +152,23 @@ def calc_bce(
 | 
			
		||||
        weight = torch.ones_like(gt) * mask
 | 
			
		||||
 | 
			
		||||
    if lerp_bound is not None:
 | 
			
		||||
        # binary_cross_entropy_lerp requires pred to be in [0, 1]
 | 
			
		||||
        if pred_logits:
 | 
			
		||||
            pred = F.sigmoid(pred)
 | 
			
		||||
 | 
			
		||||
        return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
 | 
			
		||||
    else:
 | 
			
		||||
        return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
 | 
			
		||||
        if pred_logits:
 | 
			
		||||
            loss = F.binary_cross_entropy_with_logits(
 | 
			
		||||
                pred,
 | 
			
		||||
                gt,
 | 
			
		||||
                reduction="none",
 | 
			
		||||
                weight=weight,
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
 | 
			
		||||
 | 
			
		||||
        return loss.mean()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def binary_cross_entropy_lerp(
 | 
			
		||||
 | 
			
		||||
@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
 | 
			
		||||
import matplotlib
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from PIL import Image
 | 
			
		||||
 | 
			
		||||
_NO_TORCHVISION = False
 | 
			
		||||
try:
 | 
			
		||||
    import torchvision
 | 
			
		||||
except ImportError:
 | 
			
		||||
    _NO_TORCHVISION = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
 | 
			
		||||
 | 
			
		||||
matplotlib.use("Agg")
 | 
			
		||||
@ -36,6 +45,7 @@ class VideoWriter:
 | 
			
		||||
        fps: int = 20,
 | 
			
		||||
        output_format: str = "visdom",
 | 
			
		||||
        rmdir_allowed: bool = False,
 | 
			
		||||
        use_torchvision_video_writer: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
@ -49,6 +59,8 @@ class VideoWriter:
 | 
			
		||||
                is supported.
 | 
			
		||||
            rmdir_allowed: If `True` delete and create `cache_dir` in case
 | 
			
		||||
                it is not empty.
 | 
			
		||||
            use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
 | 
			
		||||
            to write the video
 | 
			
		||||
        """
 | 
			
		||||
        self.rmdir_allowed = rmdir_allowed
 | 
			
		||||
        self.output_format = output_format
 | 
			
		||||
@ -56,10 +68,14 @@ class VideoWriter:
 | 
			
		||||
        self.out_path = out_path
 | 
			
		||||
        self.cache_dir = cache_dir
 | 
			
		||||
        self.ffmpeg_bin = ffmpeg_bin
 | 
			
		||||
        self.use_torchvision_video_writer = use_torchvision_video_writer
 | 
			
		||||
        self.frames = []
 | 
			
		||||
        self.regexp = "frame_%08d.png"
 | 
			
		||||
        self.frame_num = 0
 | 
			
		||||
 | 
			
		||||
        if self.use_torchvision_video_writer:
 | 
			
		||||
            assert not _NO_TORCHVISION, "torchvision not available"
 | 
			
		||||
 | 
			
		||||
        if self.cache_dir is not None:
 | 
			
		||||
            self.tmp_dir = None
 | 
			
		||||
            if os.path.isdir(self.cache_dir):
 | 
			
		||||
@ -114,7 +130,7 @@ class VideoWriter:
 | 
			
		||||
                resize = im.size
 | 
			
		||||
            # make sure size is divisible by 2
 | 
			
		||||
            resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
 | 
			
		||||
            # pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
 | 
			
		||||
 | 
			
		||||
            im = im.resize(resize, Image.ANTIALIAS)
 | 
			
		||||
            im.save(outfile)
 | 
			
		||||
 | 
			
		||||
@ -139,38 +155,56 @@ class VideoWriter:
 | 
			
		||||
        #  got `Optional[str]`.
 | 
			
		||||
        regexp = os.path.join(self.cache_dir, self.regexp)
 | 
			
		||||
 | 
			
		||||
        if shutil.which(self.ffmpeg_bin) is None:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
 | 
			
		||||
                + "Please set FFMPEG in the environment or ffmpeg_bin on this class."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.output_format == "visdom":  # works for ppt too
 | 
			
		||||
            args = [
 | 
			
		||||
                self.ffmpeg_bin,
 | 
			
		||||
                "-r",
 | 
			
		||||
                str(self.fps),
 | 
			
		||||
                "-i",
 | 
			
		||||
                regexp,
 | 
			
		||||
                "-vcodec",
 | 
			
		||||
                "h264",
 | 
			
		||||
                "-f",
 | 
			
		||||
                "mp4",
 | 
			
		||||
                "-y",
 | 
			
		||||
                "-crf",
 | 
			
		||||
                "18",
 | 
			
		||||
                "-b",
 | 
			
		||||
                "2000k",
 | 
			
		||||
                "-pix_fmt",
 | 
			
		||||
                "yuv420p",
 | 
			
		||||
                self.out_path,
 | 
			
		||||
            ]
 | 
			
		||||
            if quiet:
 | 
			
		||||
                subprocess.check_call(
 | 
			
		||||
                    args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
 | 
			
		||||
            # Video codec parameters
 | 
			
		||||
            video_codec = "h264"
 | 
			
		||||
            crf = "18"
 | 
			
		||||
            b = "2000k"
 | 
			
		||||
            pix_fmt = "yuv420p"
 | 
			
		||||
 | 
			
		||||
            if self.use_torchvision_video_writer:
 | 
			
		||||
                torchvision.io.write_video(
 | 
			
		||||
                    self.out_path,
 | 
			
		||||
                    torch.stack(
 | 
			
		||||
                        [torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
 | 
			
		||||
                    ),
 | 
			
		||||
                    fps=self.fps,
 | 
			
		||||
                    video_codec=video_codec,
 | 
			
		||||
                    options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                subprocess.check_call(args)
 | 
			
		||||
                if shutil.which(self.ffmpeg_bin) is None:
 | 
			
		||||
                    raise ValueError(
 | 
			
		||||
                        f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
 | 
			
		||||
                        + "Please set FFMPEG in the environment or ffmpeg_bin on this class."
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                args = [
 | 
			
		||||
                    self.ffmpeg_bin,
 | 
			
		||||
                    "-r",
 | 
			
		||||
                    str(self.fps),
 | 
			
		||||
                    "-i",
 | 
			
		||||
                    regexp,
 | 
			
		||||
                    "-vcodec",
 | 
			
		||||
                    video_codec,
 | 
			
		||||
                    "-f",
 | 
			
		||||
                    "mp4",
 | 
			
		||||
                    "-y",
 | 
			
		||||
                    "-crf",
 | 
			
		||||
                    crf,
 | 
			
		||||
                    "-b",
 | 
			
		||||
                    b,
 | 
			
		||||
                    "-pix_fmt",
 | 
			
		||||
                    pix_fmt,
 | 
			
		||||
                    self.out_path,
 | 
			
		||||
                ]
 | 
			
		||||
                if quiet:
 | 
			
		||||
                    subprocess.check_call(
 | 
			
		||||
                        args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
 | 
			
		||||
                    )
 | 
			
		||||
                else:
 | 
			
		||||
                    subprocess.check_call(args)
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("no such output type %s" % str(self.output_format))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user