mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
face areas backward
Summary: Added backward for mesh face areas & normals. Exposed it as a layer. Replaced the computation with the new op in Meshes and in Sample Points. Current issue: Circular imports. I moved the import of the op in meshes inside the function scope. Reviewed By: jcjohnson Differential Revision: D19920082 fbshipit-source-id: d213226d5e1d19a0c8452f4d32771d07e8b91c0a
This commit is contained in:
committed by
Facebook Github Bot
parent
9ca5489107
commit
a3baa367e3
@@ -11,19 +11,19 @@ from test_face_areas_normals import TestFaceAreasNormals
|
||||
|
||||
def bm_face_areas_normals() -> None:
|
||||
kwargs_list = []
|
||||
backend_cuda = [False]
|
||||
backend = ["cpu"]
|
||||
if torch.cuda.is_available():
|
||||
backend_cuda.append(True)
|
||||
backend.append("cuda:0")
|
||||
|
||||
num_meshes = [2, 10, 32]
|
||||
num_verts = [100, 1000]
|
||||
num_faces = [300, 3000]
|
||||
|
||||
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
|
||||
test_cases = product(num_meshes, num_verts, num_faces, backend)
|
||||
for case in test_cases:
|
||||
n, v, f, c = case
|
||||
n, v, f, d = case
|
||||
kwargs_list.append(
|
||||
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
|
||||
{"num_meshes": n, "num_verts": v, "num_faces": f, "device": d}
|
||||
)
|
||||
benchmark(
|
||||
TestFaceAreasNormals.face_areas_normals_with_init,
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops import mesh_face_areas_normals
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
@@ -28,7 +28,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
verts = torch.rand(
|
||||
(num_verts, 3), dtype=torch.float32, device=device
|
||||
(num_verts, 3),
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
requires_grad=True,
|
||||
)
|
||||
faces = torch.randint(
|
||||
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
||||
@@ -40,10 +43,12 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
return meshes
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals(verts, faces):
|
||||
def face_areas_normals_python(verts, faces):
|
||||
"""
|
||||
Pytorch implementation for face areas & normals.
|
||||
"""
|
||||
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||
verts = verts.float()
|
||||
vertices_faces = verts[faces] # (F, 3, 3)
|
||||
# vector pointing from v0 to v1
|
||||
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
|
||||
@@ -56,24 +61,41 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
return face_areas, face_normals
|
||||
|
||||
def _test_face_areas_normals_helper(self, device):
|
||||
def _test_face_areas_normals_helper(self, device, dtype=torch.float32):
|
||||
"""
|
||||
Check the results from face_areas cuda/cpp and PyTorch implementation are
|
||||
the same.
|
||||
"""
|
||||
meshes = self.init_meshes(10, 1000, 3000, device=device)
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
meshes = self.init_meshes(10, 200, 400, device=device)
|
||||
# make them leaf nodes
|
||||
verts = meshes.verts_packed().detach().clone().to(dtype)
|
||||
verts.requires_grad = True
|
||||
faces = meshes.faces_packed().detach().clone()
|
||||
|
||||
areas_torch, normals_torch = self.face_areas_normals(verts, faces)
|
||||
areas, normals = _C.face_areas_normals(verts, faces)
|
||||
# forward
|
||||
areas, normals = mesh_face_areas_normals(verts, faces)
|
||||
verts_torch = verts.detach().clone().to(dtype)
|
||||
verts_torch.requires_grad = True
|
||||
faces_torch = faces.detach().clone()
|
||||
areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python(
|
||||
verts_torch, faces_torch
|
||||
)
|
||||
self.assertClose(areas_torch, areas, atol=1e-7)
|
||||
# normals get normalized by area thus sensitivity increases as areas
|
||||
# in our tests can be arbitrarily small. Thus we compare normals after
|
||||
# multiplying with areas
|
||||
unnormals = normals * areas.view(-1, 1)
|
||||
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
|
||||
self.assertClose(unnormals_torch, unnormals, atol=1e-7)
|
||||
self.assertClose(unnormals_torch, unnormals, atol=1e-6)
|
||||
|
||||
# backward
|
||||
grad_areas = torch.rand(areas.shape, device=device, dtype=dtype)
|
||||
grad_normals = torch.rand(normals.shape, device=device, dtype=dtype)
|
||||
areas.backward((grad_areas, grad_normals))
|
||||
grad_verts = verts.grad
|
||||
areas_torch.backward((grad_areas, grad_normals))
|
||||
grad_verts_torch = verts_torch.grad
|
||||
self.assertClose(grad_verts_torch, grad_verts, atol=1e-6)
|
||||
|
||||
def test_face_areas_normals_cpu(self):
|
||||
self._test_face_areas_normals_helper("cpu")
|
||||
@@ -81,11 +103,16 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
def test_face_areas_normals_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0")
|
||||
|
||||
def test_nonfloats_cpu(self):
|
||||
self._test_face_areas_normals_helper("cpu", dtype=torch.double)
|
||||
|
||||
def test_nonfloats_cuda(self):
|
||||
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double)
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals_with_init(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
||||
num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
|
||||
):
|
||||
device = "cuda:0" if cuda else "cpu"
|
||||
meshes = TestFaceAreasNormals.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
@@ -94,16 +121,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def face_areas_normals():
|
||||
_C.face_areas_normals(verts, faces)
|
||||
mesh_face_areas_normals(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return face_areas_normals
|
||||
|
||||
@staticmethod
|
||||
def face_areas_normals_with_init_torch(
|
||||
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
|
||||
num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
|
||||
):
|
||||
device = "cuda:0" if cuda else "cpu"
|
||||
meshes = TestFaceAreasNormals.init_meshes(
|
||||
num_meshes, num_verts, num_faces, device
|
||||
)
|
||||
@@ -112,7 +138,7 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def face_areas_normals():
|
||||
TestFaceAreasNormals.face_areas_normals(verts, faces)
|
||||
TestFaceAreasNormals.face_areas_normals_python(verts, faces)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return face_areas_normals
|
||||
|
||||
@@ -6,8 +6,7 @@ import unittest
|
||||
from pathlib import Path
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
|
||||
from pytorch3d.ops import sample_points_from_meshes
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
|
||||
Reference in New Issue
Block a user