mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Add full-image PSNR metric
Summary: Reports also the PSNR between the unmasked G.T. image and the render. Reviewed By: bottler Differential Revision: D38655943 fbshipit-source-id: 1603a2d02116ea1ce037e5530abe1afc65a2ba93
This commit is contained in:
parent
a91f15f24e
commit
7b985702bb
@ -262,6 +262,10 @@ def eval_batch(
|
|||||||
else torch.ones_like(mask_fg)
|
else torch.ones_like(mask_fg)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# unmasked g.t. image
|
||||||
|
image_rgb = frame_data.image_rgb
|
||||||
|
|
||||||
|
# fg-masked g.t. image
|
||||||
image_rgb_masked = mask_background(
|
image_rgb_masked = mask_background(
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
|
||||||
# `Optional[torch.Tensor]`.
|
# `Optional[torch.Tensor]`.
|
||||||
@ -330,12 +334,26 @@ def eval_batch(
|
|||||||
if break_after_visualising:
|
if break_after_visualising:
|
||||||
breakpoint() # noqa: B601
|
breakpoint() # noqa: B601
|
||||||
|
|
||||||
|
# add the rgb metrics between the render and the unmasked image
|
||||||
|
for rgb_metric_name, rgb_metric_fun in zip(
|
||||||
|
("psnr_full_image", "rgb_l1_full_image"), (calc_psnr, rgb_l1)
|
||||||
|
):
|
||||||
|
results[rgb_metric_name] = rgb_metric_fun(
|
||||||
|
image_render,
|
||||||
|
image_rgb,
|
||||||
|
mask=mask_crop,
|
||||||
|
)
|
||||||
|
|
||||||
if lpips_model is not None:
|
if lpips_model is not None:
|
||||||
im1, im2 = [
|
for gt_image_type in ("_full_image", ""):
|
||||||
2.0 * im.clamp(0.0, 1.0) - 1.0
|
im1, im2 = [
|
||||||
for im in (image_rgb_masked, cloned_render["image_render"])
|
2.0 * im.clamp(0.0, 1.0) - 1.0 # pyre-ignore[16]
|
||||||
]
|
for im in (
|
||||||
results["lpips"] = lpips_model.forward(im1, im2).item()
|
image_rgb_masked if gt_image_type == "" else image_rgb,
|
||||||
|
cloned_render["image_render"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
results["lpips" + gt_image_type] = lpips_model.forward(im1, im2).item()
|
||||||
|
|
||||||
# convert all metrics to floats
|
# convert all metrics to floats
|
||||||
results = {k: float(v) for k, v in results.items()}
|
results = {k: float(v) for k, v in results.items()}
|
||||||
|
@ -256,10 +256,13 @@ class TestEvaluation(unittest.TestCase):
|
|||||||
lower_better = {
|
lower_better = {
|
||||||
"psnr": False,
|
"psnr": False,
|
||||||
"psnr_fg": False,
|
"psnr_fg": False,
|
||||||
|
"psnr_full_image": False,
|
||||||
"depth_abs_fg": True,
|
"depth_abs_fg": True,
|
||||||
"iou": False,
|
"iou": False,
|
||||||
"rgb_l1": True,
|
"rgb_l1": True,
|
||||||
"rgb_l1_fg": True,
|
"rgb_l1_fg": True,
|
||||||
|
"lpips": True,
|
||||||
|
"lpips_full_image": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
for metric in lower_better:
|
for metric in lower_better:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user