lint fixes

Summary:
Ran the linter.
TODO: need to update the linter as per D21353065.

Reviewed By: bottler

Differential Revision: D21362270

fbshipit-source-id: ad0e781de0a29f565ad25c43bc94a19b1828c020
This commit is contained in:
Nikhila Ravi 2020-05-04 09:55:03 -07:00 committed by Facebook GitHub Bot
parent 0c595dcf5b
commit 0eca74fa5f
15 changed files with 73 additions and 57 deletions

View File

@ -80,7 +80,7 @@ def make_mesh_texture_atlas(
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
faces_verts_uvs.device
)
if (faces_material_ind).sum() > 0:
if faces_material_ind.sum() > 0:
# For these faces, update the base color to the
# diffuse material color.
if "diffuse_color" not in props:

View File

@ -117,8 +117,8 @@ def chamfer_distance(
P2 = y.shape[1]
# Check if inputs are heterogeneous and create a lengths mask.
is_x_heterogeneous = ~(x_lengths == P1).all()
is_y_heterogeneous = ~(y_lengths == P2).all()
is_x_heterogeneous = (x_lengths != P1).any()
is_y_heterogeneous = (y_lengths != P2).any()
x_mask = (
torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
) # shape [N, P1]

View File

@ -259,7 +259,7 @@ def rasterize_meshes_python(
N = len(meshes)
# Assume only square images.
# TODO(T52813608) extend support for non-square images.
H, W, = image_size, image_size
H, W = image_size, image_size
K = faces_per_pixel
device = meshes.device
@ -479,7 +479,7 @@ def point_line_distance(p, v0, v1):
if l2 <= kEpsilon:
return (p - v1).dot(p - v1) # v0 == v1
t = (v1v0).dot(p - v0) / l2
t = v1v0.dot(p - v0) / l2
t = torch.clamp(t, min=0.0, max=1.0)
p_proj = v0 + t * v1v0
delta_p = p_proj - p

View File

@ -308,7 +308,7 @@ def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "
args_Nd = []
for c in args_1d:
if c.shape[0] != 1 and c.shape[0] != N:
msg = "Got non-broadcastable sizes %r" % (sizes)
msg = "Got non-broadcastable sizes %r" % sizes
raise ValueError(msg)
# Expand broadcast dim and keep non broadcast dims the same size

View File

@ -926,8 +926,8 @@ class Meshes(object):
self._num_verts_per_mesh = torch.zeros(
(0,), dtype=torch.int64, device=self.device
)
self._faces_packed = -torch.ones(
(0, 3), dtype=torch.int64, device=self.device
self._faces_packed = -(
torch.ones((0, 3), dtype=torch.int64, device=self.device)
)
self._faces_packed_to_mesh_idx = torch.zeros(
(0,), dtype=torch.int64, device=self.device
@ -977,8 +977,8 @@ class Meshes(object):
return
if self.isempty():
self._edges_packed = -torch.ones(
(0, 2), dtype=torch.int64, device=self.device
self._edges_packed = torch.full(
(0, 2), fill_value=-1, dtype=torch.int64, device=self.device
)
self._edges_packed_to_mesh_idx = torch.zeros(
(0,), dtype=torch.int64, device=self.device

View File

@ -261,7 +261,9 @@ class TestBlending(unittest.TestCase):
# of the image with surrounding padded values.
N, S, K = 1, 8, 2
device = torch.device("cuda")
pix_to_face = -torch.ones((N, S, S, K), dtype=torch.int64, device=device)
pix_to_face = torch.full(
(N, S, S, K), fill_value=-1, dtype=torch.int64, device=device
)
h = int(S / 2)
pix_to_face_full = torch.randint(
size=(N, h, h, K), low=0, high=100, device=device

View File

@ -203,7 +203,7 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
+ torch.cos(elev) * torch.cos(azim)
)
grad_elev = (
-torch.sin(elev) * torch.sin(azim)
-(torch.sin(elev)) * torch.sin(azim)
+ torch.cos(elev)
- torch.sin(elev) * torch.cos(azim)
)
@ -260,7 +260,7 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
+ torch.cos(elev) * torch.cos(azim)
)
grad_elev = (
-torch.sin(elev) * torch.sin(azim)
-(torch.sin(elev)) * torch.sin(azim)
+ torch.cos(elev)
- torch.sin(elev) * torch.cos(azim)
)
@ -395,8 +395,8 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5
else:
cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9
cam_params["bottom"] = -torch.rand(batch_size) * 0.2 - 0.9
cam_params["left"] = -torch.rand(batch_size) * 0.2 - 0.9
cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9
cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9
elif cam_type in (SfMOrthographicCameras, SfMPerspectiveCameras):
cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1
@ -532,7 +532,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase):
P = cameras.get_projection_transform()
vertices = torch.tensor([1, 2, 10], dtype=torch.float32)
z1 = 1.0 # vertices at far clipping plane so z = 1.0
z2 = (20.0 / (20.0 - 1.0) * 10.0 + -(20.0) / (20.0 - 1.0)) / 10.0
z2 = (20.0 / (20.0 - 1.0) * 10.0 + -20.0 / (20.0 - 1.0)) / 10.0
projected_verts = torch.tensor(
[
[np.sqrt(3) / 10.0, 2 * np.sqrt(3) / 10.0, z1],
@ -660,7 +660,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase):
cameras = OpenGLOrthographicCameras(znear=near, zfar=far)
P = cameras.get_projection_transform()
vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32)
z2 = 1.0 / (20.0 - 1.0) * 10.0 + -(1.0) / (20.0 - 1.0)
z2 = 1.0 / (20.0 - 1.0) * 10.0 + -1.0 / (20.0 - 1.0)
projected_verts = torch.tensor(
[[1.0, 2.0, 1.0], [1.0, 2.0, z2]], dtype=torch.float32
)

View File

@ -35,6 +35,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
low = 0 if allow_empty else 1
p1_lengths = torch.randint(low, P1, size=(N,), dtype=torch.int64, device=device)
p2_lengths = torch.randint(low, P2, size=(N,), dtype=torch.int64, device=device)
P1 = p1_lengths.max().item()
P2 = p2_lengths.max().item()
weights = torch.rand((N,), dtype=torch.float32, device=device)
# list of points and normals tensors
@ -109,9 +111,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
) # shape [N, P2]
is_x_heterogeneous = ~(x_lengths == P1).all()
is_y_heterogeneous = ~(y_lengths == P2).all()
is_x_heterogeneous = (x_lengths != P1).any()
is_y_heterogeneous = (y_lengths != P2).any()
# Only calculate the distances for the points which are not masked
for n in range(N):
for i1 in range(x_lengths[n]):

View File

@ -180,7 +180,7 @@ class TestMeshNormalConsistency(unittest.TestCase):
# mesh1: normal consistency computation
n0 = (verts1[1] - verts1[2]).cross(verts1[3] - verts1[2])
n1 = (verts1[1] - verts1[2]).cross(verts1[0] - verts1[2])
loss1 = 1.0 - torch.cosine_similarity(n0.view(1, 3), -n1.view(1, 3))
loss1 = 1.0 - torch.cosine_similarity(n0.view(1, 3), -(n1.view(1, 3)))
# mesh2: normal consistency computation
# In the cube mesh, 6 edges are shared with coplanar faces (loss=0),
@ -193,9 +193,9 @@ class TestMeshNormalConsistency(unittest.TestCase):
n2 = (verts3[1] - verts3[2]).cross(verts3[4] - verts3[2])
loss3 = (
3.0
- torch.cosine_similarity(n0.view(1, 3), -n1.view(1, 3))
- torch.cosine_similarity(n0.view(1, 3), -n2.view(1, 3))
- torch.cosine_similarity(n1.view(1, 3), -n2.view(1, 3))
- torch.cosine_similarity(n0.view(1, 3), -(n1.view(1, 3)))
- torch.cosine_similarity(n0.view(1, 3), -(n2.view(1, 3)))
- torch.cosine_similarity(n1.view(1, 3), -(n2.view(1, 3)))
)
loss3 /= 3.0

View File

@ -52,11 +52,11 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
)
self.assertTrue(torch.all(verts == expected_verts))
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
padded_vals = -torch.ones_like(faces.verts_idx)
padded_vals = -(torch.ones_like(faces.verts_idx))
self.assertTrue(torch.all(faces.normals_idx == padded_vals))
self.assertTrue(torch.all(faces.textures_idx == padded_vals))
self.assertTrue(
torch.all(faces.materials_idx == -torch.ones(len(expected_faces)))
torch.all(faces.materials_idx == -(torch.ones(len(expected_faces))))
)
self.assertTrue(normals is None)
self.assertTrue(textures is None)
@ -124,10 +124,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
dtype=torch.float32,
)
expected_faces_normals_idx = -torch.ones_like(expected_faces, dtype=torch.int64)
expected_faces_normals_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64)
expected_faces_textures_idx = -torch.ones_like(
expected_faces, dtype=torch.int64
expected_faces_textures_idx = -(
torch.ones_like(expected_faces, dtype=torch.int64)
)
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64)
@ -207,7 +209,7 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
self.assertClose(expected_textures, textures)
self.assertClose(expected_verts, verts)
self.assertTrue(
torch.all(faces.normals_idx == -torch.ones_like(faces.textures_idx))
torch.all(faces.normals_idx == -(torch.ones_like(faces.textures_idx)))
)
self.assertTrue(normals is None)
self.assertTrue(materials is None)

View File

@ -431,7 +431,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
# Naive implementation: forward & backward
edges_packed = meshes.edges_packed()
edges_list = packed_to_list(edges_packed, meshes.num_edges_per_mesh().tolist())
loss_naive = torch.zeros((N), dtype=torch.float32, device=device)
loss_naive = torch.zeros(N, dtype=torch.float32, device=device)
for i in range(N):
points = pcls.points_list()[i]
verts = meshes.verts_list()[i]
@ -461,7 +461,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
self.assertClose(loss_op, loss_naive)
# Compare backward pass
rand_val = torch.rand((1)).item()
rand_val = torch.rand(1).item()
grad_dist = torch.tensor(rand_val, dtype=torch.float32, device=device)
loss_naive.backward(grad_dist)
@ -707,7 +707,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
pcls_op = Pointclouds(points_op)
# naive implementation
loss_naive = torch.zeros((N), dtype=torch.float32, device=device)
loss_naive = torch.zeros(N, dtype=torch.float32, device=device)
for i in range(N):
points = pcls.points_list()[i]
verts = meshes.verts_list()[i]
@ -735,7 +735,7 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
self.assertClose(loss_op, loss_naive)
# Compare backward pass
rand_val = torch.rand((1)).item()
rand_val = torch.rand(1).item()
grad_dist = torch.tensor(rand_val, dtype=torch.float32, device=device)
loss_naive.backward(grad_dist)

