Updates to Implicitron dataset, metrics and tools

Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack)

Reviewed By: shapovalov

Differential Revision: D65942513

fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
This commit is contained in:
Antoine Toisoul 2025-01-27 09:43:42 -08:00 committed by Facebook GitHub Bot
parent 42a4a7d432
commit 43cd681d4f
7 changed files with 240 additions and 57 deletions

View File

@ -48,6 +48,7 @@ from pytorch3d.implicitron.dataset.utils import (
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.camera_utils import join_cameras_as_batch from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
@ -158,7 +159,7 @@ class FrameData(Mapping[str, Any]):
new_params = {} new_params = {}
for field_name in iter(self): for field_name in iter(self):
value = getattr(self, field_name) value = getattr(self, field_name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
new_params[field_name] = value.to(*args, **kwargs) new_params[field_name] = value.to(*args, **kwargs)
else: else:
new_params[field_name] = value new_params[field_name] = value
@ -420,7 +421,6 @@ class FrameData(Mapping[str, Any]):
for f in fields(elem): for f in fields(elem):
if not f.init: if not f.init:
continue continue
list_values = override_fields.get( list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch] f.name, [getattr(d, f.name) for d in batch]
) )
@ -429,7 +429,7 @@ class FrameData(Mapping[str, Any]):
if all(list_value is not None for list_value in list_values) if all(list_value is not None for list_value in list_values)
else None else None
) )
return cls(**collated) return type(elem)(**collated)
elif isinstance(elem, Pointclouds): elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch) return join_pointclouds_as_batch(batch)
@ -437,6 +437,8 @@ class FrameData(Mapping[str, Any]):
elif isinstance(elem, CamerasBase): elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space # TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch) return join_cameras_as_batch(batch)
elif isinstance(elem, Meshes):
return join_meshes_as_batch(batch)
else: else:
return torch.utils.data.dataloader.default_collate(batch) return torch.utils.data.dataloader.default_collate(batch)
@ -592,6 +594,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
fg_mask_np: np.ndarray | None = None fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None bbox_xywh: tuple[float, float, float, float] | None = None
mask_annotation = frame_annotation.mask mask_annotation = frame_annotation.mask
if mask_annotation is not None: if mask_annotation is not None:
if load_blobs and self.load_masks: if load_blobs and self.load_masks:
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation) fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)

View File

