diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index f5c249b3..54f3bb47 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import NamedTuple, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -13,13 +13,56 @@ import torch.nn as nn from .rasterize_meshes import rasterize_meshes -# Class to store the outputs of mesh rasterization -class Fragments(NamedTuple): +@dataclass(frozen=True) +class Fragments: + """ + Members: + pix_to_face: + LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving + the indices of the nearest faces at each pixel, sorted in ascending + z-order. Concretely ``pix_to_face[n, y, x, k] = f`` means that + ``faces_verts[f]`` is the kth closest face (in the z-direction) to pixel + (y, x). Pixels that are hit by fewer than faces_per_pixel are padded with + -1. + + zbuf: + FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving + the NDC z-coordinates of the nearest faces at each pixel, sorted in + ascending z-order. Concretely, if ``pix_to_face[n, y, x, k] = f`` then + ``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than + faces_per_pixel are padded with -1. + + bary_coords: + FloatTensor of shape (N, image_size, image_size, faces_per_pixel, 3) + giving the barycentric coordinates in NDC units of the nearest faces at + each pixel, sorted in ascending z-order. Concretely, if ``pix_to_face[n, + y, x, k] = f`` then ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives the + barycentric coords for pixel (y, x) relative to the face defined by + ``face_verts[f]``. Pixels hit by fewer than faces_per_pixel are padded + with -1. + + dists: + FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving + the signed Euclidean distance (in NDC units) in the x/y plane of each + point closest to the pixel. Concretely if ``pix_to_face[n, y, x, k] = f`` + then ``pix_dists[n, y, x, k]`` is the squared distance between the pixel + (y, x) and the face given by vertices ``face_verts[f]``. Pixels hit with + fewer than ``faces_per_pixel`` are padded with -1. + """ + pix_to_face: torch.Tensor zbuf: torch.Tensor bary_coords: torch.Tensor dists: torch.Tensor + def detach(self) -> "Fragments": + return Fragments( + pix_to_face=self.pix_to_face, + zbuf=self.zbuf.detach(), + bary_coords=self.bary_coords.detach(), + dists=self.dists.detach(), + ) + @dataclass class RasterizationSettings: