Updates to Implicitron dataset, metrics and tools

Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack)

Reviewed By: shapovalov

Differential Revision: D65942513

fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
This commit is contained in:
Antoine Toisoul 2025-01-27 09:43:42 -08:00 committed by Facebook GitHub Bot
parent 42a4a7d432
commit 43cd681d4f
7 changed files with 240 additions and 57 deletions

View File

@ -48,6 +48,7 @@ from pytorch3d.implicitron.dataset.utils import (
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
@ -158,7 +159,7 @@ class FrameData(Mapping[str, Any]):
new_params = {}
for field_name in iter(self):
value = getattr(self, field_name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
new_params[field_name] = value.to(*args, **kwargs)
else:
new_params[field_name] = value
@ -420,7 +421,6 @@ class FrameData(Mapping[str, Any]):
for f in fields(elem):
if not f.init:
continue
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
@ -429,7 +429,7 @@ class FrameData(Mapping[str, Any]):
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
return type(elem)(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
@ -437,6 +437,8 @@ class FrameData(Mapping[str, Any]):
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
elif isinstance(elem, Meshes):
return join_meshes_as_batch(batch)
else:
return torch.utils.data.dataloader.default_collate(batch)
@ -592,6 +594,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None
mask_annotation = frame_annotation.mask
if mask_annotation is not None:
if load_blobs and self.load_masks:
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)

View File

@ -10,6 +10,7 @@ import hashlib
import json
import logging
import os
import urllib
from dataclasses import dataclass, Field, field
from typing import (
@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import (
FrameDataBuilder, # noqa
FrameDataBuilderBase,
)
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from sqlalchemy.orm import Session
from sqlalchemy.orm import scoped_session, Session, sessionmaker
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
engine verbatim. Dont expose it to end users of your application!
pick_categories: Restrict the dataset to the given list of categories.
pick_sequences: A Sequence of sequence names to restrict the dataset to.
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
exclude_sequences: A Sequence of the names of the sequences to exclude.
limit_sequences_per_category_to: Limit the dataset to the first up to N
sequences within each category (applies after all other sequence filters
@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
more frames than that; applied after other frame-level filters.
seed: The seed of the random generator sampling `n_frames_per_sequence`
random frames per sequence.
preload_metadata: If True, the metadata is preloaded into memory.
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
scoped_session: If True, allows different parts of the code to share
a global session to access the database.
"""
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
pick_categories: Tuple[str, ...] = ()
pick_sequences: Tuple[str, ...] = ()
pick_sequences_sql_clause: Optional[str] = None
exclude_sequences: Tuple[str, ...] = ()
limit_sequences_per_category_to: int = 0
limit_sequences_to: int = 0
@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
n_frames_per_sequence: int = -1
seed: int = 0
remove_empty_masks_poll_whole_table_threshold: int = 300_000
preload_metadata: bool = False
precompute_seq_to_idx: bool = False
# we set it manually in the constructor
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
_sql_engine: sa.engine.Engine = field(
@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
frame_data_builder_class_type: str = "FrameDataBuilder"
scoped_session: bool = False
def __post_init__(self) -> None:
if sa.__version__ < "2.0":
raise ImportError("This class requires SQL Alchemy 2.0 or later")
@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
)
if self.preload_metadata:
self._sql_engine = self._preload_database(self._sql_engine)
sequences = self._get_filtered_sequences_if_any()
if self.subsets:
@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
logger.info(str(self))
if self.scoped_session:
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
if self.precompute_seq_to_idx:
# This is deprecated and will be removed in the future.
# After we backport https://github.com/facebookresearch/uco3d/pull/3
logger.warning(
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
)
self._index["rowid"] = np.arange(len(self._index))
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
del self._index["rowid"]
def __len__(self) -> int:
return len(self._index)
@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
seq_stmt = sa.select(self.sequence_annotations_type).where(
self.sequence_annotations_type.sequence_name == seq
)
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
if self.scoped_session:
# pyre-ignore
with scoped_session(self._session_factory)() as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
else:
with Session(self._sql_engine) as session:
entry = session.scalars(stmt).one()
seq_metadata = session.scalars(seq_stmt).one()
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
yield from index_slice.itertuples(index=False)
# override
def sequence_indices_in_order(
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
) -> Iterator[int]:
"""Same as `sequence_frames_in_order` but returns the iterator over
only dataset indices.
"""
if self.precompute_seq_to_idx and subset_filter is None:
# pyre-ignore
yield from self._seq_to_indices[seq_name]
else:
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
yield idx
# override
def get_eval_batches(self) -> Optional[List[Any]]:
"""
@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
or self.limit_sequences_to > 0
or self.limit_sequences_per_category_to > 0
or len(self.pick_sequences) > 0
or self.pick_sequences_sql_clause is not None
or len(self.exclude_sequences) > 0
or len(self.pick_categories) > 0
or self.n_frames_per_sequence > 0
)
def _preload_database(
self, source_engine: sa.engine.base.Engine
) -> sa.engine.base.Engine:
destination_engine = sa.create_engine("sqlite:///:memory:")
metadata = sa.MetaData()
metadata.reflect(bind=source_engine)
metadata.create_all(bind=destination_engine)
with source_engine.connect() as source_conn:
with destination_engine.connect() as destination_conn:
for table_obj in metadata.tables.values():
# Select all rows from the source table
source_rows = source_conn.execute(table_obj.select())
# Insert rows into the destination table
for row in source_rows:
destination_conn.execute(table_obj.insert().values(row))
# Commit the changes for each table
destination_conn.commit()
return destination_engine
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
# maximum possible filter (if limit_sequences_per_category_to == 0):
# WHERE category IN 'self.pick_categories'
@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
*self._get_pick_filters(),
*self._get_exclude_filters(),
]
if self.pick_sequences_sql_clause:
print("Applying the custom SQL clause.")
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
def add_where(stmt):
return stmt.where(*where_conditions) if where_conditions else stmt
@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
self.frame_annotations_type.sequence_name == seq_name,
self.frame_annotations_type.frame_number.in_(frames),
)
frame_no_ts = None
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if self.scoped_session:
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
with scoped_session(self._session_factory)() as session: # pyre-ignore
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
else:
with self._sql_engine.connect() as connection:
frame_no_ts = pd.read_sql_query(stmt, connection)
if len(frame_no_ts) != len(index_slice):
raise ValueError(

View File

@ -284,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.")
eval_batches_file = self._get_lists_file("eval_batches")
del common_dataset_kwargs["eval_batches_file"]
if self.eval_batches_path is None:
eval_batches_file = None
else:
eval_batches_file = self._get_lists_file("eval_batches")
if "eval_batches_file" in common_dataset_kwargs:
common_dataset_kwargs.pop("eval_batches_file", None)
test_dataset = dataset_type(
**common_dataset_kwargs,
subsets=self._get_subsets(self.test_subsets, True),

View File

@ -211,14 +211,21 @@ def resize_image(
if isinstance(image, np.ndarray):
image = torch.from_numpy(image)
if image_height is None or image_width is None:
if (
image_height is None
or image_width is None
or image.shape[-2] == 0
or image.shape[-1] == 0
):
# skip the resizing
return image, 1.0, torch.ones_like(image[:1])
# takes numpy array or tensor, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
image[None],
scale_factor=minscale,
@ -226,6 +233,7 @@ def resize_image(
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], image_height, image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, image_height, image_width)
@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
return im.astype(np.float32) / 255.0
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
def load_image(
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
) -> np.ndarray:
"""
Load an image from a path and return it as a numpy array.
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
returned as the fourth channel.
Otherwise, the image is read as RGB and a three-channel image is returned.
"""
with Image.open(path) as pil_im:
# Check if the image has an alpha channel
if try_read_alpha and pil_im.mode == "RGBA":
im = np.array(pil_im)
else:
im = np.array(pil_im.convert("RGB"))
im = np.array(pil_im.convert(pil_format))
return transpose_normalize_image(im)
@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_(
)
camera.focal_length = focal_length_scaled[None]
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
camera.principal_point = principal_point_scaled[None]
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
# NOTE this cache is per-worker; they are implemented as processes.

View File

@ -6,7 +6,6 @@
# pyre-unsafe
import warnings
from typing import Any, Dict, Optional
@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
_rgb_metrics(
image_rgb,
image_rgb_pred,
fg_probability,
fg_probability_pred,
mask_crop,
masks=fg_probability,
masks_crop=mask_crop,
)
)
@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
metrics["mask_neg_iou"] = utils.neg_iou_loss(
fg_probability_pred, fg_probability, mask=mask_crop
)
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred, fg_probability, mask=mask_crop
)
if torch.is_autocast_enabled():
# To avoid issues with mixed precision
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred.logit(),
fg_probability,
mask=mask_crop,
pred_logits=True,
)
else:
metrics["mask_bce"] = utils.calc_bce(
fg_probability_pred,
fg_probability,
mask=mask_crop,
pred_logits=False,
)
if depth_map is not None and depth_map_pred is not None:
assert mask_crop is not None
@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
if fg_probability is not None:
mask = fg_probability * mask_crop
_, abs_ = utils.eval_depth(
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
depth_map_pred,
depth_map,
get_best_scale=True,
mask=mask,
crop=0,
)
metrics["depth_abs_fg"] = abs_.mean()
@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
return metrics
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
def _rgb_metrics(
images,
images_pred,
masks=None,
masks_crop=None,
huber_scaling: float = 0.03,
):
assert masks_crop is not None
if images.shape[1] != images_pred.shape[1]:
raise ValueError(
f"Network output's RGB images had {images_pred.shape[1]} "
f"channels. {images.shape[1]} expected."
)
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
crop_mass = masks_crop.sum().clamp(1.0)
results = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
}

