mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 06:40:35 +08:00
Address black + isort fbsource linter warnings
Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff) Reviewed By: nikhilaravi Differential Revision: D20558373 fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eb512ffde3
commit
d57daa6f85
@@ -1,12 +1,11 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.loss import mesh_edge_loss
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
from test_sample_points_from_meshes import TestSamplePoints
|
||||
|
||||
|
||||
@@ -27,9 +26,7 @@ class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase):
|
||||
mesh = Meshes(verts=verts_list, faces=faces_list)
|
||||
loss = mesh_edge_loss(mesh, target_length=target_length)
|
||||
|
||||
self.assertClose(
|
||||
loss, torch.tensor([0.0], dtype=torch.float32, device=device)
|
||||
)
|
||||
self.assertClose(loss, torch.tensor([0.0], dtype=torch.float32, device=device))
|
||||
self.assertTrue(loss.requires_grad)
|
||||
|
||||
@staticmethod
|
||||
@@ -53,9 +50,7 @@ class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase):
|
||||
num_edges = mesh_edges.size(0)
|
||||
for e in range(num_edges):
|
||||
v0, v1 = verts_edges[e, 0], verts_edges[e, 1]
|
||||
predlosses[b] += (
|
||||
(v0 - v1).norm(dim=0, p=2) - target_length
|
||||
) ** 2.0
|
||||
predlosses[b] += ((v0 - v1).norm(dim=0, p=2) - target_length) ** 2.0
|
||||
|
||||
if num_edges > 0:
|
||||
predlosses[b] = predlosses[b] / num_edges
|
||||
@@ -96,12 +91,8 @@ class TestMeshEdgeLoss(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(loss, predloss)
|
||||
|
||||
@staticmethod
|
||||
def mesh_edge_loss(
|
||||
num_meshes: int = 10, max_v: int = 100, max_f: int = 300
|
||||
):
|
||||
meshes = TestSamplePoints.init_meshes(
|
||||
num_meshes, max_v, max_f, device="cuda:0"
|
||||
)
|
||||
def mesh_edge_loss(num_meshes: int = 10, max_v: int = 100, max_f: int = 300):
|
||||
meshes = TestSamplePoints.init_meshes(num_meshes, max_v, max_f, device="cuda:0")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def compute_loss():
|
||||
|
||||
Reference in New Issue
Block a user