diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 6b3d7102..54a5cd3c 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -422,17 +422,15 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) if self.box_crop: - clamp_bbox_xyxy = _get_clamp_bbox( - bbox_xywh, - image_path=entry.image.path, - box_crop_context=self.box_crop_context, - ) - - crop_box_xyxy = _clamp_box_to_image_bounds_and_round( - clamp_bbox_xyxy, + clamp_bbox_xyxy = _clamp_box_to_image_bounds_and_round( + _get_clamp_bbox( + bbox_xywh, + image_path=entry.image.path, + box_crop_context=self.box_crop_context, + ), image_size_hw=tuple(mask.shape[-2:]), ) - crop_box_xywh = _bbox_xyxy_to_xywh(crop_box_xyxy) + crop_box_xywh = _bbox_xyxy_to_xywh(clamp_bbox_xyxy) mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) @@ -926,9 +924,10 @@ def _clamp_box_to_image_bounds_and_round( image_size_hw: Tuple[int, int], ) -> torch.LongTensor: bbox_xyxy = bbox_xyxy.clone() - bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0.0, image_size_hw[-1]) - bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0.0, image_size_hw[-2]) - bbox_xyxy = bbox_xyxy.round().long() + bbox_xyxy[[0, 2]] = torch.clamp(bbox_xyxy[[0, 2]], 0, image_size_hw[-1]) + bbox_xyxy[[1, 3]] = torch.clamp(bbox_xyxy[[1, 3]], 0, image_size_hw[-2]) + if not isinstance(bbox_xyxy, torch.LongTensor): + bbox_xyxy = bbox_xyxy.round().long() return bbox_xyxy # pyre-ignore [7]