Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
Patrick Labatut
2020-03-29 14:46:33 -07:00
committed by Facebook GitHub Bot
parent eb512ffde3
commit d57daa6f85
110 changed files with 705 additions and 1850 deletions

View File

@@ -6,4 +6,5 @@ from .mesh_edge_loss import mesh_edge_loss
from .mesh_laplacian_smoothing import mesh_laplacian_smoothing
from .mesh_normal_consistency import mesh_normal_consistency
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -2,13 +2,10 @@
import torch
import torch.nn.functional as F
from pytorch3d.ops.nearest_neighbor_points import nn_points_idx
def _validate_chamfer_reduction_inputs(
batch_reduction: str, point_reduction: str
):
def _validate_chamfer_reduction_inputs(batch_reduction: str, point_reduction: str):
"""Check the requested reductions are valid.
Args:
@@ -18,17 +15,11 @@ def _validate_chamfer_reduction_inputs(
points, can be one of ["none", "mean", "sum"].
"""
if batch_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'batch_reduction must be one of ["none", "mean", "sum"]'
)
raise ValueError('batch_reduction must be one of ["none", "mean", "sum"]')
if point_reduction not in ["none", "mean", "sum"]:
raise ValueError(
'point_reduction must be one of ["none", "mean", "sum"]'
)
raise ValueError('point_reduction must be one of ["none", "mean", "sum"]')
if batch_reduction == "none" and point_reduction == "none":
raise ValueError(
'batch_reduction and point_reduction cannot both be "none".'
)
raise ValueError('batch_reduction and point_reduction cannot both be "none".')
def chamfer_distance(
@@ -87,10 +78,7 @@ def chamfer_distance(
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return (
(x.sum((1, 2)) * weights) * 0.0,
(x.sum((1, 2)) * weights) * 0.0,
)
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)
return_normals = x_normals is not None and y_normals is not None
cham_norm_x = x.new_zeros(())

View File

@@ -2,6 +2,7 @@
from itertools import islice
import torch
@@ -76,10 +77,7 @@ def mesh_normal_consistency(meshes):
with torch.no_grad():
edge_idx = face_to_edge.reshape(F * 3) # (3 * F,) indexes into edges
vert_idx = (
faces_packed.view(1, F, 3)
.expand(3, F, 3)
.transpose(0, 1)
.reshape(3 * F, 3)
faces_packed.view(1, F, 3).expand(3, F, 3).transpose(0, 1).reshape(3 * F, 3)
)
edge_idx, edge_sort_idx = edge_idx.sort()
vert_idx = vert_idx[edge_sort_idx]
@@ -132,9 +130,7 @@ def mesh_normal_consistency(meshes):
loss = 1 - torch.cosine_similarity(n0, n1, dim=1)
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_idx[:, 0]]
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[
vert_edge_pair_idx[:, 0]
]
verts_packed_to_mesh_idx = verts_packed_to_mesh_idx[vert_edge_pair_idx[:, 0]]
num_normals = verts_packed_to_mesh_idx.bincount(minlength=N)
weights = 1.0 / num_normals[verts_packed_to_mesh_idx].float()