mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 20:32:51 +08:00
IOU box3d epsilon fix
Summary: The epsilon value is important for determining whether vertices are inside/outside a plane. Reviewed By: gkioxari Differential Revision: D31485247 fbshipit-source-id: 5517575de7c02f1afa277d00e0190a81f44f5761
This commit is contained in:
parent
b26f4bc33a
commit
6dfa326922
@ -11,7 +11,8 @@
|
|||||||
#include <thrust/device_vector.h>
|
#include <thrust/device_vector.h>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include "utils/float_math.cuh"
|
#include "utils/float_math.cuh"
|
||||||
#include "utils/geometry_utils.cuh"
|
|
||||||
|
const auto kEpsilon = 1e-4;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
_PLANES and _TRIS define the 4- and 3-connectivity
|
_PLANES and _TRIS define the 4- and 3-connectivity
|
||||||
|
@ -16,9 +16,9 @@
|
|||||||
#include <queue>
|
#include <queue>
|
||||||
#include <tuple>
|
#include <tuple>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include "utils/geometry_utils.h"
|
|
||||||
#include "utils/vec3.h"
|
#include "utils/vec3.h"
|
||||||
|
|
||||||
|
const auto kEpsilon = 1e-4;
|
||||||
/*
|
/*
|
||||||
_PLANES and _TRIS define the 4- and 3-connectivity
|
_PLANES and _TRIS define the 4- and 3-connectivity
|
||||||
of the 8 box corners.
|
of the 8 box corners.
|
||||||
|
BIN
tests/data/real_boxes.pkl
Normal file
BIN
tests/data/real_boxes.pkl
Normal file
Binary file not shown.
@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import pickle
|
||||||
import random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
@ -15,7 +16,6 @@ from pytorch3d.io import save_obj
|
|||||||
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
|
from pytorch3d.ops.iou_box3d import _box_planes, _box_triangles, box3d_overlap
|
||||||
from pytorch3d.transforms.rotation_conversions import random_rotation
|
from pytorch3d.transforms.rotation_conversions import random_rotation
|
||||||
|
|
||||||
|
|
||||||
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
|
OBJECTRON_TO_PYTORCH3D_FACE_IDX = [0, 4, 6, 2, 1, 5, 7, 3]
|
||||||
DATA_DIR = get_tests_dir() / "data"
|
DATA_DIR = get_tests_dir() / "data"
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
@ -167,6 +167,7 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
|||||||
device=vol.device,
|
device=vol.device,
|
||||||
dtype=vol.dtype,
|
dtype=vol.dtype,
|
||||||
),
|
),
|
||||||
|
atol=1e-7,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 7th test: hand coded example and test with meshlab output
|
# 7th test: hand coded example and test with meshlab output
|
||||||
@ -283,20 +284,36 @@ class TestIoU3D(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
|
self.assertClose(vols, torch.tensor([[vol_inters]], device=device), atol=1e-1)
|
||||||
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
|
self.assertClose(ious, torch.tensor([[iou]], device=device), atol=1e-1)
|
||||||
|
|
||||||
|
def _test_real_boxes(self, overlap_fn, device):
|
||||||
|
data_filename = "./real_boxes.pkl"
|
||||||
|
with open(DATA_DIR / data_filename, "rb") as f:
|
||||||
|
example = pickle.load(f)
|
||||||
|
|
||||||
|
verts1 = torch.FloatTensor(example["verts1"])
|
||||||
|
verts2 = torch.FloatTensor(example["verts2"])
|
||||||
|
boxes = torch.stack((verts1, verts2)).to(device)
|
||||||
|
|
||||||
|
iou_expected = torch.eye(2).to(device)
|
||||||
|
vol, iou = overlap_fn(boxes, boxes)
|
||||||
|
self.assertClose(iou, iou_expected)
|
||||||
|
|
||||||
def test_iou_naive(self):
|
def test_iou_naive(self):
|
||||||
device = get_random_cuda_device()
|
device = get_random_cuda_device()
|
||||||
self._test_iou(self._box3d_overlap_naive_batched, device)
|
self._test_iou(self._box3d_overlap_naive_batched, device)
|
||||||
self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
|
self._test_compare_objectron(self._box3d_overlap_naive_batched, device)
|
||||||
|
self._test_real_boxes(self._box3d_overlap_naive_batched, device)
|
||||||
|
|
||||||
def test_iou_cpu(self):
|
def test_iou_cpu(self):
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
self._test_iou(box3d_overlap, device)
|
self._test_iou(box3d_overlap, device)
|
||||||
self._test_compare_objectron(box3d_overlap, device)
|
self._test_compare_objectron(box3d_overlap, device)
|
||||||
|
self._test_real_boxes(box3d_overlap, device)
|
||||||
|
|
||||||
def test_iou_cuda(self):
|
def test_iou_cuda(self):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
self._test_iou(box3d_overlap, device)
|
self._test_iou(box3d_overlap, device)
|
||||||
self._test_compare_objectron(box3d_overlap, device)
|
self._test_compare_objectron(box3d_overlap, device)
|
||||||
|
self._test_real_boxes(box3d_overlap, device)
|
||||||
|
|
||||||
def _test_compare_objectron(self, overlap_fn, device):
|
def _test_compare_objectron(self, overlap_fn, device):
|
||||||
# Load saved objectron data
|
# Load saved objectron data
|
||||||
@ -656,7 +673,7 @@ def is_inside(
|
|||||||
n: torch.Tensor,
|
n: torch.Tensor,
|
||||||
points: torch.Tensor,
|
points: torch.Tensor,
|
||||||
return_proj: bool = True,
|
return_proj: bool = True,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-4,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Computes whether point is "inside" the plane.
|
Computes whether point is "inside" the plane.
|
||||||
@ -778,17 +795,17 @@ def clip_tri_by_plane_oneout(
|
|||||||
vout is "outside" the plane and vin1, vin2 are "inside"
|
vout is "outside" the plane and vin1, vin2 are "inside"
|
||||||
Returns:
|
Returns:
|
||||||
verts: tensor of shape (4, 3) containing the new vertices formed after clipping the
|
verts: tensor of shape (4, 3) containing the new vertices formed after clipping the
|
||||||
original intersectiong triangle (vout, vin1, vin2)
|
original intersecting triangle (vout, vin1, vin2)
|
||||||
faces: tensor of shape (2, 3) defining the vertex indices forming the two new triangles
|
faces: tensor of shape (2, 3) defining the vertex indices forming the two new triangles
|
||||||
which are "inside" the plane formed after clipping
|
which are "inside" the plane formed after clipping
|
||||||
"""
|
"""
|
||||||
device = plane.device
|
device = plane.device
|
||||||
# point of intersection between plane and (vin1, vout)
|
# point of intersection between plane and (vin1, vout)
|
||||||
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
|
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin1, vout)
|
||||||
assert a1 >= eps and a1 <= 1.0, a1
|
assert a1 >= -eps and a1 <= 1.0 + eps, a1
|
||||||
# point of intersection between plane and (vin2, vout)
|
# point of intersection between plane and (vin2, vout)
|
||||||
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
|
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin2, vout)
|
||||||
assert a2 >= 0.0 and a2 <= 1.0, a2
|
assert a2 >= -eps and a2 <= 1.0 + eps, a2
|
||||||
|
|
||||||
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
|
verts = torch.stack((vin1, pint1, pint2, vin2), dim=0) # 4x3
|
||||||
faces = torch.tensor(
|
faces = torch.tensor(
|
||||||
@ -823,10 +840,10 @@ def clip_tri_by_plane_twoout(
|
|||||||
device = plane.device
|
device = plane.device
|
||||||
# point of intersection between plane and (vin, vout1)
|
# point of intersection between plane and (vin, vout1)
|
||||||
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
|
pint1, a1 = plane_edge_point_of_intersection(plane, n, vin, vout1)
|
||||||
assert a1 >= eps and a1 <= 1.0, a1
|
assert a1 >= -eps and a1 <= 1.0 + eps, a1
|
||||||
# point of intersection between plane and (vin, vout2)
|
# point of intersection between plane and (vin, vout2)
|
||||||
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
|
pint2, a2 = plane_edge_point_of_intersection(plane, n, vin, vout2)
|
||||||
assert a2 >= eps and a2 <= 1.0, a2
|
assert a2 >= -eps and a2 <= 1.0 + eps, a2
|
||||||
|
|
||||||
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
|
verts = torch.stack((vin, pint1, pint2), dim=0) # 3x3
|
||||||
faces = torch.tensor(
|
faces = torch.tensor(
|
||||||
@ -917,6 +934,7 @@ def box3d_overlap_naive(box1: torch.Tensor, box2: torch.Tensor):
|
|||||||
`iou = vol / (vol1 + vol2 - vol)`
|
`iou = vol / (vol1 + vol2 - vol)`
|
||||||
"""
|
"""
|
||||||
device = box1.device
|
device = box1.device
|
||||||
|
|
||||||
# For boxes1 we compute the unit directions n1 corresponding to quad_faces
|
# For boxes1 we compute the unit directions n1 corresponding to quad_faces
|
||||||
n1 = box_planar_dir(box1) # (6, 3)
|
n1 = box_planar_dir(box1) # (6, 3)
|
||||||
# For boxes2 we compute the unit directions n2 corresponding to quad_faces
|
# For boxes2 we compute the unit directions n2 corresponding to quad_faces
|
||||||
|
Loading…
x
Reference in New Issue
Block a user