diff --git a/pytorch3d/renderer/mesh/shading.py b/pytorch3d/renderer/mesh/shading.py index 2a142f28..05cb66ad 100644 --- a/pytorch3d/renderer/mesh/shading.py +++ b/pytorch3d/renderer/mesh/shading.py @@ -55,6 +55,46 @@ def _apply_lighting( return ambient_color, diffuse_color, specular_color +def _phong_shading_with_pixels( + meshes, fragments, lights, cameras, materials, texels +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply per pixel shading. First interpolate the vertex normals and + vertex coordinates using the barycentric coordinates to get the position + and normal at each pixel. Then compute the illumination for each pixel. + The pixel color is obtained by multiplying the pixel textures by the ambient + and diffuse illumination and adding the specular component. + + Args: + meshes: Batch of meshes + fragments: Fragments named tuple with the outputs of rasterization + lights: Lights class containing a batch of lights + cameras: Cameras class containing a batch of cameras + materials: Materials class containing a batch of material properties + texels: texture per pixel of shape (N, H, W, K, 3) + + Returns: + colors: (N, H, W, K, 3) + pixel_coords: (N, H, W, K, 3), camera coordinates of each intersection. + """ + verts = meshes.verts_packed() # (V, 3) + faces = meshes.faces_packed() # (F, 3) + vertex_normals = meshes.verts_normals_packed() # (V, 3) + faces_verts = verts[faces] + faces_normals = vertex_normals[faces] + pixel_coords_in_camera = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_verts + ) + pixel_normals = interpolate_face_attributes( + fragments.pix_to_face, fragments.bary_coords, faces_normals + ) + ambient, diffuse, specular = _apply_lighting( + pixel_coords_in_camera, pixel_normals, lights, cameras, materials + ) + colors = (ambient + diffuse) * texels + specular + return colors, pixel_coords_in_camera + + def phong_shading( meshes, fragments, lights, cameras, materials, texels ) -> torch.Tensor: @@ -76,21 +116,9 @@ def phong_shading( Returns: colors: (N, H, W, K, 3) """ - verts = meshes.verts_packed() # (V, 3) - faces = meshes.faces_packed() # (F, 3) - vertex_normals = meshes.verts_normals_packed() # (V, 3) - faces_verts = verts[faces] - faces_normals = vertex_normals[faces] - pixel_coords = interpolate_face_attributes( - fragments.pix_to_face, fragments.bary_coords, faces_verts + colors, _ = _phong_shading_with_pixels( + meshes, fragments, lights, cameras, materials, texels ) - pixel_normals = interpolate_face_attributes( - fragments.pix_to_face, fragments.bary_coords, faces_normals - ) - ambient, diffuse, specular = _apply_lighting( - pixel_coords, pixel_normals, lights, cameras, materials - ) - colors = (ambient + diffuse) * texels + specular return colors