mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	small fix for iou3d
Summary: A small numerical fix for IoU for 3D boxes, fixes GH #992 * Adds a check for boxes with zero side areas (invalid boxes) * Fixes numerical issue when two boxes have coplanar sides Reviewed By: nikhilaravi Differential Revision: D33195691 fbshipit-source-id: 8a34b4d1f1e5ec2edb6d54143930da44bdde0906
This commit is contained in:
		
							parent
							
								
									069c9fd759
								
							
						
					
					
						commit
						ccfb72cc50
					
				@ -90,7 +90,8 @@ __global__ void IoUBox3DKernel(
 | 
			
		||||
        for (int b2 = 0; b2 < box2_count; ++b2) {
 | 
			
		||||
          const bool is_coplanar =
 | 
			
		||||
              IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
 | 
			
		||||
          if (is_coplanar) {
 | 
			
		||||
          const float area = FaceArea(box1_intersect[b1]);
 | 
			
		||||
          if ((is_coplanar) && (area > kEpsilon)) {
 | 
			
		||||
            tri2_keep[b2].keep = false;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@ -81,7 +81,8 @@ std::tuple<at::Tensor, at::Tensor> IoUBox3DCpu(
 | 
			
		||||
          for (int b2 = 0; b2 < box2_intersect.size(); ++b2) {
 | 
			
		||||
            const bool is_coplanar =
 | 
			
		||||
                IsCoplanarFace(box1_intersect[b1], box2_intersect[b2]);
 | 
			
		||||
            if (is_coplanar) {
 | 
			
		||||
            const float area = FaceArea(box1_intersect[b1]);
 | 
			
		||||
            if ((is_coplanar) && (area > kEpsilon)) {
 | 
			
		||||
              tri2_keep[b2] = 0;
 | 
			
		||||
            }
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
@ -138,6 +138,26 @@ FaceNormal(const float3 v0, const float3 v1, const float3 v2) {
 | 
			
		||||
  return n;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The area of the face defined by vertices (v0, v1, v2)
 | 
			
		||||
// Define e0 to be the edge connecting (v1, v0)
 | 
			
		||||
// Define e1 to be the edge connecting (v2, v0)
 | 
			
		||||
// Area is the norm of the cross product of e0, e1 divided by 2.0
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    tri: FaceVerts of float3 coordinates of the vertices of the face
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float: area for the face
 | 
			
		||||
//
 | 
			
		||||
__device__ inline float FaceArea(const FaceVerts& tri) {
 | 
			
		||||
  // Get verts for face 1
 | 
			
		||||
  const float3 v0 = tri.v0;
 | 
			
		||||
  const float3 v1 = tri.v1;
 | 
			
		||||
  const float3 v2 = tri.v2;
 | 
			
		||||
  const float3 n = cross(v1 - v0, v2 - v0);
 | 
			
		||||
  return norm(n) / 2.0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The normal of a box plane defined by the verts in `plane` with
 | 
			
		||||
// the centroid of the box given by `center`.
 | 
			
		||||
// Args
 | 
			
		||||
 | 
			
		||||
@ -145,6 +145,26 @@ inline vec3<float> FaceNormal(vec3<float> v0, vec3<float> v1, vec3<float> v2) {
 | 
			
		||||
  return n;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The area of the face defined by vertices (v0, v1, v2)
 | 
			
		||||
// Define e0 to be the edge connecting (v1, v0)
 | 
			
		||||
// Define e1 to be the edge connecting (v2, v0)
 | 
			
		||||
// Area is the norm of the cross product of e0, e1 divided by 2.0
 | 
			
		||||
//
 | 
			
		||||
// Args
 | 
			
		||||
//    tri: vec3 coordinates of the vertices of the face
 | 
			
		||||
//
 | 
			
		||||
// Returns
 | 
			
		||||
//    float: area for the face
 | 
			
		||||
//
 | 
			
		||||
inline float FaceArea(const std::vector<vec3<float>>& tri) {
 | 
			
		||||
  // Get verts for face
 | 
			
		||||
  const vec3<float> v0 = tri[0];
 | 
			
		||||
  const vec3<float> v1 = tri[1];
 | 
			
		||||
  const vec3<float> v2 = tri[2];
 | 
			
		||||
  const vec3<float> n = cross(v1 - v0, v2 - v0);
 | 
			
		||||
  return norm(n) / 2.0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// The normal of a box plane defined by the verts in `plane` with
 | 
			
		||||
// the centroid of the box given by `center`.
 | 
			
		||||
// Args
 | 
			
		||||
 | 
			
		||||
@ -69,6 +69,28 @@ def _check_coplanar(boxes: torch.Tensor, eps: float = 1e-4) -> None:
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _check_nonzero(boxes: torch.Tensor, eps: float = 1e-4) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Checks that the sides of the box have a non zero area
 | 
			
		||||
    """
 | 
			
		||||
    faces = torch.tensor(_box_triangles, dtype=torch.int64, device=boxes.device)
 | 
			
		||||
    # pyre-fixme[16]: `boxes` has no attribute `index_select`.
 | 
			
		||||
    verts = boxes.index_select(index=faces.view(-1), dim=1)
 | 
			
		||||
    B = boxes.shape[0]
 | 
			
		||||
    T, V = faces.shape
 | 
			
		||||
    # (B, T, 3, 3) -> (B, T, 3)
 | 
			
		||||
    v0, v1, v2 = verts.reshape(B, T, V, 3).unbind(2)
 | 
			
		||||
 | 
			
		||||
    normals = torch.cross(v1 - v0, v2 - v0, dim=-1)  # (B, T, 3)
 | 
			
		||||
    face_areas = normals.norm(dim=-1) / 2
 | 
			
		||||
 | 
			
		||||
    if (face_areas < eps).any().item():
 | 
			
		||||
        msg = "Planes have zero areas"
 | 
			
		||||
        raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
    return
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class _box3d_overlap(Function):
 | 
			
		||||
    """
 | 
			
		||||
    Torch autograd Function wrapper for box3d_overlap C++/CUDA implementations.
 | 
			
		||||
@ -138,6 +160,8 @@ def box3d_overlap(
 | 
			
		||||
 | 
			
		||||
    _check_coplanar(boxes1, eps)
 | 
			
		||||
    _check_coplanar(boxes2, eps)
 | 
			
		||||
    _check_nonzero(boxes1, eps)
 | 
			
		||||
    _check_nonzero(boxes2, eps)
 | 
			
		||||
 | 
			
		||||
    # pyre-fixme[16]: `_box3d_overlap` has no attribute `apply`.
 | 
			
		||||
    vol, iou = _box3d_overlap.apply(boxes1, boxes2)
 | 
			
		||||
 | 
			
		||||
@ -111,6 +111,11 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
 | 
			
		||||
        )
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, iou = overlap_fn(box2[None], box1[None])
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 3rd test
 | 
			
		||||
        dd = random.random()
 | 
			
		||||
@ -119,6 +124,11 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
 | 
			
		||||
        )
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, _ = overlap_fn(box2[None], box1[None])
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol, torch.tensor([[1 - dd]], device=vol.device, dtype=vol.dtype)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 4th test
 | 
			
		||||
        ddx, ddy, ddz = random.random(), random.random(), random.random()
 | 
			
		||||
@ -132,6 +142,16 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                dtype=vol.dtype,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, _ = overlap_fn(box2[None], box1[None])
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol,
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
                [[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
 | 
			
		||||
                device=vol.device,
 | 
			
		||||
                dtype=vol.dtype,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Also check IoU is 1 when computing overlap with the same shifted box
 | 
			
		||||
        vol, iou = overlap_fn(box2[None], box2[None])
 | 
			
		||||
@ -152,6 +172,16 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                dtype=vol.dtype,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, _ = overlap_fn(box2r[None], box1r[None])
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol,
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
                [[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
 | 
			
		||||
                device=vol.device,
 | 
			
		||||
                dtype=vol.dtype,
 | 
			
		||||
            ),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 6th test
 | 
			
		||||
        ddx, ddy, ddz = random.random(), random.random(), random.random()
 | 
			
		||||
@ -170,6 +200,17 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            ),
 | 
			
		||||
            atol=1e-7,
 | 
			
		||||
        )
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, _ = overlap_fn(box2r[None], box1r[None])
 | 
			
		||||
        self.assertClose(
 | 
			
		||||
            vol,
 | 
			
		||||
            torch.tensor(
 | 
			
		||||
                [[(1 - ddx) * (1 - ddy) * (1 - ddz)]],
 | 
			
		||||
                device=vol.device,
 | 
			
		||||
                dtype=vol.dtype,
 | 
			
		||||
            ),
 | 
			
		||||
            atol=1e-7,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # 7th test: hand coded example and test with meshlab output
 | 
			
		||||
 | 
			
		||||
@ -214,6 +255,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        vol, iou = overlap_fn(box1r[None], box2r[None])
 | 
			
		||||
        self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
 | 
			
		||||
        self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, iou = overlap_fn(box2r[None], box1r[None])
 | 
			
		||||
        self.assertClose(vol, torch.tensor([[vol_inters]], device=device), atol=1e-1)
 | 
			
		||||
        self.assertClose(iou, torch.tensor([[iou_mesh]], device=device), atol=1e-1)
 | 
			
		||||
 | 
			
		||||
        # 8th test: compare with sampling
 | 
			
		||||
        # create box1
 | 
			
		||||
@ -232,7 +277,9 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        iou_sampling = self._box3d_overlap_sampling_batched(
 | 
			
		||||
            box1r[None], box2r[None], num_samples=10000
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertClose(iou, iou_sampling, atol=1e-2)
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, iou = overlap_fn(box2r[None], box1r[None])
 | 
			
		||||
        self.assertClose(iou, iou_sampling, atol=1e-2)
 | 
			
		||||
 | 
			
		||||
        # 9th test: non overlapping boxes, iou = 0.0
 | 
			
		||||
@ -240,6 +287,10 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        vol, iou = overlap_fn(box1[None], box2[None])
 | 
			
		||||
        self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
        self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vol, iou = overlap_fn(box2[None], box1[None])
 | 
			
		||||
        self.assertClose(vol, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
        self.assertClose(iou, torch.tensor([[0.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
 | 
			
		||||
        # 10th test: Non coplanar verts in a plane
 | 
			
		||||
        box10 = box1 + torch.rand((8, 3), dtype=torch.float32, device=device)
 | 
			
		||||
@ -284,6 +335,56 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        vols, ious = overlap_fn(box_skew_1[None], box_skew_2[None])
 | 
			
		||||
        self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
 | 
			
		||||
        self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
 | 
			
		||||
        # symmetry
 | 
			
		||||
        vols, ious = overlap_fn(box_skew_2[None], box_skew_1[None])
 | 
			
		||||
        self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
 | 
			
		||||
        self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
 | 
			
		||||
 | 
			
		||||
        # 12th test: Zero area bounding box (from GH issue #992)
 | 
			
		||||
        box12a = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [-1.0000, -1.0000, -0.5000],
 | 
			
		||||
                [1.0000, -1.0000, -0.5000],
 | 
			
		||||
                [1.0000, 1.0000, -0.5000],
 | 
			
		||||
                [-1.0000, 1.0000, -0.5000],
 | 
			
		||||
                [-1.0000, -1.0000, 0.5000],
 | 
			
		||||
                [1.0000, -1.0000, 0.5000],
 | 
			
		||||
                [1.0000, 1.0000, 0.5000],
 | 
			
		||||
                [-1.0000, 1.0000, 0.5000],
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        box12b = torch.tensor(
 | 
			
		||||
            [
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
                [0.0, 0.0, 0.0],
 | 
			
		||||
            ],
 | 
			
		||||
            device=device,
 | 
			
		||||
            dtype=torch.float32,
 | 
			
		||||
        )
 | 
			
		||||
        msg = "Planes have zero areas"
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, msg):
 | 
			
		||||
            overlap_fn(box12a[None], box12b[None])
 | 
			
		||||
        # symmetry
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, msg):
 | 
			
		||||
            overlap_fn(box12b[None], box12a[None])
 | 
			
		||||
 | 
			
		||||
        # 13th test: From GH issue #992
 | 
			
		||||
        # Zero area coplanar face after intersection
 | 
			
		||||
        ctrs = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0]])
 | 
			
		||||
        whl = torch.tensor([[2.0, 2.0, 2.0], [2.0, 2, 2]])
 | 
			
		||||
        box13a = TestIoU3D.create_box(ctrs[0], whl[0])
 | 
			
		||||
        box13b = TestIoU3D.create_box(ctrs[1], whl[1])
 | 
			
		||||
        vol, iou = overlap_fn(box13a[None], box13b[None])
 | 
			
		||||
        self.assertClose(vol, torch.tensor([[2.0]], device=vol.device, dtype=vol.dtype))
 | 
			
		||||
 | 
			
		||||
    def _test_real_boxes(self, overlap_fn, device):
 | 
			
		||||
        data_filename = "./real_boxes.pkl"
 | 
			
		||||
@ -577,6 +678,13 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
 | 
			
		||||
        msg = "Plane vertices are not coplanar"
 | 
			
		||||
        raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
    # Check all faces have non zero area
 | 
			
		||||
    area1 = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
 | 
			
		||||
    area2 = torch.cross(v3 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2
 | 
			
		||||
    if (area1 < eps).any().item() or (area2 < eps).any().item():
 | 
			
		||||
        msg = "Planes have zero areas"
 | 
			
		||||
        raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
    # We can write:  `ctr = v0 + a * e0 + b * e1 + c * n`, (1).
 | 
			
		||||
    # With <e0, n> = 0 and <e1, n> = 0, where <.,.> refers to the dot product,
 | 
			
		||||
    # since that e0 is orthogonal to n. Same for e1.
 | 
			
		||||
@ -607,6 +715,27 @@ def box_planar_dir(box: torch.Tensor, eps=1e-4) -> torch.Tensor:
 | 
			
		||||
    return n
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tri_verts_area(tri_verts: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Computes the area of the triangle faces in tri_verts
 | 
			
		||||
    Args:
 | 
			
		||||
        tri_verts: tensor of shape (T, 3, 3)
 | 
			
		||||
    Returns:
 | 
			
		||||
        areas: the area of the triangles (T, 1)
 | 
			
		||||
    """
 | 
			
		||||
    add_dim = False
 | 
			
		||||
    if tri_verts.ndim == 2:
 | 
			
		||||
        tri_verts = tri_verts.unsqueeze(0)
 | 
			
		||||
        add_dim = True
 | 
			
		||||
 | 
			
		||||
    v0, v1, v2 = tri_verts.unbind(1)
 | 
			
		||||
    areas = torch.cross(v1 - v0, v2 - v0, dim=-1).norm(dim=-1) / 2.0
 | 
			
		||||
 | 
			
		||||
    if add_dim:
 | 
			
		||||
        areas = areas[0]
 | 
			
		||||
    return areas
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def box_volume(box: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Computes the volume of each box in boxes.
 | 
			
		||||
@ -988,7 +1117,10 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
 | 
			
		||||
    keep2 = torch.ones((tri_verts2.shape[0],), device=device, dtype=torch.bool)
 | 
			
		||||
    for i1 in range(tri_verts1.shape[0]):
 | 
			
		||||
        for i2 in range(tri_verts2.shape[0]):
 | 
			
		||||
            if coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2]):
 | 
			
		||||
            if (
 | 
			
		||||
                coplanar_tri_faces(tri_verts1[i1], tri_verts2[i2])
 | 
			
		||||
                and tri_verts_area(tri_verts1[i1]) > 1e-4
 | 
			
		||||
            ):
 | 
			
		||||
                keep2[i2] = 0
 | 
			
		||||
    keep2 = keep2.nonzero()[:, 0]
 | 
			
		||||
    tri_verts2 = tri_verts2[keep2]
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user