Fix for box_crop=True

Summary: one more bugfix in JsonIndexDataset

Reviewed By: bottler

Differential Revision: D37789138

fbshipit-source-id: 2fb2bda7448674091ff6b279175f0bbd16ff7a62
This commit is contained in:
David Novotny 2022-07-12 10:03:58 -07:00 committed by Facebook GitHub Bot
parent d3b7f5f421
commit af55ba01f8

View File

@ -422,17 +422,15 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
if self.box_crop:
clamp_bbox_xyxy = _get_clamp_bbox(
bbox_xywh,
image_path=entry.image.path,
box_crop_context=self.box_crop_context,
)
crop_box_xyxy = _clamp_box_to_image_bounds_and_round(
clamp_bbox_xyxy,
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
_get_clamp_bbox(
bbox_xywh,
image_path=entry.image.path,
box_crop_context=self.box_crop_context,
),
image_size_hw=tuple(mask.shape[-2:]),
)
crop_box_xywh = _bbox_xyxy_to_xywh(crop_box_xyxy)
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
@ -926,9 +924,10 @@ def _clamp_box_to_image_bounds_and_round(
image_size_hw: Tuple[int, int],
) -> torch.LongTensor:
bbox_xyxy = bbox_xyxy.clone()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0.0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0.0, image_size_hw[-2])
bbox_xyxy = bbox_xyxy.round().long()
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
if not isinstance(bbox_xyxy, torch.LongTensor):
bbox_xyxy = bbox_xyxy.round().long()
return bbox_xyxy # pyre-ignore [7]