Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
Patrick Labatut
2020-03-29 14:46:33 -07:00
committed by Facebook GitHub Bot
parent eb512ffde3
commit d57daa6f85
110 changed files with 705 additions and 1850 deletions

View File

@@ -1,20 +1,15 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
import torch.nn as nn
from common_testing import TestCaseMixin
from pytorch3d import _C
from pytorch3d.ops.graph_conv import (
GraphConv,
gather_scatter,
gather_scatter_python,
)
from pytorch3d.ops.graph_conv import GraphConv, gather_scatter, gather_scatter_python
from pytorch3d.structures.meshes import Meshes
from pytorch3d.utils import ico_sphere
from common_testing import TestCaseMixin
class TestGraphConv(TestCaseMixin, unittest.TestCase):
def test_undirected(self):
@@ -89,8 +84,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
w1 = torch.tensor([[-1, -1, -1]], dtype=dtype)
expected_y = torch.tensor(
[[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]],
dtype=dtype,
[[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]], dtype=dtype
)
conv = GraphConv(3, 1, directed=True).to(dtype)
@@ -126,17 +120,13 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
def test_cpu_cuda_tensor_error(self):
device = torch.device("cuda:0")
verts = torch.tensor(
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
dtype=torch.float32,
device=device,
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
)
edges = torch.tensor([[0, 1], [0, 2]])
conv = GraphConv(3, 1, directed=True).to(torch.float32)
with self.assertRaises(Exception) as err:
conv(verts, edges)
self.assertTrue(
"tensors must be on the same device." in str(err.exception)
)
self.assertTrue("tensors must be on the same device." in str(err.exception))
def test_gather_scatter(self):
"""
@@ -178,12 +168,10 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
backend: str = "cuda",
):
device = torch.device("cuda") if backend == "cuda" else "cpu"
verts_list = torch.tensor(
num_verts * [[0.11, 0.22, 0.33]], device=device
).view(-1, 3)
faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(
verts_list = torch.tensor(num_verts * [[0.11, 0.22, 0.33]], device=device).view(
-1, 3
)
faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(-1, 3)
meshes = Meshes(num_meshes * [verts_list], num_meshes * [faces_list])
gconv = GraphConv(gconv_dim, gconv_dim, directed=directed)
gconv.to(device)
@@ -191,9 +179,7 @@ class TestGraphConv(TestCaseMixin, unittest.TestCase):
total_verts = meshes.verts_packed().shape[0]
# Features.
x = torch.randn(
total_verts, gconv_dim, device=device, requires_grad=True
)
x = torch.randn(total_verts, gconv_dim, device=device, requires_grad=True)
torch.cuda.synchronize()
def run_graph_conv():