mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-23 06:12:48 +08:00
Fix for box_crop=True
Summary: one more bugfix in JsonIndexDataset Reviewed By: bottler Differential Revision: D37789138 fbshipit-source-id: 2fb2bda7448674091ff6b279175f0bbd16ff7a62
This commit is contained in:
parent
d3b7f5f421
commit
af55ba01f8
@ -422,17 +422,15 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
|
||||||
|
|
||||||
if self.box_crop:
|
if self.box_crop:
|
||||||
clamp_bbox_xyxy = _get_clamp_bbox(
|
clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round(
|
||||||
|
_get_clamp_bbox(
|
||||||
bbox_xywh,
|
bbox_xywh,
|
||||||
image_path=entry.image.path,
|
image_path=entry.image.path,
|
||||||
box_crop_context=self.box_crop_context,
|
box_crop_context=self.box_crop_context,
|
||||||
)
|
),
|
||||||
|
|
||||||
crop_box_xyxy = _clamp_box_to_image_bounds_and_round(
|
|
||||||
clamp_bbox_xyxy,
|
|
||||||
image_size_hw=tuple(mask.shape[-2:]),
|
image_size_hw=tuple(mask.shape[-2:]),
|
||||||
)
|
)
|
||||||
crop_box_xywh = _bbox_xyxy_to_xywh(crop_box_xyxy)
|
crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy)
|
||||||
|
|
||||||
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
|
||||||
|
|
||||||
@ -926,8 +924,9 @@ def _clamp_box_to_image_bounds_and_round(
|
|||||||
image_size_hw: Tuple[int, int],
|
image_size_hw: Tuple[int, int],
|
||||||
) -> torch.LongTensor:
|
) -> torch.LongTensor:
|
||||||
bbox_xyxy = bbox_xyxy.clone()
|
bbox_xyxy = bbox_xyxy.clone()
|
||||||
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0.0, image_size_hw[-1])
|
bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1])
|
||||||
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0.0, image_size_hw[-2])
|
bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2])
|
||||||
|
if not isinstance(bbox_xyxy, torch.LongTensor):
|
||||||
bbox_xyxy = bbox_xyxy.round().long()
|
bbox_xyxy = bbox_xyxy.round().long()
|
||||||
return bbox_xyxy # pyre-ignore [7]
|
return bbox_xyxy # pyre-ignore [7]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user