mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Address black + isort fbsource linter warnings
Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff) Reviewed By: nikhilaravi Differential Revision: D20558373 fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eb512ffde3
commit
d57daa6f85
@@ -2,22 +2,18 @@
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.ops.vert_align import vert_align
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def vert_align_naive(
|
||||
feats,
|
||||
verts_or_meshes,
|
||||
return_packed: bool = False,
|
||||
align_corners: bool = True,
|
||||
feats, verts_or_meshes, return_packed: bool = False, align_corners: bool = True
|
||||
):
|
||||
"""
|
||||
Naive implementation of vert_align.
|
||||
@@ -60,16 +56,13 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||
return out_feats
|
||||
|
||||
@staticmethod
|
||||
def init_meshes(
|
||||
num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000
|
||||
):
|
||||
def init_meshes(num_meshes: int = 10, num_verts: int = 1000, num_faces: int = 3000):
|
||||
device = torch.device("cuda:0")
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
verts = (
|
||||
torch.rand((num_verts, 3), dtype=torch.float32, device=device)
|
||||
* 2.0
|
||||
torch.rand((num_verts, 3), dtype=torch.float32, device=device) * 2.0
|
||||
- 1.0
|
||||
) # verts in the space of [-1, 1]
|
||||
faces = torch.randint(
|
||||
@@ -82,15 +75,11 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||
return meshes
|
||||
|
||||
@staticmethod
|
||||
def init_feats(
|
||||
batch_size: int = 10, num_channels: int = 256, device: str = "cuda"
|
||||
):
|
||||
def init_feats(batch_size: int = 10, num_channels: int = 256, device: str = "cuda"):
|
||||
H, W = [14, 28], [14, 28]
|
||||
feats = []
|
||||
for (h, w) in zip(H, W):
|
||||
feats.append(
|
||||
torch.rand((batch_size, num_channels, h, w), device=device)
|
||||
)
|
||||
feats.append(torch.rand((batch_size, num_channels, h, w), device=device))
|
||||
return feats
|
||||
|
||||
def test_vert_align_with_meshes(self):
|
||||
@@ -102,16 +91,12 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# feats in list
|
||||
out = vert_align(feats, meshes, return_packed=True)
|
||||
naive_out = TestVertAlign.vert_align_naive(
|
||||
feats, meshes, return_packed=True
|
||||
)
|
||||
naive_out = TestVertAlign.vert_align_naive(feats, meshes, return_packed=True)
|
||||
self.assertClose(out, naive_out)
|
||||
|
||||
# feats as tensor
|
||||
out = vert_align(feats[0], meshes, return_packed=True)
|
||||
naive_out = TestVertAlign.vert_align_naive(
|
||||
feats[0], meshes, return_packed=True
|
||||
)
|
||||
naive_out = TestVertAlign.vert_align_naive(feats[0], meshes, return_packed=True)
|
||||
self.assertClose(out, naive_out)
|
||||
|
||||
def test_vert_align_with_verts(self):
|
||||
@@ -120,30 +105,21 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||
"""
|
||||
feats = TestVertAlign.init_feats(10, 256)
|
||||
verts = (
|
||||
torch.rand(
|
||||
(10, 100, 3), dtype=torch.float32, device=feats[0].device
|
||||
)
|
||||
* 2.0
|
||||
torch.rand((10, 100, 3), dtype=torch.float32, device=feats[0].device) * 2.0
|
||||
- 1.0
|
||||
)
|
||||
|
||||
# feats in list
|
||||
out = vert_align(feats, verts, return_packed=True)
|
||||
naive_out = TestVertAlign.vert_align_naive(
|
||||
feats, verts, return_packed=True
|
||||
)
|
||||
naive_out = TestVertAlign.vert_align_naive(feats, verts, return_packed=True)
|
||||
self.assertClose(out, naive_out)
|
||||
|
||||
# feats as tensor
|
||||
out = vert_align(feats[0], verts, return_packed=True)
|
||||
naive_out = TestVertAlign.vert_align_naive(
|
||||
feats[0], verts, return_packed=True
|
||||
)
|
||||
naive_out = TestVertAlign.vert_align_naive(feats[0], verts, return_packed=True)
|
||||
self.assertClose(out, naive_out)
|
||||
|
||||
out2 = vert_align(
|
||||
feats[0], verts, return_packed=True, align_corners=False
|
||||
)
|
||||
out2 = vert_align(feats[0], verts, return_packed=True, align_corners=False)
|
||||
naive_out2 = TestVertAlign.vert_align_naive(
|
||||
feats[0], verts, return_packed=True, align_corners=False
|
||||
)
|
||||
@@ -158,9 +134,7 @@ class TestVertAlign(TestCaseMixin, unittest.TestCase):
|
||||
verts_list = []
|
||||
faces_list = []
|
||||
for _ in range(num_meshes):
|
||||
verts = torch.rand(
|
||||
(num_verts, 3), dtype=torch.float32, device=device
|
||||
)
|
||||
verts = torch.rand((num_verts, 3), dtype=torch.float32, device=device)
|
||||
faces = torch.randint(
|
||||
num_verts, size=(num_faces, 3), dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user