face areas backward

Summary:
Added backward for mesh face areas & normals. Exposed it as a layer. Replaced the computation with the new op in Meshes and in Sample Points.

Current issue: Circular imports. I moved the import of the op in meshes inside the function scope.

Reviewed By: jcjohnson

Differential Revision: D19920082

fbshipit-source-id: d213226d5e1d19a0c8452f4d32771d07e8b91c0a
This commit is contained in:
Georgia Gkioxari
2020-02-20 11:10:04 -08:00
committed by Facebook Github Bot
parent 9ca5489107
commit a3baa367e3
11 changed files with 513 additions and 63 deletions

View File

@@ -11,19 +11,19 @@ from test_face_areas_normals import TestFaceAreasNormals
def bm_face_areas_normals() -> None:
kwargs_list = []
backend_cuda = [False]
backend = ["cpu"]
if torch.cuda.is_available():
backend_cuda.append(True)
backend.append("cuda:0")
num_meshes = [2, 10, 32]
num_verts = [100, 1000]
num_faces = [300, 3000]
test_cases = product(num_meshes, num_verts, num_faces, backend_cuda)
test_cases = product(num_meshes, num_verts, num_faces, backend)
for case in test_cases:
n, v, f, c = case
n, v, f, d = case
kwargs_list.append(
{"num_meshes": n, "num_verts": v, "num_faces": f, "cuda": c}
{"num_meshes": n, "num_verts": v, "num_faces": f, "device": d}
)
benchmark(
TestFaceAreasNormals.face_areas_normals_with_init,

View File

@@ -5,7 +5,7 @@
import unittest
import torch
from pytorch3d import _C
from pytorch3d.ops import mesh_face_areas_normals
from pytorch3d.structures.meshes import Meshes
from common_testing import TestCaseMixin
@@ -28,7 +28,10 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
faces_list = []
for _ in range(num_meshes):
verts = torch.rand(
(num_verts, 3), dtype=torch.float32, device=device
(num_verts, 3),
dtype=torch.float32,
device=device,
requires_grad=True,
)
faces = torch.randint(
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
@@ -40,10 +43,12 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
return meshes
@staticmethod
def face_areas_normals(verts, faces):
def face_areas_normals_python(verts, faces):
"""
Pytorch implementation for face areas & normals.
"""
# TODO(gkioxari) Change cast to floats once we add support for doubles.
verts = verts.float()
vertices_faces = verts[faces] # (F, 3, 3)
# vector pointing from v0 to v1
v01 = vertices_faces[:, 1] - vertices_faces[:, 0]
@@ -56,24 +61,41 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
)
return face_areas, face_normals
def _test_face_areas_normals_helper(self, device):
def _test_face_areas_normals_helper(self, device, dtype=torch.float32):
"""
Check the results from face_areas cuda/cpp and PyTorch implementation are
the same.
"""
meshes = self.init_meshes(10, 1000, 3000, device=device)
verts = meshes.verts_packed()
faces = meshes.faces_packed()
meshes = self.init_meshes(10, 200, 400, device=device)
# make them leaf nodes
verts = meshes.verts_packed().detach().clone().to(dtype)
verts.requires_grad = True
faces = meshes.faces_packed().detach().clone()
areas_torch, normals_torch = self.face_areas_normals(verts, faces)
areas, normals = _C.face_areas_normals(verts, faces)
# forward
areas, normals = mesh_face_areas_normals(verts, faces)
verts_torch = verts.detach().clone().to(dtype)
verts_torch.requires_grad = True
faces_torch = faces.detach().clone()
areas_torch, normals_torch = TestFaceAreasNormals.face_areas_normals_python(
verts_torch, faces_torch
)
self.assertClose(areas_torch, areas, atol=1e-7)
# normals get normalized by area thus sensitivity increases as areas
# in our tests can be arbitrarily small. Thus we compare normals after
# multiplying with areas
unnormals = normals * areas.view(-1, 1)
unnormals_torch = normals_torch * areas_torch.view(-1, 1)
self.assertClose(unnormals_torch, unnormals, atol=1e-7)
self.assertClose(unnormals_torch, unnormals, atol=1e-6)
# backward
grad_areas = torch.rand(areas.shape, device=device, dtype=dtype)
grad_normals = torch.rand(normals.shape, device=device, dtype=dtype)
areas.backward((grad_areas, grad_normals))
grad_verts = verts.grad
areas_torch.backward((grad_areas, grad_normals))
grad_verts_torch = verts_torch.grad
self.assertClose(grad_verts_torch, grad_verts, atol=1e-6)
def test_face_areas_normals_cpu(self):
self._test_face_areas_normals_helper("cpu")
@@ -81,11 +103,16 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
def test_face_areas_normals_cuda(self):
self._test_face_areas_normals_helper("cuda:0")
def test_nonfloats_cpu(self):
self._test_face_areas_normals_helper("cpu", dtype=torch.double)
def test_nonfloats_cuda(self):
self._test_face_areas_normals_helper("cuda:0", dtype=torch.double)
@staticmethod
def face_areas_normals_with_init(
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
):
device = "cuda:0" if cuda else "cpu"
meshes = TestFaceAreasNormals.init_meshes(
num_meshes, num_verts, num_faces, device
)
@@ -94,16 +121,15 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize()
def face_areas_normals():
_C.face_areas_normals(verts, faces)
mesh_face_areas_normals(verts, faces)
torch.cuda.synchronize()
return face_areas_normals
@staticmethod
def face_areas_normals_with_init_torch(
num_meshes: int, num_verts: int, num_faces: int, cuda: bool = True
num_meshes: int, num_verts: int, num_faces: int, device: str = "cpu"
):
device = "cuda:0" if cuda else "cpu"
meshes = TestFaceAreasNormals.init_meshes(
num_meshes, num_verts, num_faces, device
)
@@ -112,7 +138,7 @@ class TestFaceAreasNormals(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize()
def face_areas_normals():
TestFaceAreasNormals.face_areas_normals(verts, faces)
TestFaceAreasNormals.face_areas_normals_python(verts, faces)
torch.cuda.synchronize()
return face_areas_normals

View File

@@ -6,8 +6,7 @@ import unittest
from pathlib import Path
import torch
from pytorch3d import _C
from pytorch3d.ops.sample_points_from_meshes import sample_points_from_meshes
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils.ico_sphere import ico_sphere