diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index 1fac0b9e..a3b20fa5 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -153,8 +153,21 @@ def flat_shading(meshes, fragments, lights, cameras, materials, texels) -> torch face_normals = meshes.faces_normals_packed() # (V, 3) faces_verts = verts[faces] face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts - pixel_coords = face_coords[fragments.pix_to_face] - pixel_normals = face_normals[fragments.pix_to_face] + + # Replace empty pixels in pix_to_face with 0 in order to interpolate. + mask = fragments.pix_to_face == -1 + pix_to_face = fragments.pix_to_face.clone() + pix_to_face[mask] = 0 + + N, H, W, K = pix_to_face.shape + idx = pix_to_face.view(N * H * W * K, 1).expand(N * H * W * K, 3) + + # gather pixel coords + pixel_coords = face_coords.gather(0, idx).view(N, H, W, K, 3) + pixel_coords[mask] = 0.0 + # gather pixel normals + pixel_normals = face_normals.gather(0, idx).view(N, H, W, K, 3) + pixel_normals[mask] = 0.0 # Calculate the illumination at each face ambient, diffuse, specular = _apply_lighting(