View File

@ -112,11 +112,13 @@ class TestICP(TestCaseMixin, unittest.TestCase):
]
# run full icp
converged, _, Xt, (
R,
T,
s,
), t_hist = points_alignment.iterative_closest_point(
(
converged,
_,
Xt,
(R, T, s),
t_hist,
) = points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=False,
@ -130,11 +132,13 @@ class TestICP(TestCaseMixin, unittest.TestCase):
t_init = t_hist[min(2, len(t_hist) - 1)]
# rerun the ICP
converged_init, _, Xt_init, (
R_init,
T_init,
s_init,
), t_hist_init = points_alignment.iterative_closest_point(
(
converged_init,
_,
Xt_init,
(R_init, T_init, s_init),
t_hist_init,
) = points_alignment.iterative_closest_point(
X,
Y,
init_transform=t_init,
@ -182,11 +186,13 @@ class TestICP(TestCaseMixin, unittest.TestCase):
n_points_Y = Y_pcl.num_points_per_cloud()
# run icp with Pointlouds inputs
_, _, Xt_pcl, (
R_pcl,
T_pcl,
s_pcl,
), _ = points_alignment.iterative_closest_point(
(
_,
_,
Xt_pcl,
(R_pcl, T_pcl, s_pcl),
_,
) = points_alignment.iterative_closest_point(
X_pcl,
Y_pcl,
estimate_scale=estimate_scale,
@ -263,11 +269,13 @@ class TestICP(TestCaseMixin, unittest.TestCase):
]
# run the icp algorithm
converged, _, _, (
R_ours,
T_ours,
s_ours,
), _ = points_alignment.iterative_closest_point(
(
converged,
_,
_,
(R_ours, T_ours, s_ours),
_,
) = points_alignment.iterative_closest_point(
X,
Y,
estimate_scale=estimate_scale,

View File

@ -74,7 +74,10 @@ class TestPCLNormals(TestCaseMixin, unittest.TestCase):
# check for both disambiguation options
for disambiguate_directions in (True, False):
curvatures, local_coord_frames = estimate_pointcloud_local_coord_frames(
(
curvatures,
local_coord_frames,
) = estimate_pointcloud_local_coord_frames(
pcl,
neighborhood_size=neighborhood_size,
disambiguate_directions=disambiguate_directions,

View File

@ -448,7 +448,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
], dtype=torch.int64, device=device)
# fmt: on
pix_to_face_padded = -torch.ones_like(pix_to_face_frontface)
pix_to_face_padded = -(torch.ones_like(pix_to_face_frontface))
# Run with and without culling
# Without culling, for k=0, the front face (i.e. face 2) is
# rasterized and for k=1, the back face (i.e. face 3) is

View File

@ -95,7 +95,7 @@ class TestSamplePoints(TestCaseMixin, unittest.TestCase):
x, y, z = samples[1, :].unbind(1)
radius = torch.sqrt(x ** 2 + y ** 2 + z ** 2)
self.assertClose(radius, torch.ones((num_samples)))
self.assertClose(radius, torch.ones(num_samples))
# Pyramid: points shoudl lie on one of the faces.
pyramid_verts = samples[2, :]