mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Updates to Implicitron dataset, metrics and tools
Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack) Reviewed By: shapovalov Differential Revision: D65942513 fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
This commit is contained in:
parent
42a4a7d432
commit
43cd681d4f
@ -48,6 +48,7 @@ from pytorch3d.implicitron.dataset.utils import (
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||
|
||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||
@ -158,7 +159,7 @@ class FrameData(Mapping[str, Any]):
|
||||
new_params = {}
|
||||
for field_name in iter(self):
|
||||
value = getattr(self, field_name)
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||
new_params[field_name] = value.to(*args, **kwargs)
|
||||
else:
|
||||
new_params[field_name] = value
|
||||
@ -420,7 +421,6 @@ class FrameData(Mapping[str, Any]):
|
||||
for f in fields(elem):
|
||||
if not f.init:
|
||||
continue
|
||||
|
||||
list_values = override_fields.get(
|
||||
f.name, [getattr(d, f.name) for d in batch]
|
||||
)
|
||||
@ -429,7 +429,7 @@ class FrameData(Mapping[str, Any]):
|
||||
if all(list_value is not None for list_value in list_values)
|
||||
else None
|
||||
)
|
||||
return cls(**collated)
|
||||
return type(elem)(**collated)
|
||||
|
||||
elif isinstance(elem, Pointclouds):
|
||||
return join_pointclouds_as_batch(batch)
|
||||
@ -437,6 +437,8 @@ class FrameData(Mapping[str, Any]):
|
||||
elif isinstance(elem, CamerasBase):
|
||||
# TODO: don't store K; enforce working in NDC space
|
||||
return join_cameras_as_batch(batch)
|
||||
elif isinstance(elem, Meshes):
|
||||
return join_meshes_as_batch(batch)
|
||||
else:
|
||||
return torch.utils.data.dataloader.default_collate(batch)
|
||||
|
||||
@ -592,6 +594,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
fg_mask_np: np.ndarray | None = None
|
||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||
mask_annotation = frame_annotation.mask
|
||||
|
||||
if mask_annotation is not None:
|
||||
if load_blobs and self.load_masks:
|
||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||
|
@ -10,6 +10,7 @@ import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
import urllib
|
||||
from dataclasses import dataclass, Field, field
|
||||
from typing import (
|
||||
@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import (
|
||||
FrameDataBuilder, # noqa
|
||||
FrameDataBuilderBase,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||
|
||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||
|
||||
@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
engine verbatim. Don’t expose it to end users of your application!
|
||||
pick_categories: Restrict the dataset to the given list of categories.
|
||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||
sequences within each category (applies after all other sequence filters
|
||||
@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
more frames than that; applied after other frame-level filters.
|
||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||
random frames per sequence.
|
||||
preload_metadata: If True, the metadata is preloaded into memory.
|
||||
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||
scoped_session: If True, allows different parts of the code to share
|
||||
a global session to access the database.
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||
@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
pick_categories: Tuple[str, ...] = ()
|
||||
|
||||
pick_sequences: Tuple[str, ...] = ()
|
||||
pick_sequences_sql_clause: Optional[str] = None
|
||||
exclude_sequences: Tuple[str, ...] = ()
|
||||
limit_sequences_per_category_to: int = 0
|
||||
limit_sequences_to: int = 0
|
||||
@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
n_frames_per_sequence: int = -1
|
||||
seed: int = 0
|
||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||
preload_metadata: bool = False
|
||||
precompute_seq_to_idx: bool = False
|
||||
# we set it manually in the constructor
|
||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||
_sql_engine: sa.engine.Engine = field(
|
||||
@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||
|
||||
scoped_session: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if sa.__version__ < "2.0":
|
||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||
@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||
)
|
||||
|
||||
if self.preload_metadata:
|
||||
self._sql_engine = self._preload_database(self._sql_engine)
|
||||
|
||||
sequences = self._get_filtered_sequences_if_any()
|
||||
|
||||
if self.subsets:
|
||||
@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
|
||||
logger.info(str(self))
|
||||
|
||||
if self.scoped_session:
|
||||
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||
|
||||
if self.precompute_seq_to_idx:
|
||||
# This is deprecated and will be removed in the future.
|
||||
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||
logger.warning(
|
||||
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||
)
|
||||
self._index["rowid"] = np.arange(len(self._index))
|
||||
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||
del self._index["rowid"]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._index)
|
||||
|
||||
@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||
self.sequence_annotations_type.sequence_name == seq
|
||||
)
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
if self.scoped_session:
|
||||
# pyre-ignore
|
||||
with scoped_session(self._session_factory)() as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
else:
|
||||
with Session(self._sql_engine) as session:
|
||||
entry = session.scalars(stmt).one()
|
||||
seq_metadata = session.scalars(seq_stmt).one()
|
||||
|
||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||
|
||||
@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
|
||||
yield from index_slice.itertuples(index=False)
|
||||
|
||||
# override
|
||||
def sequence_indices_in_order(
|
||||
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||
) -> Iterator[int]:
|
||||
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||
only dataset indices.
|
||||
"""
|
||||
if self.precompute_seq_to_idx and subset_filter is None:
|
||||
# pyre-ignore
|
||||
yield from self._seq_to_indices[seq_name]
|
||||
else:
|
||||
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||
yield idx
|
||||
|
||||
# override
|
||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||
"""
|
||||
@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
or self.limit_sequences_to > 0
|
||||
or self.limit_sequences_per_category_to > 0
|
||||
or len(self.pick_sequences) > 0
|
||||
or self.pick_sequences_sql_clause is not None
|
||||
or len(self.exclude_sequences) > 0
|
||||
or len(self.pick_categories) > 0
|
||||
or self.n_frames_per_sequence > 0
|
||||
)
|
||||
|
||||
def _preload_database(
|
||||
self, source_engine: sa.engine.base.Engine
|
||||
) -> sa.engine.base.Engine:
|
||||
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||
metadata = sa.MetaData()
|
||||
metadata.reflect(bind=source_engine)
|
||||
metadata.create_all(bind=destination_engine)
|
||||
|
||||
with source_engine.connect() as source_conn:
|
||||
with destination_engine.connect() as destination_conn:
|
||||
for table_obj in metadata.tables.values():
|
||||
# Select all rows from the source table
|
||||
source_rows = source_conn.execute(table_obj.select())
|
||||
|
||||
# Insert rows into the destination table
|
||||
for row in source_rows:
|
||||
destination_conn.execute(table_obj.insert().values(row))
|
||||
|
||||
# Commit the changes for each table
|
||||
destination_conn.commit()
|
||||
|
||||
return destination_engine
|
||||
|
||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||
# WHERE category IN 'self.pick_categories'
|
||||
@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
*self._get_pick_filters(),
|
||||
*self._get_exclude_filters(),
|
||||
]
|
||||
if self.pick_sequences_sql_clause:
|
||||
print("Applying the custom SQL clause.")
|
||||
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||
|
||||
def add_where(stmt):
|
||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||
@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self.frame_annotations_type.sequence_name == seq_name,
|
||||
self.frame_annotations_type.frame_number.in_(frames),
|
||||
)
|
||||
frame_no_ts = None
|
||||
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
if self.scoped_session:
|
||||
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||
else:
|
||||
with self._sql_engine.connect() as connection:
|
||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||
|
||||
if len(frame_no_ts) != len(index_slice):
|
||||
raise ValueError(
|
||||
|
@ -284,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||
|
||||
logger.debug("Extracting test dataset.")
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
del common_dataset_kwargs["eval_batches_file"]
|
||||
if self.eval_batches_path is None:
|
||||
eval_batches_file = None
|
||||
else:
|
||||
eval_batches_file = self._get_lists_file("eval_batches")
|
||||
|
||||
if "eval_batches_file" in common_dataset_kwargs:
|
||||
common_dataset_kwargs.pop("eval_batches_file", None)
|
||||
|
||||
test_dataset = dataset_type(
|
||||
**common_dataset_kwargs,
|
||||
subsets=self._get_subsets(self.test_subsets, True),
|
||||
|
@ -211,14 +211,21 @@ def resize_image(
|
||||
if isinstance(image, np.ndarray):
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
if image_height is None or image_width is None:
|
||||
if (
|
||||
image_height is None
|
||||
or image_width is None
|
||||
or image.shape[-2] == 0
|
||||
or image.shape[-1] == 0
|
||||
):
|
||||
# skip the resizing
|
||||
return image, 1.0, torch.ones_like(image[:1])
|
||||
|
||||
# takes numpy array or tensor, returns pytorch tensor
|
||||
minscale = min(
|
||||
image_height / image.shape[-2],
|
||||
image_width / image.shape[-1],
|
||||
)
|
||||
|
||||
imre = torch.nn.functional.interpolate(
|
||||
image[None],
|
||||
scale_factor=minscale,
|
||||
@ -226,6 +233,7 @@ def resize_image(
|
||||
align_corners=False if mode == "bilinear" else None,
|
||||
recompute_scale_factor=True,
|
||||
)[0]
|
||||
|
||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||
mask = torch.zeros(1, image_height, image_width)
|
||||
@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
|
||||
def load_image(
|
||||
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load an image from a path and return it as a numpy array.
|
||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||
returned as the fourth channel.
|
||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||
"""
|
||||
|
||||
with Image.open(path) as pil_im:
|
||||
# Check if the image has an alpha channel
|
||||
if try_read_alpha and pil_im.mode == "RGBA":
|
||||
im = np.array(pil_im)
|
||||
else:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
im = np.array(pil_im.convert(pil_format))
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_(
|
||||
)
|
||||
camera.focal_length = focal_length_scaled[None]
|
||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||
camera.principal_point = principal_point_scaled[None]
|
||||
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
||||
|
||||
|
||||
# NOTE this cache is per-worker; they are implemented as processes.
|
||||
|
@ -6,7 +6,6 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
|
||||
_rgb_metrics(
|
||||
image_rgb,
|
||||
image_rgb_pred,
|
||||
fg_probability,
|
||||
fg_probability_pred,
|
||||
mask_crop,
|
||||
masks=fg_probability,
|
||||
masks_crop=mask_crop,
|
||||
)
|
||||
)
|
||||
|
||||
@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
|
||||
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred, fg_probability, mask=mask_crop
|
||||
)
|
||||
if torch.is_autocast_enabled():
|
||||
# To avoid issues with mixed precision
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred.logit(),
|
||||
fg_probability,
|
||||
mask=mask_crop,
|
||||
pred_logits=True,
|
||||
)
|
||||
else:
|
||||
metrics["mask_bce"] = utils.calc_bce(
|
||||
fg_probability_pred,
|
||||
fg_probability,
|
||||
mask=mask_crop,
|
||||
pred_logits=False,
|
||||
)
|
||||
|
||||
if depth_map is not None and depth_map_pred is not None:
|
||||
assert mask_crop is not None
|
||||
@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
|
||||
if fg_probability is not None:
|
||||
mask = fg_probability * mask_crop
|
||||
_, abs_ = utils.eval_depth(
|
||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
||||
depth_map_pred,
|
||||
depth_map,
|
||||
get_best_scale=True,
|
||||
mask=mask,
|
||||
crop=0,
|
||||
)
|
||||
metrics["depth_abs_fg"] = abs_.mean()
|
||||
|
||||
@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
|
||||
return metrics
|
||||
|
||||
|
||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
||||
def _rgb_metrics(
|
||||
images,
|
||||
images_pred,
|
||||
masks=None,
|
||||
masks_crop=None,
|
||||
huber_scaling: float = 0.03,
|
||||
):
|
||||
assert masks_crop is not None
|
||||
if images.shape[1] != images_pred.shape[1]:
|
||||
raise ValueError(
|
||||
f"Network output's RGB images had {images_pred.shape[1]} "
|
||||
f"channels. {images.shape[1]} expected."
|
||||
)
|
||||
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
|
||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
||||
rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
|
||||
crop_mass = masks_crop.sum().clamp(1.0)
|
||||
results = {
|
||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
|
||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||
}
|
||||
|
@ -6,12 +6,15 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def eval_depth(
|
||||
pred: torch.Tensor,
|
||||
@ -21,6 +24,8 @@ def eval_depth(
|
||||
get_best_scale: bool = True,
|
||||
mask_thr: float = 0.5,
|
||||
best_scale_clamp_thr: float = 1e-4,
|
||||
use_disparity: bool = False,
|
||||
disparity_eps: float = 1e-4,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Evaluate the depth error between the prediction `pred` and the ground
|
||||
@ -64,6 +69,13 @@ def eval_depth(
|
||||
# s.t. we get best possible mse error
|
||||
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
|
||||
pred = pred * scale_best[:, None, None, None]
|
||||
if use_disparity:
|
||||
gt = torch.div(1.0, (gt + disparity_eps))
|
||||
pred = torch.div(1.0, (pred + disparity_eps))
|
||||
scale_best = estimate_depth_scale_factor(
|
||||
pred, gt, dmask, best_scale_clamp_thr
|
||||
).detach()
|
||||
pred = pred * scale_best[:, None, None, None]
|
||||
|
||||
df = gt - pred
|
||||
|
||||
@ -117,6 +129,7 @@ def calc_bce(
|
||||
pred_eps: float = 0.01,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
lerp_bound: Optional[float] = None,
|
||||
pred_logits: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the binary cross entropy.
|
||||
@ -139,9 +152,23 @@ def calc_bce(
|
||||
weight = torch.ones_like(gt) * mask
|
||||
|
||||
if lerp_bound is not None:
|
||||
# binary_cross_entropy_lerp requires pred to be in [0, 1]
|
||||
if pred_logits:
|
||||
pred = F.sigmoid(pred)
|
||||
|
||||
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
|
||||
else:
|
||||
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
|
||||
if pred_logits:
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
pred,
|
||||
gt,
|
||||
reduction="none",
|
||||
weight=weight,
|
||||
)
|
||||
else:
|
||||
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
|
||||
|
||||
return loss.mean()
|
||||
|
||||
|
||||
def binary_cross_entropy_lerp(
|
||||
|
@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from PIL import Image
|
||||
|
||||
_NO_TORCHVISION = False
|
||||
try:
|
||||
import torchvision
|
||||
except ImportError:
|
||||
_NO_TORCHVISION = True
|
||||
|
||||
|
||||
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
|
||||
|
||||
matplotlib.use("Agg")
|
||||
@ -36,6 +45,7 @@ class VideoWriter:
|
||||
fps: int = 20,
|
||||
output_format: str = "visdom",
|
||||
rmdir_allowed: bool = False,
|
||||
use_torchvision_video_writer: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@ -49,6 +59,8 @@ class VideoWriter:
|
||||
is supported.
|
||||
rmdir_allowed: If `True` delete and create `cache_dir` in case
|
||||
it is not empty.
|
||||
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
|
||||
to write the video
|
||||
"""
|
||||
self.rmdir_allowed = rmdir_allowed
|
||||
self.output_format = output_format
|
||||
@ -56,10 +68,14 @@ class VideoWriter:
|
||||
self.out_path = out_path
|
||||
self.cache_dir = cache_dir
|
||||
self.ffmpeg_bin = ffmpeg_bin
|
||||
self.use_torchvision_video_writer = use_torchvision_video_writer
|
||||
self.frames = []
|
||||
self.regexp = "frame_%08d.png"
|
||||
self.frame_num = 0
|
||||
|
||||
if self.use_torchvision_video_writer:
|
||||
assert not _NO_TORCHVISION, "torchvision not available"
|
||||
|
||||
if self.cache_dir is not None:
|
||||
self.tmp_dir = None
|
||||
if os.path.isdir(self.cache_dir):
|
||||
@ -114,7 +130,7 @@ class VideoWriter:
|
||||
resize = im.size
|
||||
# make sure size is divisible by 2
|
||||
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
||||
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
|
||||
|
||||
im = im.resize(resize, Image.ANTIALIAS)
|
||||
im.save(outfile)
|
||||
|
||||
@ -139,38 +155,56 @@ class VideoWriter:
|
||||
# got `Optional[str]`.
|
||||
regexp = os.path.join(self.cache_dir, self.regexp)
|
||||
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
if self.output_format == "visdom": # works for ppt too
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
"h264",
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
"18",
|
||||
"-b",
|
||||
"2000k",
|
||||
"-pix_fmt",
|
||||
"yuv420p",
|
||||
self.out_path,
|
||||
]
|
||||
if quiet:
|
||||
subprocess.check_call(
|
||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
# Video codec parameters
|
||||
video_codec = "h264"
|
||||
crf = "18"
|
||||
b = "2000k"
|
||||
pix_fmt = "yuv420p"
|
||||
|
||||
if self.use_torchvision_video_writer:
|
||||
torchvision.io.write_video(
|
||||
self.out_path,
|
||||
torch.stack(
|
||||
[torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
|
||||
),
|
||||
fps=self.fps,
|
||||
video_codec=video_codec,
|
||||
options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
|
||||
)
|
||||
|
||||
else:
|
||||
subprocess.check_call(args)
|
||||
if shutil.which(self.ffmpeg_bin) is None:
|
||||
raise ValueError(
|
||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||
)
|
||||
|
||||
args = [
|
||||
self.ffmpeg_bin,
|
||||
"-r",
|
||||
str(self.fps),
|
||||
"-i",
|
||||
regexp,
|
||||
"-vcodec",
|
||||
video_codec,
|
||||
"-f",
|
||||
"mp4",
|
||||
"-y",
|
||||
"-crf",
|
||||
crf,
|
||||
"-b",
|
||||
b,
|
||||
"-pix_fmt",
|
||||
pix_fmt,
|
||||
self.out_path,
|
||||
]
|
||||
if quiet:
|
||||
subprocess.check_call(
|
||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||
)
|
||||
else:
|
||||
subprocess.check_call(args)
|
||||
else:
|
||||
raise ValueError("no such output type %s" % str(self.output_format))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user