mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Updates to Implicitron dataset, metrics and tools
Summary: Update Pytorch3D to be able to run assetgen (see later diffs in the stack) Reviewed By: shapovalov Differential Revision: D65942513 fbshipit-source-id: 1d01141c9f7e106608fa591be6e0d3262cb5944f
This commit is contained in:
parent
42a4a7d432
commit
43cd681d4f
@ -48,6 +48,7 @@ from pytorch3d.implicitron.dataset.utils import (
|
|||||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||||
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
|
||||||
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
|
||||||
|
from pytorch3d.structures.meshes import join_meshes_as_batch, Meshes
|
||||||
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
from pytorch3d.structures.pointclouds import join_pointclouds_as_batch, Pointclouds
|
||||||
|
|
||||||
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
FrameAnnotationT = types.FrameAnnotation | orm_types.SqlFrameAnnotation
|
||||||
@ -158,7 +159,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
new_params = {}
|
new_params = {}
|
||||||
for field_name in iter(self):
|
for field_name in iter(self):
|
||||||
value = getattr(self, field_name)
|
value = getattr(self, field_name)
|
||||||
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
|
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase, Meshes)):
|
||||||
new_params[field_name] = value.to(*args, **kwargs)
|
new_params[field_name] = value.to(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
new_params[field_name] = value
|
new_params[field_name] = value
|
||||||
@ -420,7 +421,6 @@ class FrameData(Mapping[str, Any]):
|
|||||||
for f in fields(elem):
|
for f in fields(elem):
|
||||||
if not f.init:
|
if not f.init:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
list_values = override_fields.get(
|
list_values = override_fields.get(
|
||||||
f.name, [getattr(d, f.name) for d in batch]
|
f.name, [getattr(d, f.name) for d in batch]
|
||||||
)
|
)
|
||||||
@ -429,7 +429,7 @@ class FrameData(Mapping[str, Any]):
|
|||||||
if all(list_value is not None for list_value in list_values)
|
if all(list_value is not None for list_value in list_values)
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
return cls(**collated)
|
return type(elem)(**collated)
|
||||||
|
|
||||||
elif isinstance(elem, Pointclouds):
|
elif isinstance(elem, Pointclouds):
|
||||||
return join_pointclouds_as_batch(batch)
|
return join_pointclouds_as_batch(batch)
|
||||||
@ -437,6 +437,8 @@ class FrameData(Mapping[str, Any]):
|
|||||||
elif isinstance(elem, CamerasBase):
|
elif isinstance(elem, CamerasBase):
|
||||||
# TODO: don't store K; enforce working in NDC space
|
# TODO: don't store K; enforce working in NDC space
|
||||||
return join_cameras_as_batch(batch)
|
return join_cameras_as_batch(batch)
|
||||||
|
elif isinstance(elem, Meshes):
|
||||||
|
return join_meshes_as_batch(batch)
|
||||||
else:
|
else:
|
||||||
return torch.utils.data.dataloader.default_collate(batch)
|
return torch.utils.data.dataloader.default_collate(batch)
|
||||||
|
|
||||||
@ -592,6 +594,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
|||||||
fg_mask_np: np.ndarray | None = None
|
fg_mask_np: np.ndarray | None = None
|
||||||
bbox_xywh: tuple[float, float, float, float] | None = None
|
bbox_xywh: tuple[float, float, float, float] | None = None
|
||||||
mask_annotation = frame_annotation.mask
|
mask_annotation = frame_annotation.mask
|
||||||
|
|
||||||
if mask_annotation is not None:
|
if mask_annotation is not None:
|
||||||
if load_blobs and self.load_masks:
|
if load_blobs and self.load_masks:
|
||||||
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
|
||||||
|
@ -10,6 +10,7 @@ import hashlib
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import urllib
|
import urllib
|
||||||
from dataclasses import dataclass, Field, field
|
from dataclasses import dataclass, Field, field
|
||||||
from typing import (
|
from typing import (
|
||||||
@ -37,12 +38,13 @@ from pytorch3d.implicitron.dataset.frame_data import (
|
|||||||
FrameDataBuilder, # noqa
|
FrameDataBuilder, # noqa
|
||||||
FrameDataBuilderBase,
|
FrameDataBuilderBase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
registry,
|
registry,
|
||||||
ReplaceableBase,
|
ReplaceableBase,
|
||||||
run_auto_creation,
|
run_auto_creation,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import scoped_session, Session, sessionmaker
|
||||||
|
|
||||||
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
from .orm_types import SqlFrameAnnotation, SqlSequenceAnnotation
|
||||||
|
|
||||||
@ -91,6 +93,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
engine verbatim. Don’t expose it to end users of your application!
|
engine verbatim. Don’t expose it to end users of your application!
|
||||||
pick_categories: Restrict the dataset to the given list of categories.
|
pick_categories: Restrict the dataset to the given list of categories.
|
||||||
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
pick_sequences: A Sequence of sequence names to restrict the dataset to.
|
||||||
|
pick_sequences_sql_clause: Custom SQL WHERE clause to constrain sequence annotations.
|
||||||
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
exclude_sequences: A Sequence of the names of the sequences to exclude.
|
||||||
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
limit_sequences_per_category_to: Limit the dataset to the first up to N
|
||||||
sequences within each category (applies after all other sequence filters
|
sequences within each category (applies after all other sequence filters
|
||||||
@ -105,6 +108,10 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
more frames than that; applied after other frame-level filters.
|
more frames than that; applied after other frame-level filters.
|
||||||
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
seed: The seed of the random generator sampling `n_frames_per_sequence`
|
||||||
random frames per sequence.
|
random frames per sequence.
|
||||||
|
preload_metadata: If True, the metadata is preloaded into memory.
|
||||||
|
precompute_seq_to_idx: If True, precomputes the mapping from sequence name to indices.
|
||||||
|
scoped_session: If True, allows different parts of the code to share
|
||||||
|
a global session to access the database.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
frame_annotations_type: ClassVar[Type[SqlFrameAnnotation]] = SqlFrameAnnotation
|
||||||
@ -123,6 +130,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
pick_categories: Tuple[str, ...] = ()
|
pick_categories: Tuple[str, ...] = ()
|
||||||
|
|
||||||
pick_sequences: Tuple[str, ...] = ()
|
pick_sequences: Tuple[str, ...] = ()
|
||||||
|
pick_sequences_sql_clause: Optional[str] = None
|
||||||
exclude_sequences: Tuple[str, ...] = ()
|
exclude_sequences: Tuple[str, ...] = ()
|
||||||
limit_sequences_per_category_to: int = 0
|
limit_sequences_per_category_to: int = 0
|
||||||
limit_sequences_to: int = 0
|
limit_sequences_to: int = 0
|
||||||
@ -130,6 +138,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
n_frames_per_sequence: int = -1
|
n_frames_per_sequence: int = -1
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
remove_empty_masks_poll_whole_table_threshold: int = 300_000
|
||||||
|
preload_metadata: bool = False
|
||||||
|
precompute_seq_to_idx: bool = False
|
||||||
# we set it manually in the constructor
|
# we set it manually in the constructor
|
||||||
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
_index: pd.DataFrame = field(init=False, metadata={"omegaconf_ignore": True})
|
||||||
_sql_engine: sa.engine.Engine = field(
|
_sql_engine: sa.engine.Engine = field(
|
||||||
@ -142,6 +152,8 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
frame_data_builder: FrameDataBuilderBase # pyre-ignore[13]
|
||||||
frame_data_builder_class_type: str = "FrameDataBuilder"
|
frame_data_builder_class_type: str = "FrameDataBuilder"
|
||||||
|
|
||||||
|
scoped_session: bool = False
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if sa.__version__ < "2.0":
|
if sa.__version__ < "2.0":
|
||||||
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
raise ImportError("This class requires SQL Alchemy 2.0 or later")
|
||||||
@ -169,6 +181,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
f"sqlite:///file:{urllib.parse.quote(self.sqlite_metadata_file)}?mode=ro&uri=true"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.preload_metadata:
|
||||||
|
self._sql_engine = self._preload_database(self._sql_engine)
|
||||||
|
|
||||||
sequences = self._get_filtered_sequences_if_any()
|
sequences = self._get_filtered_sequences_if_any()
|
||||||
|
|
||||||
if self.subsets:
|
if self.subsets:
|
||||||
@ -192,6 +207,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
|
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
|
if self.scoped_session:
|
||||||
|
self._session_factory = sessionmaker(bind=self._sql_engine) # pyre-ignore
|
||||||
|
|
||||||
|
if self.precompute_seq_to_idx:
|
||||||
|
# This is deprecated and will be removed in the future.
|
||||||
|
# After we backport https://github.com/facebookresearch/uco3d/pull/3
|
||||||
|
logger.warning(
|
||||||
|
"Using precompute_seq_to_idx is deprecated and will be removed in the future."
|
||||||
|
)
|
||||||
|
self._index["rowid"] = np.arange(len(self._index))
|
||||||
|
groupby = self._index.groupby("sequence_name", sort=False)["rowid"]
|
||||||
|
self._seq_to_indices = dict(groupby.apply(list)) # pyre-ignore
|
||||||
|
del self._index["rowid"]
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
return len(self._index)
|
return len(self._index)
|
||||||
|
|
||||||
@ -252,9 +281,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
seq_stmt = sa.select(self.sequence_annotations_type).where(
|
||||||
self.sequence_annotations_type.sequence_name == seq
|
self.sequence_annotations_type.sequence_name == seq
|
||||||
)
|
)
|
||||||
with Session(self._sql_engine) as session:
|
if self.scoped_session:
|
||||||
entry = session.scalars(stmt).one()
|
# pyre-ignore
|
||||||
seq_metadata = session.scalars(seq_stmt).one()
|
with scoped_session(self._session_factory)() as session:
|
||||||
|
entry = session.scalars(stmt).one()
|
||||||
|
seq_metadata = session.scalars(seq_stmt).one()
|
||||||
|
else:
|
||||||
|
with Session(self._sql_engine) as session:
|
||||||
|
entry = session.scalars(stmt).one()
|
||||||
|
seq_metadata = session.scalars(seq_stmt).one()
|
||||||
|
|
||||||
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
assert entry.image.path == self._index.loc[(seq, frame), "_image_path"]
|
||||||
|
|
||||||
@ -363,6 +398,20 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
|
|
||||||
yield from index_slice.itertuples(index=False)
|
yield from index_slice.itertuples(index=False)
|
||||||
|
|
||||||
|
# override
|
||||||
|
def sequence_indices_in_order(
|
||||||
|
self, seq_name: str, subset_filter: Optional[Sequence[str]] = None
|
||||||
|
) -> Iterator[int]:
|
||||||
|
"""Same as `sequence_frames_in_order` but returns the iterator over
|
||||||
|
only dataset indices.
|
||||||
|
"""
|
||||||
|
if self.precompute_seq_to_idx and subset_filter is None:
|
||||||
|
# pyre-ignore
|
||||||
|
yield from self._seq_to_indices[seq_name]
|
||||||
|
else:
|
||||||
|
for _, _, idx in self.sequence_frames_in_order(seq_name, subset_filter):
|
||||||
|
yield idx
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def get_eval_batches(self) -> Optional[List[Any]]:
|
def get_eval_batches(self) -> Optional[List[Any]]:
|
||||||
"""
|
"""
|
||||||
@ -396,11 +445,35 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
or self.limit_sequences_to > 0
|
or self.limit_sequences_to > 0
|
||||||
or self.limit_sequences_per_category_to > 0
|
or self.limit_sequences_per_category_to > 0
|
||||||
or len(self.pick_sequences) > 0
|
or len(self.pick_sequences) > 0
|
||||||
|
or self.pick_sequences_sql_clause is not None
|
||||||
or len(self.exclude_sequences) > 0
|
or len(self.exclude_sequences) > 0
|
||||||
or len(self.pick_categories) > 0
|
or len(self.pick_categories) > 0
|
||||||
or self.n_frames_per_sequence > 0
|
or self.n_frames_per_sequence > 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _preload_database(
|
||||||
|
self, source_engine: sa.engine.base.Engine
|
||||||
|
) -> sa.engine.base.Engine:
|
||||||
|
destination_engine = sa.create_engine("sqlite:///:memory:")
|
||||||
|
metadata = sa.MetaData()
|
||||||
|
metadata.reflect(bind=source_engine)
|
||||||
|
metadata.create_all(bind=destination_engine)
|
||||||
|
|
||||||
|
with source_engine.connect() as source_conn:
|
||||||
|
with destination_engine.connect() as destination_conn:
|
||||||
|
for table_obj in metadata.tables.values():
|
||||||
|
# Select all rows from the source table
|
||||||
|
source_rows = source_conn.execute(table_obj.select())
|
||||||
|
|
||||||
|
# Insert rows into the destination table
|
||||||
|
for row in source_rows:
|
||||||
|
destination_conn.execute(table_obj.insert().values(row))
|
||||||
|
|
||||||
|
# Commit the changes for each table
|
||||||
|
destination_conn.commit()
|
||||||
|
|
||||||
|
return destination_engine
|
||||||
|
|
||||||
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
def _get_filtered_sequences_if_any(self) -> Optional[pd.Series]:
|
||||||
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
# maximum possible filter (if limit_sequences_per_category_to == 0):
|
||||||
# WHERE category IN 'self.pick_categories'
|
# WHERE category IN 'self.pick_categories'
|
||||||
@ -413,6 +486,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
*self._get_pick_filters(),
|
*self._get_pick_filters(),
|
||||||
*self._get_exclude_filters(),
|
*self._get_exclude_filters(),
|
||||||
]
|
]
|
||||||
|
if self.pick_sequences_sql_clause:
|
||||||
|
print("Applying the custom SQL clause.")
|
||||||
|
where_conditions.append(sa.text(self.pick_sequences_sql_clause))
|
||||||
|
|
||||||
def add_where(stmt):
|
def add_where(stmt):
|
||||||
return stmt.where(*where_conditions) if where_conditions else stmt
|
return stmt.where(*where_conditions) if where_conditions else stmt
|
||||||
@ -749,9 +825,15 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self.frame_annotations_type.sequence_name == seq_name,
|
self.frame_annotations_type.sequence_name == seq_name,
|
||||||
self.frame_annotations_type.frame_number.in_(frames),
|
self.frame_annotations_type.frame_number.in_(frames),
|
||||||
)
|
)
|
||||||
|
frame_no_ts = None
|
||||||
|
|
||||||
with self._sql_engine.connect() as connection:
|
if self.scoped_session:
|
||||||
frame_no_ts = pd.read_sql_query(stmt, connection)
|
stmt_text = str(stmt.compile(compile_kwargs={"literal_binds": True}))
|
||||||
|
with scoped_session(self._session_factory)() as session: # pyre-ignore
|
||||||
|
frame_no_ts = pd.read_sql_query(stmt_text, session.connection())
|
||||||
|
else:
|
||||||
|
with self._sql_engine.connect() as connection:
|
||||||
|
frame_no_ts = pd.read_sql_query(stmt, connection)
|
||||||
|
|
||||||
if len(frame_no_ts) != len(index_slice):
|
if len(frame_no_ts) != len(index_slice):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -284,8 +284,14 @@ class SqlIndexDatasetMapProvider(DatasetMapProviderBase):
|
|||||||
logger.info(f"Val dataset: {str(val_dataset)}")
|
logger.info(f"Val dataset: {str(val_dataset)}")
|
||||||
|
|
||||||
logger.debug("Extracting test dataset.")
|
logger.debug("Extracting test dataset.")
|
||||||
eval_batches_file = self._get_lists_file("eval_batches")
|
if self.eval_batches_path is None:
|
||||||
del common_dataset_kwargs["eval_batches_file"]
|
eval_batches_file = None
|
||||||
|
else:
|
||||||
|
eval_batches_file = self._get_lists_file("eval_batches")
|
||||||
|
|
||||||
|
if "eval_batches_file" in common_dataset_kwargs:
|
||||||
|
common_dataset_kwargs.pop("eval_batches_file", None)
|
||||||
|
|
||||||
test_dataset = dataset_type(
|
test_dataset = dataset_type(
|
||||||
**common_dataset_kwargs,
|
**common_dataset_kwargs,
|
||||||
subsets=self._get_subsets(self.test_subsets, True),
|
subsets=self._get_subsets(self.test_subsets, True),
|
||||||
|
@ -211,14 +211,21 @@ def resize_image(
|
|||||||
if isinstance(image, np.ndarray):
|
if isinstance(image, np.ndarray):
|
||||||
image = torch.from_numpy(image)
|
image = torch.from_numpy(image)
|
||||||
|
|
||||||
if image_height is None or image_width is None:
|
if (
|
||||||
|
image_height is None
|
||||||
|
or image_width is None
|
||||||
|
or image.shape[-2] == 0
|
||||||
|
or image.shape[-1] == 0
|
||||||
|
):
|
||||||
# skip the resizing
|
# skip the resizing
|
||||||
return image, 1.0, torch.ones_like(image[:1])
|
return image, 1.0, torch.ones_like(image[:1])
|
||||||
|
|
||||||
# takes numpy array or tensor, returns pytorch tensor
|
# takes numpy array or tensor, returns pytorch tensor
|
||||||
minscale = min(
|
minscale = min(
|
||||||
image_height / image.shape[-2],
|
image_height / image.shape[-2],
|
||||||
image_width / image.shape[-1],
|
image_width / image.shape[-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
imre = torch.nn.functional.interpolate(
|
imre = torch.nn.functional.interpolate(
|
||||||
image[None],
|
image[None],
|
||||||
scale_factor=minscale,
|
scale_factor=minscale,
|
||||||
@ -226,6 +233,7 @@ def resize_image(
|
|||||||
align_corners=False if mode == "bilinear" else None,
|
align_corners=False if mode == "bilinear" else None,
|
||||||
recompute_scale_factor=True,
|
recompute_scale_factor=True,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
imre_ = torch.zeros(image.shape[0], image_height, image_width)
|
||||||
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
|
||||||
mask = torch.zeros(1, image_height, image_width)
|
mask = torch.zeros(1, image_height, image_width)
|
||||||
@ -238,20 +246,21 @@ def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
|||||||
return im.astype(np.float32) / 255.0
|
return im.astype(np.float32) / 255.0
|
||||||
|
|
||||||
|
|
||||||
def load_image(path: str, try_read_alpha: bool = False) -> np.ndarray:
|
def load_image(
|
||||||
|
path: str, try_read_alpha: bool = False, pil_format: str = "RGB"
|
||||||
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Load an image from a path and return it as a numpy array.
|
Load an image from a path and return it as a numpy array.
|
||||||
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
If try_read_alpha is True, the image is read as RGBA and the alpha channel is
|
||||||
returned as the fourth channel.
|
returned as the fourth channel.
|
||||||
Otherwise, the image is read as RGB and a three-channel image is returned.
|
Otherwise, the image is read as RGB and a three-channel image is returned.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with Image.open(path) as pil_im:
|
with Image.open(path) as pil_im:
|
||||||
# Check if the image has an alpha channel
|
# Check if the image has an alpha channel
|
||||||
if try_read_alpha and pil_im.mode == "RGBA":
|
if try_read_alpha and pil_im.mode == "RGBA":
|
||||||
im = np.array(pil_im)
|
im = np.array(pil_im)
|
||||||
else:
|
else:
|
||||||
im = np.array(pil_im.convert("RGB"))
|
im = np.array(pil_im.convert(pil_format))
|
||||||
|
|
||||||
return transpose_normalize_image(im)
|
return transpose_normalize_image(im)
|
||||||
|
|
||||||
@ -389,7 +398,7 @@ def adjust_camera_to_image_scale_(
|
|||||||
)
|
)
|
||||||
camera.focal_length = focal_length_scaled[None]
|
camera.focal_length = focal_length_scaled[None]
|
||||||
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
# pyre-fixme[16]: `PerspectiveCameras` has no attribute `principal_point`.
|
||||||
camera.principal_point = principal_point_scaled[None]
|
camera.principal_point = principal_point_scaled[None] # pyre-ignore[16]
|
||||||
|
|
||||||
|
|
||||||
# NOTE this cache is per-worker; they are implemented as processes.
|
# NOTE this cache is per-worker; they are implemented as processes.
|
||||||
|
@ -6,7 +6,6 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
@ -298,9 +297,8 @@ class ViewMetrics(ViewMetricsBase):
|
|||||||
_rgb_metrics(
|
_rgb_metrics(
|
||||||
image_rgb,
|
image_rgb,
|
||||||
image_rgb_pred,
|
image_rgb_pred,
|
||||||
fg_probability,
|
masks=fg_probability,
|
||||||
fg_probability_pred,
|
masks_crop=mask_crop,
|
||||||
mask_crop,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -310,9 +308,21 @@ class ViewMetrics(ViewMetricsBase):
|
|||||||
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
metrics["mask_neg_iou"] = utils.neg_iou_loss(
|
||||||
fg_probability_pred, fg_probability, mask=mask_crop
|
fg_probability_pred, fg_probability, mask=mask_crop
|
||||||
)
|
)
|
||||||
metrics["mask_bce"] = utils.calc_bce(
|
if torch.is_autocast_enabled():
|
||||||
fg_probability_pred, fg_probability, mask=mask_crop
|
# To avoid issues with mixed precision
|
||||||
)
|
metrics["mask_bce"] = utils.calc_bce(
|
||||||
|
fg_probability_pred.logit(),
|
||||||
|
fg_probability,
|
||||||
|
mask=mask_crop,
|
||||||
|
pred_logits=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
metrics["mask_bce"] = utils.calc_bce(
|
||||||
|
fg_probability_pred,
|
||||||
|
fg_probability,
|
||||||
|
mask=mask_crop,
|
||||||
|
pred_logits=False,
|
||||||
|
)
|
||||||
|
|
||||||
if depth_map is not None and depth_map_pred is not None:
|
if depth_map is not None and depth_map_pred is not None:
|
||||||
assert mask_crop is not None
|
assert mask_crop is not None
|
||||||
@ -324,7 +334,11 @@ class ViewMetrics(ViewMetricsBase):
|
|||||||
if fg_probability is not None:
|
if fg_probability is not None:
|
||||||
mask = fg_probability * mask_crop
|
mask = fg_probability * mask_crop
|
||||||
_, abs_ = utils.eval_depth(
|
_, abs_ = utils.eval_depth(
|
||||||
depth_map_pred, depth_map, get_best_scale=True, mask=mask, crop=0
|
depth_map_pred,
|
||||||
|
depth_map,
|
||||||
|
get_best_scale=True,
|
||||||
|
mask=mask,
|
||||||
|
crop=0,
|
||||||
)
|
)
|
||||||
metrics["depth_abs_fg"] = abs_.mean()
|
metrics["depth_abs_fg"] = abs_.mean()
|
||||||
|
|
||||||
@ -346,18 +360,26 @@ class ViewMetrics(ViewMetricsBase):
|
|||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
def _rgb_metrics(images, images_pred, masks, masks_pred, masks_crop):
|
def _rgb_metrics(
|
||||||
|
images,
|
||||||
|
images_pred,
|
||||||
|
masks=None,
|
||||||
|
masks_crop=None,
|
||||||
|
huber_scaling: float = 0.03,
|
||||||
|
):
|
||||||
assert masks_crop is not None
|
assert masks_crop is not None
|
||||||
if images.shape[1] != images_pred.shape[1]:
|
if images.shape[1] != images_pred.shape[1]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Network output's RGB images had {images_pred.shape[1]} "
|
f"Network output's RGB images had {images_pred.shape[1]} "
|
||||||
f"channels. {images.shape[1]} expected."
|
f"channels. {images.shape[1]} expected."
|
||||||
)
|
)
|
||||||
|
rgb_abs = ((images_pred - images).abs()).mean(dim=1, keepdim=True)
|
||||||
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
|
||||||
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
|
rgb_loss = utils.huber(rgb_squared, scaling=huber_scaling)
|
||||||
crop_mass = masks_crop.sum().clamp(1.0)
|
crop_mass = masks_crop.sum().clamp(1.0)
|
||||||
results = {
|
results = {
|
||||||
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
|
||||||
|
"rgb_l1": (rgb_abs * masks_crop).sum() / crop_mass,
|
||||||
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
|
||||||
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
|
||||||
}
|
}
|
||||||
|
@ -6,12 +6,15 @@
|
|||||||
|
|
||||||
# pyre-unsafe
|
# pyre-unsafe
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def eval_depth(
|
def eval_depth(
|
||||||
pred: torch.Tensor,
|
pred: torch.Tensor,
|
||||||
@ -21,6 +24,8 @@ def eval_depth(
|
|||||||
get_best_scale: bool = True,
|
get_best_scale: bool = True,
|
||||||
mask_thr: float = 0.5,
|
mask_thr: float = 0.5,
|
||||||
best_scale_clamp_thr: float = 1e-4,
|
best_scale_clamp_thr: float = 1e-4,
|
||||||
|
use_disparity: bool = False,
|
||||||
|
disparity_eps: float = 1e-4,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Evaluate the depth error between the prediction `pred` and the ground
|
Evaluate the depth error between the prediction `pred` and the ground
|
||||||
@ -64,6 +69,13 @@ def eval_depth(
|
|||||||
# s.t. we get best possible mse error
|
# s.t. we get best possible mse error
|
||||||
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
|
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
|
||||||
pred = pred * scale_best[:, None, None, None]
|
pred = pred * scale_best[:, None, None, None]
|
||||||
|
if use_disparity:
|
||||||
|
gt = torch.div(1.0, (gt + disparity_eps))
|
||||||
|
pred = torch.div(1.0, (pred + disparity_eps))
|
||||||
|
scale_best = estimate_depth_scale_factor(
|
||||||
|
pred, gt, dmask, best_scale_clamp_thr
|
||||||
|
).detach()
|
||||||
|
pred = pred * scale_best[:, None, None, None]
|
||||||
|
|
||||||
df = gt - pred
|
df = gt - pred
|
||||||
|
|
||||||
@ -117,6 +129,7 @@ def calc_bce(
|
|||||||
pred_eps: float = 0.01,
|
pred_eps: float = 0.01,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
lerp_bound: Optional[float] = None,
|
lerp_bound: Optional[float] = None,
|
||||||
|
pred_logits: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Calculates the binary cross entropy.
|
Calculates the binary cross entropy.
|
||||||
@ -139,9 +152,23 @@ def calc_bce(
|
|||||||
weight = torch.ones_like(gt) * mask
|
weight = torch.ones_like(gt) * mask
|
||||||
|
|
||||||
if lerp_bound is not None:
|
if lerp_bound is not None:
|
||||||
|
# binary_cross_entropy_lerp requires pred to be in [0, 1]
|
||||||
|
if pred_logits:
|
||||||
|
pred = F.sigmoid(pred)
|
||||||
|
|
||||||
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
|
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
|
||||||
else:
|
else:
|
||||||
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
|
if pred_logits:
|
||||||
|
loss = F.binary_cross_entropy_with_logits(
|
||||||
|
pred,
|
||||||
|
gt,
|
||||||
|
reduction="none",
|
||||||
|
weight=weight,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
loss = F.binary_cross_entropy(pred, gt, reduction="none", weight=weight)
|
||||||
|
|
||||||
|
return loss.mean()
|
||||||
|
|
||||||
|
|
||||||
def binary_cross_entropy_lerp(
|
def binary_cross_entropy_lerp(
|
||||||
|
@ -16,8 +16,17 @@ from typing import Optional, Tuple, Union
|
|||||||
import matplotlib
|
import matplotlib
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
_NO_TORCHVISION = False
|
||||||
|
try:
|
||||||
|
import torchvision
|
||||||
|
except ImportError:
|
||||||
|
_NO_TORCHVISION = True
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
|
_DEFAULT_FFMPEG = os.environ.get("FFMPEG", "ffmpeg")
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
@ -36,6 +45,7 @@ class VideoWriter:
|
|||||||
fps: int = 20,
|
fps: int = 20,
|
||||||
output_format: str = "visdom",
|
output_format: str = "visdom",
|
||||||
rmdir_allowed: bool = False,
|
rmdir_allowed: bool = False,
|
||||||
|
use_torchvision_video_writer: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -49,6 +59,8 @@ class VideoWriter:
|
|||||||
is supported.
|
is supported.
|
||||||
rmdir_allowed: If `True` delete and create `cache_dir` in case
|
rmdir_allowed: If `True` delete and create `cache_dir` in case
|
||||||
it is not empty.
|
it is not empty.
|
||||||
|
use_torchvision_video_writer: If `True` use `torchvision.io.write_video`
|
||||||
|
to write the video
|
||||||
"""
|
"""
|
||||||
self.rmdir_allowed = rmdir_allowed
|
self.rmdir_allowed = rmdir_allowed
|
||||||
self.output_format = output_format
|
self.output_format = output_format
|
||||||
@ -56,10 +68,14 @@ class VideoWriter:
|
|||||||
self.out_path = out_path
|
self.out_path = out_path
|
||||||
self.cache_dir = cache_dir
|
self.cache_dir = cache_dir
|
||||||
self.ffmpeg_bin = ffmpeg_bin
|
self.ffmpeg_bin = ffmpeg_bin
|
||||||
|
self.use_torchvision_video_writer = use_torchvision_video_writer
|
||||||
self.frames = []
|
self.frames = []
|
||||||
self.regexp = "frame_%08d.png"
|
self.regexp = "frame_%08d.png"
|
||||||
self.frame_num = 0
|
self.frame_num = 0
|
||||||
|
|
||||||
|
if self.use_torchvision_video_writer:
|
||||||
|
assert not _NO_TORCHVISION, "torchvision not available"
|
||||||
|
|
||||||
if self.cache_dir is not None:
|
if self.cache_dir is not None:
|
||||||
self.tmp_dir = None
|
self.tmp_dir = None
|
||||||
if os.path.isdir(self.cache_dir):
|
if os.path.isdir(self.cache_dir):
|
||||||
@ -114,7 +130,7 @@ class VideoWriter:
|
|||||||
resize = im.size
|
resize = im.size
|
||||||
# make sure size is divisible by 2
|
# make sure size is divisible by 2
|
||||||
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
|
||||||
# pyre-fixme[16]: Module `Image` has no attribute `ANTIALIAS`.
|
|
||||||
im = im.resize(resize, Image.ANTIALIAS)
|
im = im.resize(resize, Image.ANTIALIAS)
|
||||||
im.save(outfile)
|
im.save(outfile)
|
||||||
|
|
||||||
@ -139,38 +155,56 @@ class VideoWriter:
|
|||||||
# got `Optional[str]`.
|
# got `Optional[str]`.
|
||||||
regexp = os.path.join(self.cache_dir, self.regexp)
|
regexp = os.path.join(self.cache_dir, self.regexp)
|
||||||
|
|
||||||
if shutil.which(self.ffmpeg_bin) is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
|
||||||
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.output_format == "visdom": # works for ppt too
|
if self.output_format == "visdom": # works for ppt too
|
||||||
args = [
|
# Video codec parameters
|
||||||
self.ffmpeg_bin,
|
video_codec = "h264"
|
||||||
"-r",
|
crf = "18"
|
||||||
str(self.fps),
|
b = "2000k"
|
||||||
"-i",
|
pix_fmt = "yuv420p"
|
||||||
regexp,
|
|
||||||
"-vcodec",
|
if self.use_torchvision_video_writer:
|
||||||
"h264",
|
torchvision.io.write_video(
|
||||||
"-f",
|
self.out_path,
|
||||||
"mp4",
|
torch.stack(
|
||||||
"-y",
|
[torch.from_numpy(np.array(Image.open(f))) for f in self.frames]
|
||||||
"-crf",
|
),
|
||||||
"18",
|
fps=self.fps,
|
||||||
"-b",
|
video_codec=video_codec,
|
||||||
"2000k",
|
options={"crf": crf, "b": b, "pix_fmt": pix_fmt},
|
||||||
"-pix_fmt",
|
|
||||||
"yuv420p",
|
|
||||||
self.out_path,
|
|
||||||
]
|
|
||||||
if quiet:
|
|
||||||
subprocess.check_call(
|
|
||||||
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
subprocess.check_call(args)
|
if shutil.which(self.ffmpeg_bin) is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot find ffmpeg as `{self.ffmpeg_bin}`. "
|
||||||
|
+ "Please set FFMPEG in the environment or ffmpeg_bin on this class."
|
||||||
|
)
|
||||||
|
|
||||||
|
args = [
|
||||||
|
self.ffmpeg_bin,
|
||||||
|
"-r",
|
||||||
|
str(self.fps),
|
||||||
|
"-i",
|
||||||
|
regexp,
|
||||||
|
"-vcodec",
|
||||||
|
video_codec,
|
||||||
|
"-f",
|
||||||
|
"mp4",
|
||||||
|
"-y",
|
||||||
|
"-crf",
|
||||||
|
crf,
|
||||||
|
"-b",
|
||||||
|
b,
|
||||||
|
"-pix_fmt",
|
||||||
|
pix_fmt,
|
||||||
|
self.out_path,
|
||||||
|
]
|
||||||
|
if quiet:
|
||||||
|
subprocess.check_call(
|
||||||
|
args, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
subprocess.check_call(args)
|
||||||
else:
|
else:
|
||||||
raise ValueError("no such output type %s" % str(self.output_format))
|
raise ValueError("no such output type %s" % str(self.output_format))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user