mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Initial commit
fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
113
tests/test_mesh_edge_loss.py
Normal file
113
tests/test_mesh_edge_loss.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from pytorch3d.loss import mesh_edge_loss
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
from test_sample_points_from_meshes import TestSamplePoints
|
||||
|
||||
|
||||
class TestMeshEdgeLoss(unittest.TestCase):
|
||||
def test_empty_meshes(self):
|
||||
device = torch.device("cuda:0")
|
||||
target_length = 0
|
||||
N = 10
|
||||
V = 32
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(N):
|
||||
vn = torch.randint(3, high=V, size=(1,))[0].item()
|
||||
verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
|
||||
faces = torch.tensor([], dtype=torch.int64, device=device)
|
||||
verts_list.append(verts)
|
||||
faces_list.append(faces)
|
||||
mesh = Meshes(verts=verts_list, faces=faces_list)
|
||||
loss = mesh_edge_loss(mesh, target_length=target_length)
|
||||
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
loss, torch.tensor([0.0], dtype=torch.float32, device=device)
|
||||
)
|
||||
)
|
||||
self.assertTrue(loss.requires_grad)
|
||||
|
||||
@staticmethod
|
||||
def mesh_edge_loss_naive(meshes, target_length: float = 0.0):
|
||||
"""
|
||||
Naive iterative implementation of mesh loss calculation.
|
||||
"""
|
||||
edges_packed = meshes.edges_packed()
|
||||
verts_packed = meshes.verts_packed()
|
||||
edge_to_mesh = meshes.edges_packed_to_mesh_idx()
|
||||
N = len(meshes)
|
||||
device = meshes.device
|
||||
valid = meshes.valid
|
||||
predlosses = torch.zeros((N,), dtype=torch.float32, device=device)
|
||||
|
||||
for b in range(N):
|
||||
if valid[b] == 0:
|
||||
continue
|
||||
mesh_edges = edges_packed[edge_to_mesh == b]
|
||||
verts_edges = verts_packed[mesh_edges]
|
||||
num_edges = mesh_edges.size(0)
|
||||
for e in range(num_edges):
|
||||
v0, v1 = verts_edges[e, 0], verts_edges[e, 1]
|
||||
predlosses[b] += (
|
||||
(v0 - v1).norm(dim=0, p=2) - target_length
|
||||
) ** 2.0
|
||||
|
||||
if num_edges > 0:
|
||||
predlosses[b] = predlosses[b] / num_edges
|
||||
|
||||
return predlosses.mean()
|
||||
|
||||
def test_mesh_edge_loss_output(self):
|
||||
"""
|
||||
Check outputs of tensorized and iterative implementations are the same.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
target_length = 0.5
|
||||
num_meshes = 10
|
||||
num_verts = 32
|
||||
num_faces = 64
|
||||
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
valid = torch.randint(2, size=(num_meshes,))
|
||||
|
||||
for n in range(num_meshes):
|
||||
if valid[n]:
|
||||
vn = torch.randint(3, high=num_verts, size=(1,))[0].item()
|
||||
fn = torch.randint(vn, high=num_faces, size=(1,))[0].item()
|
||||
verts = torch.rand((vn, 3), dtype=torch.float32, device=device)
|
||||
faces = torch.randint(
|
||||
vn, size=(fn, 3), dtype=torch.int64, device=device
|
||||
)
|
||||
else:
|
||||
verts = torch.tensor([], dtype=torch.float32, device=device)
|
||||
faces = torch.tensor([], dtype=torch.int64, device=device)
|
||||
verts_list.append(verts)
|
||||
faces_list.append(faces)
|
||||
meshes = Meshes(verts=verts_list, faces=faces_list)
|
||||
loss = mesh_edge_loss(meshes, target_length=target_length)
|
||||
|
||||
predloss = TestMeshEdgeLoss.mesh_edge_loss_naive(meshes, target_length)
|
||||
self.assertTrue(torch.allclose(loss, predloss))
|
||||
|
||||
@staticmethod
|
||||
def mesh_edge_loss(
|
||||
num_meshes: int = 10, max_v: int = 100, max_f: int = 300
|
||||
):
|
||||
meshes = TestSamplePoints.init_meshes(
|
||||
num_meshes, max_v, max_f, device="cuda:0"
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def compute_loss():
|
||||
mesh_edge_loss(meshes, target_length=0.0)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return compute_loss
|
||||
Reference in New Issue
Block a user