View File

@ -6,12 +6,15 @@
# pyre-unsafe
import logging
import math
from typing import Optional, Tuple
import torch
from torch.nn import functional as F
logger = logging.getLogger(__name__)
def eval_depth(
pred: torch.Tensor,
@ -21,6 +24,8 @@ def eval_depth(
get_best_scale: bool = True,
mask_thr: float = 0.5,
best_scale_clamp_thr: float = 1e-4,
use_disparity: bool = False,
disparity_eps: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluate the depth error between the prediction `pred` and the ground
@ -64,6 +69,13 @@ def eval_depth(
# s.t. we get best possible mse error
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
pred = pred * scale_best[:, None, None, None]
if use_disparity:
gt = torch.div(1.0, (gt + disparity_eps))
pred = torch.div(1.0, (pred + disparity_eps))
scale_best = estimate_depth_scale_factor(
pred, gt, dmask, best_scale_clamp_thr
).detach()
pred = pred * scale_best[:, None, None, None]
df = gt - pred
@ -117,6 +129,7 @@ def calc_bce(
pred_eps: float = 0.01,
mask: Optional[torch.Tensor] = None,
lerp_bound: Optional[float] = None,
pred_logits: bool = False,
) -> torch.Tensor:
"""
Calculates the binary cross entropy.
@ -139,9 +152,23 @@ def calc_bce(
weight = torch.ones_like(gt) * mask
if lerp_bound is not None:
# binary_cross_entropy_lerp requires pred to be in [0, 1]
if pred_logits:
pred = F.sigmoid(pred)
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
else:
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
if pred_logits:
loss = F.binary_cross_entropy_with_logits(
pred,
gt,
reduction="none",
weight=weight,
)
else:
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
return loss.mean()
def binary_cross_entropy_lerp(

View File

@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from PIL import Image
_NO_TORCHVISION = False
try:
import torchvision
except ImportError:
_NO_TORCHVISION = True
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
matplotlib.use("Agg")
@ -36,6 +45,7 @@ class VideoWriter:
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
use_torchvision_video_writer: bool = False,
**kwargs,
) -> None:
"""
@ -49,6 +59,8 @@ class VideoWriter:
is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty.
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
to write the video
"""
self.rmdir_allowed = rmdir_allowed
self.output_format = output_format
@ -56,10 +68,14 @@ class VideoWriter:
self.out_path = out_path
self.cache_dir = cache_dir
self.ffmpeg_bin = ffmpeg_bin
self.use_torchvision_video_writer = use_torchvision_video_writer
self.frames = []
self.regexp = "frame_%08d.png"
self.frame_num = 0
if self.use_torchvision_video_writer:
assert not _NO_TORCHVISION, "torchvision not available"
if self.cache_dir is not None:
self.tmp_dir = None
if os.path.isdir(self.cache_dir):
@ -114,7 +130,7 @@ class VideoWriter:
resize = im.size
# make sure size is divisible by 2
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
im = im.resize(resize, Image.ANTIALIAS)
im.save(outfile)
@ -139,38 +155,56 @@ class VideoWriter:
# got `Optional[str]`.
regexp = os.path.join(self.cache_dir, self.regexp)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
if self.output_format == "visdom": # works for ppt too
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
"h264",
"-f",
"mp4",
"-y",
"-crf",
"18",
"-b",
"2000k",
"-pix_fmt",
"yuv420p",
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
# Video codec parameters
video_codec = "h264"
crf = "18"
b = "2000k"
pix_fmt = "yuv420p"
if self.use_torchvision_video_writer:
torchvision.io.write_video(
self.out_path,
torch.stack(
[torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
),
fps=self.fps,
video_codec=video_codec,
options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
)
else:
subprocess.check_call(args)
if shutil.which(self.ffmpeg_bin) is None:
raise ValueError(
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
)
args = [
self.ffmpeg_bin,
"-r",
str(self.fps),
"-i",
regexp,
"-vcodec",
video_codec,
"-f",
"mp4",
"-y",
"-crf",
crf,
"-b",
b,
"-pix_fmt",
pix_fmt,
self.out_path,
]
if quiet:
subprocess.check_call(
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
subprocess.check_call(args)
else:
raise ValueError("no such output type %s" % str(self.output_format))