@ -10,6 +10,7 @@ import hashlib
import json import json
import logging import logging
import os import os
import urllib import urllib
from dataclasses import dataclass, Field, field from dataclasses import dataclass, Field, field
from typing import ( from typing import (
@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import (
FrameDataBuilder, # noqa FrameDataBuilder, # noqa
FrameDataBuilderBase, FrameDataBuilderBase,
) )
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
registry, registry,
ReplaceableBase, ReplaceableBase,
run_auto_creation, run_auto_creation,
) )
from sqlalchemy.orm import Session from sqlalchemy.orm import scoped_session, Session, sessionmaker
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
engine verbatim. Dont expose it to end users of your application! engine verbatim. Dont expose it to end users of your application!
pick_categories: Restrict the dataset to the given list of categories. pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to. pick_sequences: A Sequence of sequence names to restrict the dataset to.
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
exclude_sequences: A Sequence of the names of the sequences to exclude. exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_per_category_to: Limit the dataset to the first up to N limit_sequences_per_category_to: Limit the dataset to the first up to N
sequences within each category (applies after all other sequence filters sequences within each category (applies after all other sequence filters
@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
more frames than that; applied after other frame-level filters. more frames than that; applied after other frame-level filters.
seed: The seed of the random generator sampling `n_frames_per_sequence` seed: The seed of the random generator sampling `n_frames_per_sequence`
random frames per sequence. random frames per sequence.
preload_metadata: If True, the metadata is preloaded into memory.
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
scoped_session: If True, allows different parts of the code to share
a global session to access the database.
""" """
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
pick_categories: Tuple[str, ...] = () pick_categories: Tuple[str, ...] = ()
pick_sequences: Tuple[str, ...] = () pick_sequences: Tuple[str, ...] = ()
pick_sequences_sql_clause: Optional[str] = None
exclude_sequences: Tuple[str, ...] = () exclude_sequences: Tuple[str, ...] = ()
limit_sequences_per_category_to: int = 0 limit_sequences_per_category_to: int = 0
limit_sequences_to: int = 0 limit_sequences_to: int = 0
@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
n_frames_per_sequence: int = -1 n_frames_per_sequence: int = -1
seed: int = 0 seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000 remove_empty_masks_poll_whole_table_threshold: int = 300_000
preload_metadata: bool = False
precompute_seq_to_idx: bool = False
# we set it manually in the constructor # we set it manually in the constructor
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True}) _index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
_sql_engine: sa.engine.Engine = field( _sql_engine: sa.engine.Engine = field(
@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13] frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
frame_data_builder_class_type: str = "FrameDataBuilder" frame_data_builder_class_type: str = "FrameDataBuilder"
scoped_session: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
if sa.__version__ < "2.0": if sa.__version__ < "2.0":
raise ImportError("This class requires SQL Alchemy 2.0 or later") raise ImportError("This class requires SQL Alchemy 2.0 or later")
@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true" f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
) )
if self.preload_metadata:
self._sql_engine = self._preload_database(self._sql_engine)
sequences = self._get_filtered_sequences_if_any() sequences = self._get_filtered_sequences_if_any()
if self.subsets: if self.subsets:
@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
logger.info(str(self)) logger.info(str(self))
if self.scoped_session:
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
if self.precompute_seq_to_idx:
# This is deprecated and will be removed in the future.
# After we backport https://github.com/facebookresearch/uco3d/pull/3
logger.warning(
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
)
self._index["rowid"] = np.arange(len(self._index))
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
del self._index["rowid"]
def __len__(self) -> int: def __len__(self) -> int:
return len(self._index) return len(self._index)
@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
seq_stmt = sa.select(self.sequence_annotations_type).where( seq_stmt = sa.select(self.sequence_annotations_type).where(
self.sequence_annotations_type.sequence_name == seq self.sequence_annotations_type.sequence_name == seq
) )
with Session(self._sql_engine) as session: if self.scoped_session:
entry = session.scalars(stmt).one() # pyre-ignore
seq_metadata = session.scalars(seq_stmt).one() with scoped_session(self._session_factory)() as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
else:
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"] assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
yield from index_slice.itertuples(index=False) yield from index_slice.itertuples(index=False)
# override
def sequence_indices_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
if self.precompute_seq_to_idx and subset_filter is None:
# pyre-ignore
yield from self._seq_to_indices[seq_name]
else:
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
yield idx
# override # override
def get_eval_batches(self) -> Optional[List[Any]]: def get_eval_batches(self) -> Optional[List[Any]]:
""" """
@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
or self.limit_sequences_to > 0 or self.limit_sequences_to > 0
or self.limit_sequences_per_category_to > 0 or self.limit_sequences_per_category_to > 0
or len(self.pick_sequences) > 0 or len(self.pick_sequences) > 0
or self.pick_sequences_sql_clause is not None
or len(self.exclude_sequences) > 0 or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0 or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0 or self.n_frames_per_sequence > 0
) )
def _preload_database(
self, source_engine: sa.engine.base.Engine
) -> sa.engine.base.Engine:
destination_engine = sa.create_engine("sqlite:///:memory:")
metadata = sa.MetaData()
metadata.reflect(bind=source_engine)
metadata.create_all(bind=destination_engine)
with source_engine.connect() as source_conn:
with destination_engine.connect() as destination_conn:
for table_obj in metadata.tables.values():
# Select all rows from the source table
source_rows = source_conn.execute(table_obj.select())
# Insert rows into the destination table
for row in source_rows:
destination_conn.execute(table_obj.insert().values(row))
# Commit the changes for each table
destination_conn.commit()
return destination_engine
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]: def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible filter (if limit_sequences_per_category_to == 0): # maximum possible filter (if limit_sequences_per_category_to == 0):
# WHERE category IN 'self.pick_categories' # WHERE category IN 'self.pick_categories'
@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
*self._get_pick_filters(), *self._get_pick_filters(),
*self._get_exclude_filters(), *self._get_exclude_filters(),
] ]
if self.pick_sequences_sql_clause:
print("Applying the custom SQL clause.")
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
def add_where(stmt): def add_where(stmt):
return stmt.where(*where_conditions) if where_conditions else stmt return stmt.where(*where_conditions) if where_conditions else stmt
@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
self.frame_annotations_type.sequence_name == seq_name, self.frame_annotations_type.sequence_name == seq_name,
self.frame_annotations_type.frame_number.in_(frames), self.frame_annotations_type.frame_number.in_(frames),
) )
frame_no_ts = None
with self._sql_engine.connect() as connection: if self.scoped_session:
frame_no_ts = pd.read_sql_query(stmt, connection) stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
with scoped_session(self._session_factory)() as session: # pyre-ignore
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
else:
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if len(frame_no_ts) != len(index_slice): if len(frame_no_ts) != len(index_slice):
raise ValueError( raise ValueError(

View File

@ -284,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
logger.info(f"Val dataset: {str(val_dataset)}") logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.") logger.debug("Extracting test dataset.")
eval_batches_file = self._get_lists_file("eval_batches") if self.eval_batches_path is None:
del common_dataset_kwargs["eval_batches_file"] eval_batches_file = None
else:
eval_batches_file = self._get_lists_file("eval_batches")
if "eval_batches_file" in common_dataset_kwargs:
common_dataset_kwargs.pop("eval_batches_file", None)
test_dataset = dataset_type( test_dataset = dataset_type(
**common_dataset_kwargs, **common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True), subsets=self._get_subsets(self.test_subsets, True),

View File

@ -211,14 +211,21 @@ def resize_image(
if isinstance(image, np.ndarray): if isinstance(image, np.ndarray):
image = torch.from_numpy(image) image = torch.from_numpy(image)
if image_height is None or image_width is None: if (
image_height is None
or image_width is None
or image.shape[-2] == 0
or image.shape[-1] == 0
):
# skip the resizing # skip the resizing
return image, 1.0, torch.ones_like(image[:1]) return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor # takes numpy array or tensor, returns pytorch tensor
minscale = min( minscale = min(
image_height / image.shape[-2], image_height / image.shape[-2],
image_width / image.shape[-1], image_width / image.shape[-1],
) )
imre = torch.nn.functional.interpolate( imre = torch.nn.functional.interpolate(
image[None], image[None],
scale_factor=minscale, scale_factor=minscale,
@ -226,6 +233,7 @@ def resize_image(
align_corners=False if mode == "bilinear" else None, align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True, recompute_scale_factor=True,
)[0] )[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width) imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width) mask = torch.zeros(1, image_height, image_width)
@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
return im.astype(np.float32) / 255.0 return im.astype(np.float32) / 255.0
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray: def load_image(
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
) -> np.ndarray:
""" """
Load an image from a path and return it as a numpy array. Load an image from a path and return it as a numpy array.
If try_read_alpha is True, the image is read as RGBA and the alpha channel is If try_read_alpha is True, the image is read as RGBA and the alpha channel is
returned as the fourth channel. returned as the fourth channel.
Otherwise, the image is read as RGB and a three-channel image is returned. Otherwise, the image is read as RGB and a three-channel image is returned.
""" """
with Image.open(path) as pil_im: with Image.open(path) as pil_im:
# Check if the image has an alpha channel # Check if the image has an alpha channel
if try_read_alpha and pil_im.mode == "RGBA": if try_read_alpha and pil_im.mode == "RGBA":
im = np.array(pil_im) im = np.array(pil_im)
else: else:
im = np.array(pil_im.convert("RGB")) im = np.array(pil_im.convert(pil_format))
return transpose_normalize_image(im) return transpose_normalize_image(im)
@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_(
) )
camera.focal_length = focal_length_scaled[None] camera.focal_length = focal_length_scaled[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`. # pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_scaled[None] camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
# NOTE this cache is per-worker; they are implemented as processes. # NOTE this cache is per-worker; they are implemented as processes.

View File

@ -6,7 +6,6 @@
# pyre-unsafe # pyre-unsafe
import warnings import warnings
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
_rgb_metrics( _rgb_metrics(
image_rgb, image_rgb,
image_rgb_pred, image_rgb_pred,
fg_probability, masks=fg_probability,
fg_probability_pred, masks_crop=mask_crop,
mask_crop,
) )
) )
@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
metrics["mask_neg_iou"] = utils.neg_iou_loss( metrics["mask_neg_iou"] = utils.neg_iou_loss(
fg_probability_pred, fg_probability, mask=mask_crop fg_probability_pred, fg_probability, mask=mask_crop
) )
metrics["mask_bce"] = utils.calc_bce( if torch.is_autocast_enabled():
fg_probability_pred, fg_probability, mask=mask_crop # To avoid issues with mixed precision
) metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred.logit(),
fg_probability,
mask=mask_crop,
pred_logits=True,
)
else:
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred,
fg_probability,
mask=mask_crop,
pred_logits=False,
)
if depth_map is not None and depth_map_pred is not None: if depth_map is not None and depth_map_pred is not None:
assert mask_crop is not None assert mask_crop is not None
@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
if fg_probability is not None: if fg_probability is not None:
mask = fg_probability * mask_crop mask = fg_probability * mask_crop
_, abs_ = utils.eval_depth( _, abs_ = utils.eval_depth(
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0 depth_map_pred,
depth_map,
get_best_scale=True,
mask=mask,
crop=0,
) )
metrics["depth_abs_fg"] = abs_.mean() metrics["depth_abs_fg"] = abs_.mean()
@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
return metrics return metrics
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop): def _rgb_metrics(
images,
images_pred,
masks=None,
masks_crop=None,
huber_scaling: float = 0.03,
):
assert masks_crop is not None assert masks_crop is not None
if images.shape[1] != images_pred.shape[1]: if images.shape[1] != images_pred.shape[1]:
raise ValueError( raise ValueError(
f"Network output's RGB images had {images_pred.shape[1]} " f"Network output's RGB images had {images_pred.shape[1]} "
f"channels. {images.shape[1]} expected." f"channels. {images.shape[1]} expected."
) )
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03) rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
crop_mass = masks_crop.sum().clamp(1.0) crop_mass = masks_crop.sum().clamp(1.0)
results = { results = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop), "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
} }

View File

@ -6,12 +6,15 @@
# pyre-unsafe # pyre-unsafe
import logging
import math import math
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from torch.nn import functional as F from torch.nn import functional as F
logger = logging.getLogger(__name__)
def eval_depth( def eval_depth(
pred: torch.Tensor, pred: torch.Tensor,
@ -21,6 +24,8 @@ def eval_depth(
get_best_scale: bool = True, get_best_scale: bool = True,
mask_thr: float = 0.5, mask_thr: float = 0.5,
best_scale_clamp_thr: float = 1e-4, best_scale_clamp_thr: float = 1e-4,
use_disparity: bool = False,
disparity_eps: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Evaluate the depth error between the prediction `pred` and the ground Evaluate the depth error between the prediction `pred` and the ground
@ -64,6 +69,13 @@ def eval_depth(
# s.t. we get best possible mse error # s.t. we get best possible mse error
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr) scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
pred = pred * scale_best[:, None, None, None] pred = pred * scale_best[:, None, None, None]
if use_disparity:
gt = torch.div(1.0, (gt + disparity_eps))
pred = torch.div(1.0, (pred + disparity_eps))
scale_best = estimate_depth_scale_factor(
pred, gt, dmask, best_scale_clamp_thr
).detach()
pred = pred * scale_best[:, None, None, None]
df = gt - pred df = gt - pred
@ -117,6 +129,7 @@ def calc_bce(
pred_eps: float = 0.01, pred_eps: float = 0.01,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
lerp_bound: Optional[float] = None, lerp_bound: Optional[float] = None,
pred_logits: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Calculates the binary cross entropy. Calculates the binary cross entropy.
@ -139,9 +152,23 @@ def calc_bce(
weight = torch.ones_like(gt) * mask weight = torch.ones_like(gt) * mask
if lerp_bound is not None: if lerp_bound is not None:
# binary_cross_entropy_lerp requires pred to be in [0, 1]
if pred_logits:
pred = F.sigmoid(pred)
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound) return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
else: else:
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight) if pred_logits:
loss = F.binary_cross_entropy_with_logits(
pred,
gt,
reduction="none",
weight=weight,
)
else:
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
return loss.mean()
def binary_cross_entropy_lerp( def binary_cross_entropy_lerp(

View File

@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
import matplotlib import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import torch
from PIL import Image from PIL import Image
_NO_TORCHVISION = False
try:
import torchvision
except ImportError:
_NO_TORCHVISION = True
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg") _DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
matplotlib.use("Agg") matplotlib.use("Agg")
@ -36,6 +45,7 @@ class VideoWriter:
fps: int = 20, fps: int = 20,
output_format: str = "visdom", output_format: str = "visdom",
rmdir_allowed: bool = False, rmdir_allowed: bool = False,
use_torchvision_video_writer: bool = False,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -49,6 +59,8 @@ class VideoWriter:
is supported. is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty. it is not empty.
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
to write the video
""" """
self.rmdir_allowed = rmdir_allowed self.rmdir_allowed = rmdir_allowed
self.output_format = output_format self.output_format = output_format
@ -56,10 +68,14 @@ class VideoWriter:
self.out_path = out_path self.out_path = out_path
self.cache_dir = cache_dir self.cache_dir = cache_dir
self.ffmpeg_bin = ffmpeg_bin self.ffmpeg_bin = ffmpeg_bin
self.use_torchvision_video_writer = use_torchvision_video_writer
self.frames = [] self.frames = []
self.regexp = "frame_%08d.png" self.regexp = "frame_%08d.png"
self.frame_num = 0 self.frame_num = 0
if self.use_torchvision_video_writer:
assert not _NO_TORCHVISION, "torchvision not available"
if self.cache_dir is not None: if self.cache_dir is not None:
self.tmp_dir = None self.tmp_dir = None
if os.path.isdir(self.cache_dir): if os.path.isdir(self.cache_dir):
@ -114,7 +130,7 @@ class VideoWriter:
resize = im.size resize = im.size
# make sure size is divisible by 2 # make sure size is divisible by 2
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)]) resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
im = im.resize(resize, Image.ANTIALIAS) im = im.resize(resize, Image.ANTIALIAS)
im.save(outfile) im.save(outfile)
@ -139,38 +155,56 @@ class VideoWriter:
# got `Optional[str]`. # got `Optional[str]`.
regexp = os.path.join(self.cache_dir, self.regexp) regexp = os.path.join(self.cache_dir, self.regexp)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
if self.output_format == "visdom": # works for ppt too if self.output_format == "visdom": # works for ppt too
args = [ # Video codec parameters
self.ffmpeg_bin, video_codec = "h264"
"-r", crf = "18"
str(self.fps), b = "2000k"
"-i", pix_fmt = "yuv420p"
regexp,
"-vcodec", if self.use_torchvision_video_writer:
"h264", torchvision.io.write_video(
"-f", self.out_path,
"mp4", torch.stack(
"-y", [torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
"-crf", ),
"18", fps=self.fps,
"-b", video_codec=video_codec,
"2000k", options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
"-pix_fmt",
"yuv420p",
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
) )
else: else:
subprocess.check_call(args) if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
video_codec,
"-f",
"mp4",
"-y",
"-crf",
crf,
"-b",
b,
"-pix_fmt",
pix_fmt,
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
subprocess.check_call(args)
else: else:
raise ValueError("no such output type %s" % str(self.output_format)) raise ValueError("no such output type %s" % str(self.output_format))