mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 07:40:34 +08:00
Address black + isort fbsource linter warnings
Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff) Reviewed By: nikhilaravi Differential Revision: D20558373 fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eb512ffde3
commit
d57daa6f85
@@ -1,20 +1,15 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d import _C
|
||||
from pytorch3d.ops.graph_conv import (
|
||||
GraphConv,
|
||||
gather_scatter,
|
||||
gather_scatter_python,
|
||||
)
|
||||
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.utils import ico_sphere
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
def test_undirected(self):
|
||||
@@ -89,8 +84,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
w1 = torch.tensor([[-1, -1, -1]], dtype=dtype)
|
||||
|
||||
expected_y = torch.tensor(
|
||||
[[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]],
|
||||
dtype=dtype,
|
||||
[[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]], dtype=dtype
|
||||
)
|
||||
|
||||
conv = GraphConv(3, 1, directed=True).to(dtype)
|
||||
@@ -126,17 +120,13 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
def test_cpu_cuda_tensor_error(self):
|
||||
device = torch.device("cuda:0")
|
||||
verts = torch.tensor(
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
|
||||
)
|
||||
edges = torch.tensor([[0, 1], [0, 2]])
|
||||
conv = GraphConv(3, 1, directed=True).to(torch.float32)
|
||||
with self.assertRaises(Exception) as err:
|
||||
conv(verts, edges)
|
||||
self.assertTrue(
|
||||
"tensors must be on the same device." in str(err.exception)
|
||||
)
|
||||
self.assertTrue("tensors must be on the same device." in str(err.exception))
|
||||
|
||||
def test_gather_scatter(self):
|
||||
"""
|
||||
@@ -178,12 +168,10 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
backend: str = "cuda",
|
||||
):
|
||||
device = torch.device("cuda") if backend == "cuda" else "cpu"
|
||||
verts_list = torch.tensor(
|
||||
num_verts * [[0.11, 0.22, 0.33]], device=device
|
||||
).view(-1, 3)
|
||||
faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(
|
||||
verts_list = torch.tensor(num_verts * [[0.11, 0.22, 0.33]], device=device).view(
|
||||
-1, 3
|
||||
)
|
||||
faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(-1, 3)
|
||||
meshes = Meshes(num_meshes * [verts_list], num_meshes * [faces_list])
|
||||
gconv = GraphConv(gconv_dim, gconv_dim, directed=directed)
|
||||
gconv.to(device)
|
||||
@@ -191,9 +179,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
||||
total_verts = meshes.verts_packed().shape[0]
|
||||
|
||||
# Features.
|
||||
x = torch.randn(
|
||||
total_verts, gconv_dim, device=device, requires_grad=True
|
||||
)
|
||||
x = torch.randn(total_verts, gconv_dim, device=device, requires_grad=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_graph_conv():
|
||||
|
||||
Reference in New Issue
Block a user