mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Applies new import merging and sorting from µsort v1.0. When merging imports, µsort will make a best-effort to move associated comments to match merged elements, but there are known limitations due to the diynamic nature of Python and developer tooling. These changes should not produce any dangerous runtime changes, but may require touch-ups to satisfy linters and other tooling. Note that µsort uses case-insensitive, lexicographical sorting, which results in a different ordering compared to isort. This provides a more consistent sorting order, matching the case-insensitive order used when sorting import statements by module name, and ensures that "frog", "FROG", and "Frog" always sort next to each other. For details on µsort's sorting and merging semantics, see the user guide: https://usort.readthedocs.io/en/stable/guide.html#sorting Reviewed By: bottler Differential Revision: D35553814 fbshipit-source-id: be49bdb6a4c25264ff8d4db3a601f18736d17be1
204 lines
7.2 KiB
Python
204 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from common_testing import get_random_cuda_device, TestCaseMixin
|
|
from pytorch3d import _C
|
|
from pytorch3d.ops.graph_conv import gather_scatter, gather_scatter_python, GraphConv
|
|
from pytorch3d.structures.meshes import Meshes
|
|
from pytorch3d.utils import ico_sphere
|
|
|
|
|
|
class TestGraphConv(TestCaseMixin, unittest.TestCase):
|
|
def test_undirected(self):
|
|
dtype = torch.float32
|
|
device = get_random_cuda_device()
|
|
verts = torch.tensor(
|
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype, device=device
|
|
)
|
|
edges = torch.tensor([[0, 1], [0, 2]], device=device)
|
|
w0 = torch.tensor([[1, 1, 1]], dtype=dtype, device=device)
|
|
w1 = torch.tensor([[-1, -1, -1]], dtype=dtype, device=device)
|
|
|
|
expected_y = torch.tensor(
|
|
[
|
|
[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9],
|
|
[4 + 5 + 6 - 1 - 2 - 3],
|
|
[7 + 8 + 9 - 1 - 2 - 3],
|
|
],
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
conv = GraphConv(3, 1, directed=False).to(device)
|
|
conv.w0.weight.data.copy_(w0)
|
|
conv.w0.bias.data.zero_()
|
|
conv.w1.weight.data.copy_(w1)
|
|
conv.w1.bias.data.zero_()
|
|
|
|
y = conv(verts, edges)
|
|
self.assertClose(y, expected_y)
|
|
|
|
def test_no_edges(self):
|
|
dtype = torch.float32
|
|
verts = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)
|
|
edges = torch.zeros(0, 2, dtype=torch.int64)
|
|
w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
|
|
expected_y = torch.tensor(
|
|
[[1 - 2 - 2 * 3], [4 - 5 - 2 * 6], [7 - 8 - 2 * 9]], dtype=dtype
|
|
)
|
|
conv = GraphConv(3, 1).to(dtype)
|
|
conv.w0.weight.data.copy_(w0)
|
|
conv.w0.bias.data.zero_()
|
|
|
|
y = conv(verts, edges)
|
|
self.assertClose(y, expected_y)
|
|
|
|
def test_no_verts_and_edges(self):
|
|
dtype = torch.float32
|
|
verts = torch.tensor([], dtype=dtype, requires_grad=True)
|
|
edges = torch.tensor([], dtype=dtype)
|
|
w0 = torch.tensor([[1, -1, -2]], dtype=dtype)
|
|
|
|
conv = GraphConv(3, 1).to(dtype)
|
|
conv.w0.weight.data.copy_(w0)
|
|
conv.w0.bias.data.zero_()
|
|
y = conv(verts, edges)
|
|
self.assertClose(y, torch.zeros((0, 1)))
|
|
self.assertTrue(y.requires_grad)
|
|
|
|
conv2 = GraphConv(3, 2).to(dtype)
|
|
conv2.w0.weight.data.copy_(w0.repeat(2, 1))
|
|
conv2.w0.bias.data.zero_()
|
|
y = conv2(verts, edges)
|
|
self.assertClose(y, torch.zeros((0, 2)))
|
|
self.assertTrue(y.requires_grad)
|
|
|
|
def test_directed(self):
|
|
dtype = torch.float32
|
|
verts = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype)
|
|
edges = torch.tensor([[0, 1], [0, 2]])
|
|
w0 = torch.tensor([[1, 1, 1]], dtype=dtype)
|
|
w1 = torch.tensor([[-1, -1, -1]], dtype=dtype)
|
|
|
|
expected_y = torch.tensor(
|
|
[[1 + 2 + 3 - 4 - 5 - 6 - 7 - 8 - 9], [4 + 5 + 6], [7 + 8 + 9]], dtype=dtype
|
|
)
|
|
|
|
conv = GraphConv(3, 1, directed=True).to(dtype)
|
|
conv.w0.weight.data.copy_(w0)
|
|
conv.w0.bias.data.zero_()
|
|
conv.w1.weight.data.copy_(w1)
|
|
conv.w1.bias.data.zero_()
|
|
|
|
y = conv(verts, edges)
|
|
self.assertClose(y, expected_y)
|
|
|
|
def test_backward(self):
|
|
device = get_random_cuda_device()
|
|
mesh = ico_sphere()
|
|
verts = mesh.verts_packed()
|
|
edges = mesh.edges_packed()
|
|
verts_cpu = verts.clone()
|
|
edges_cpu = edges.clone()
|
|
verts_cuda = verts.clone().to(device)
|
|
edges_cuda = edges.clone().to(device)
|
|
verts.requires_grad = True
|
|
verts_cpu.requires_grad = True
|
|
verts_cuda.requires_grad = True
|
|
|
|
neighbor_sums_cuda = gather_scatter(verts_cuda, edges_cuda, False)
|
|
neighbor_sums_cpu = gather_scatter(verts_cpu, edges_cpu, False)
|
|
neighbor_sums = gather_scatter_python(verts, edges, False)
|
|
randoms = torch.rand_like(neighbor_sums)
|
|
(neighbor_sums_cuda * randoms.to(device)).sum().backward()
|
|
(neighbor_sums_cpu * randoms).sum().backward()
|
|
(neighbor_sums * randoms).sum().backward()
|
|
|
|
self.assertClose(verts.grad, verts_cuda.grad.cpu())
|
|
self.assertClose(verts.grad, verts_cpu.grad)
|
|
|
|
def test_repr(self):
|
|
conv = GraphConv(32, 64, directed=True)
|
|
self.assertEqual(repr(conv), "GraphConv(32 -> 64, directed=True)")
|
|
|
|
def test_cpu_cuda_tensor_error(self):
|
|
device = get_random_cuda_device()
|
|
verts = torch.tensor(
|
|
[[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32, device=device
|
|
)
|
|
edges = torch.tensor([[0, 1], [0, 2]])
|
|
conv = GraphConv(3, 1, directed=True).to(torch.float32)
|
|
with self.assertRaises(Exception) as err:
|
|
conv(verts, edges)
|
|
self.assertTrue("tensors must be on the same device." in str(err.exception))
|
|
|
|
def test_gather_scatter(self):
|
|
"""
|
|
Check gather_scatter cuda and python versions give the same results.
|
|
Check that gather_scatter cuda version throws an error if cpu tensors
|
|
are given as input.
|
|
"""
|
|
device = get_random_cuda_device()
|
|
mesh = ico_sphere()
|
|
verts = mesh.verts_packed()
|
|
edges = mesh.edges_packed()
|
|
w0 = nn.Linear(3, 1)
|
|
input = w0(verts)
|
|
|
|
# undirected
|
|
output_python = gather_scatter_python(input, edges, False)
|
|
output_cuda = _C.gather_scatter(
|
|
input.to(device=device), edges.to(device=device), False, False
|
|
)
|
|
self.assertClose(output_cuda.cpu(), output_python)
|
|
|
|
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), False, False)
|
|
self.assertClose(output_cpu, output_python)
|
|
|
|
# directed
|
|
output_python = gather_scatter_python(input, edges, True)
|
|
output_cuda = _C.gather_scatter(
|
|
input.to(device=device), edges.to(device=device), True, False
|
|
)
|
|
self.assertClose(output_cuda.cpu(), output_python)
|
|
output_cpu = _C.gather_scatter(input.cpu(), edges.cpu(), True, False)
|
|
self.assertClose(output_cpu, output_python)
|
|
|
|
@staticmethod
|
|
def graph_conv_forward_backward(
|
|
gconv_dim,
|
|
num_meshes,
|
|
num_verts,
|
|
num_faces,
|
|
directed: bool,
|
|
backend: str = "cuda",
|
|
):
|
|
device = torch.device("cuda") if backend == "cuda" else "cpu"
|
|
verts_list = torch.tensor(num_verts * [[0.11, 0.22, 0.33]], device=device).view(
|
|
-1, 3
|
|
)
|
|
faces_list = torch.tensor(num_faces * [[1, 2, 3]], device=device).view(-1, 3)
|
|
meshes = Meshes(num_meshes * [verts_list], num_meshes * [faces_list])
|
|
gconv = GraphConv(gconv_dim, gconv_dim, directed=directed)
|
|
gconv.to(device)
|
|
edges = meshes.edges_packed()
|
|
total_verts = meshes.verts_packed().shape[0]
|
|
|
|
# Features.
|
|
x = torch.randn(total_verts, gconv_dim, device=device, requires_grad=True)
|
|
torch.cuda.synchronize()
|
|
|
|
def run_graph_conv():
|
|
y1 = gconv(x, edges)
|
|
y1.sum().backward()
|
|
torch.cuda.synchronize()
|
|
|
|
return run_graph_conv
|