IOU box3d epsilon fix

Summary: The epsilon value is important for determining whether vertices are inside/outside a plane.

Reviewed By: gkioxari

Differential Revision: D31485247

fbshipit-source-id: 5517575de7c02f1afa277d00e0190a81f44f5761
This commit is contained in:
Nikhila Ravi 2021-10-07 18:40:56 -07:00 committed by Facebook GitHub Bot
parent b26f4bc33a
commit 6dfa326922
4 changed files with 28 additions and 9 deletions

View File

@ -11,7 +11,8 @@
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <cstdio> #include <cstdio>
#include "utils/float_math.cuh" #include "utils/float_math.cuh"
#include "utils/geometry_utils.cuh"
const auto kEpsilon = 1e-4;
/* /*
_PLANES and _TRIS define the 4- and 3-connectivity _PLANES and _TRIS define the 4- and 3-connectivity

View File

@ -16,9 +16,9 @@
#include <queue> #include <queue>
#include <tuple> #include <tuple>
#include <type_traits> #include <type_traits>
#include "utils/geometry_utils.h"
#include "utils/vec3.h" #include "utils/vec3.h"
const auto kEpsilon = 1e-4;
/* /*
_PLANES and _TRIS define the 4- and 3-connectivity _PLANES and _TRIS define the 4- and 3-connectivity
of the 8 box corners. of the 8 box corners.

BIN
tests/data/real_boxes.pkl Normal file

Binary file not shown.

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import pickle
import random import random
import unittest import unittest
from typing import List, Tuple, Union from typing import List, Tuple, Union
@ -15,7 +16,6 @@ from pytorch3d.io import save_obj
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
from pytorch3d.transforms.rotation_conversions import random_rotation from pytorch3d.transforms.rotation_conversions import random_rotation
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3] OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
DATA_DIR = get_tests_dir() / "data" DATA_DIR = get_tests_dir() / "data"
DEBUG = False DEBUG = False
@ -167,6 +167,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
device=vol.device, device=vol.device,
dtype=vol.dtype, dtype=vol.dtype,
), ),
atol=1e-7,
) )
# 7th test: hand coded example and test with meshlab output # 7th test: hand coded example and test with meshlab output
@ -283,20 +284,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1) self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1) self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
def _test_real_boxes(self, overlap_fn, device):
data_filename = "./real_boxes.pkl"
with open(DATA_DIR / data_filename, "rb") as f:
example = pickle.load(f)
verts1 = torch.FloatTensor(example["verts1"])
verts2 = torch.FloatTensor(example["verts2"])
boxes = torch.stack((verts1, verts2)).to(device)
iou_expected = torch.eye(2).to(device)
vol, iou = overlap_fn(boxes, boxes)
self.assertClose(iou, iou_expected)
def test_iou_naive(self): def test_iou_naive(self):
device = get_random_cuda_device() device = get_random_cuda_device()
self._test_iou(self._box3d_overlap_naive_batched, device) self._test_iou(self._box3d_overlap_naive_batched, device)
self._test_compare_objectron(self._box3d_overlap_naive_batched, device) self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
self._test_real_boxes(self._box3d_overlap_naive_batched, device)
def test_iou_cpu(self): def test_iou_cpu(self):
device = torch.device("cpu") device = torch.device("cpu")
self._test_iou(box3d_overlap, device) self._test_iou(box3d_overlap, device)
self._test_compare_objectron(box3d_overlap, device) self._test_compare_objectron(box3d_overlap, device)
self._test_real_boxes(box3d_overlap, device)
def test_iou_cuda(self): def test_iou_cuda(self):
device = torch.device("cuda:0") device = torch.device("cuda:0")
self._test_iou(box3d_overlap, device) self._test_iou(box3d_overlap, device)
self._test_compare_objectron(box3d_overlap, device) self._test_compare_objectron(box3d_overlap, device)
self._test_real_boxes(box3d_overlap, device)
def _test_compare_objectron(self, overlap_fn, device): def _test_compare_objectron(self, overlap_fn, device):
# Load saved objectron data # Load saved objectron data
@ -656,7 +673,7 @@ def is_inside(
n: torch.Tensor, n: torch.Tensor,
points: torch.Tensor, points: torch.Tensor,
return_proj: bool = True, return_proj: bool = True,
eps: float = 1e-6, eps: float = 1e-4,
): ):
""" """
Computes whether point is "inside" the plane. Computes whether point is "inside" the plane.
@ -778,17 +795,17 @@ def clip_tri_by_plane_oneout(
vout is "outside" the plane and vin1, vin2 are "inside" vout is "outside" the plane and vin1, vin2 are "inside"
Returns: Returns:
verts: tensor of shape (4, 3) containing the new vertices formed after clipping the verts: tensor of shape (4, 3) containing the new vertices formed after clipping the
original intersectiong triangle (vout, vin1, vin2) original intersecting triangle (vout, vin1, vin2)
faces: tensor of shape (2, 3) defining the vertex indices forming the two new triangles faces: tensor of shape (2, 3) defining the vertex indices forming the two new triangles
which are "inside" the plane formed after clipping which are "inside" the plane formed after clipping
""" """
device = plane.device device = plane.device
# point of intersection between plane and (vin1, vout) # point of intersection between plane and (vin1, vout)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout) pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
assert a1 >= eps and a1 <= 1.0, a1 assert a1 >= -eps and a1 <= 1.0 + eps, a1
# point of intersection between plane and (vin2, vout) # point of intersection between plane and (vin2, vout)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout) pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
assert a2 >= 0.0 and a2 <= 1.0, a2 assert a2 >= -eps and a2 <= 1.0 + eps, a2
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3 verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
faces = torch.tensor( faces = torch.tensor(
@ -823,10 +840,10 @@ def clip_tri_by_plane_twoout(
device = plane.device device = plane.device
# point of intersection between plane and (vin, vout1) # point of intersection between plane and (vin, vout1)
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1) pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
assert a1 >= eps and a1 <= 1.0, a1 assert a1 >= -eps and a1 <= 1.0 + eps, a1
# point of intersection between plane and (vin, vout2) # point of intersection between plane and (vin, vout2)
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2) pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
assert a2 >= eps and a2 <= 1.0, a2 assert a2 >= -eps and a2 <= 1.0 + eps, a2
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3 verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
faces = torch.tensor( faces = torch.tensor(
@ -917,6 +934,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
`iou = vol / (vol1 + vol2 - vol)` `iou = vol / (vol1 + vol2 - vol)`
""" """
device = box1.device device = box1.device
# For boxes1 we compute the unit directions n1 corresponding to quad_faces # For boxes1 we compute the unit directions n1 corresponding to quad_faces
n1 = box_planar_dir(box1) # (6, 3) n1 = box_planar_dir(box1) # (6, 3)
# For boxes2 we compute the unit directions n2 corresponding to quad_faces # For boxes2 we compute the unit directions n2 corresponding to quad_faces