Fix inferred typing

Summary: D35513897 (4b94649f7b) was a pyre infer job which got some things wrong. Correct by adding the correct types, so these things shouldn't need worrying about again.

Reviewed By: patricklabatut

Differential Revision: D35546144

fbshipit-source-id: 89f6ea2b67be27aa0b0b14afff4347cccf23feb7
This commit is contained in:
Jeremy Reizenstein
2022-04-13 04:40:56 -07:00
committed by Facebook GitHub Bot
parent 78fd5af1a6
commit df08ea8eb4
4 changed files with 41 additions and 34 deletions

View File

@@ -167,7 +167,7 @@ def estimate_pointcloud_local_coord_frames(
return curvatures, local_coord_frames
def _disambiguate_vector_directions(pcl, knns, vecs: float) -> float:
def _disambiguate_vector_directions(pcl, knns, vecs: torch.Tensor) -> torch.Tensor:
"""
Disambiguates normal directions according to [1].
@@ -181,7 +181,6 @@ def _disambiguate_vector_directions(pcl, knns, vecs: float) -> float:
# each element of the neighborhood
df = knns - pcl[:, :, None]
# projection of the difference on the principal direction
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
proj = (vecs[:, :, None] * df).sum(3)
# check how many projections are positive
n_pos = (proj > 0).type_as(knns).sum(2, keepdim=True)

View File

@@ -261,7 +261,7 @@ class SubdivideMeshes(nn.Module):
# Calculate the indices needed to group the new and existing verts
# for each mesh.
verts_sort_idx = create_verts_index(
verts_sort_idx = _create_verts_index(
num_verts_per_mesh, num_edges_per_mesh, meshes.device
) # (sum(V_n)+sum(E_n),)
@@ -282,7 +282,9 @@ class SubdivideMeshes(nn.Module):
# Calculate the indices needed to group the existing and new faces
# for each mesh.
face_sort_idx = create_faces_index(num_faces_per_mesh, device=meshes.device)
face_sort_idx = _create_faces_index(
num_faces_per_mesh, device=meshes.device
)
# Reorder the faces to sequentially group existing and new faces
# for each mesh.
@@ -329,7 +331,7 @@ class SubdivideMeshes(nn.Module):
return new_meshes, new_feats
def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
"""
Helper function to group the vertex indices for each mesh. New vertices are
stacked at the end of the original verts tensor, so in order to have
@@ -400,7 +402,7 @@ def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
return verts_idx
def create_faces_index(faces_per_mesh: int, device=None):
def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
"""
Helper function to group the faces indices for each mesh. New faces are
stacked at the end of the original faces tensor, so in order to have
@@ -417,9 +419,7 @@ def create_faces_index(faces_per_mesh: int, device=None):
"""
# e.g. faces_per_mesh = [2, 5, 3]
# pyre-fixme[16]: `int` has no attribute `sum`.
F = faces_per_mesh.sum() # e.g. 10
# pyre-fixme[16]: `int` has no attribute `cumsum`.
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone()