Add Fragments.detach()

Summary: Add a capability to detach all detachable tensors in Fragments.

Reviewed By: bottler

Differential Revision: D35918133

fbshipit-source-id: 03b5d4491a3a6791b0a7bc9119f26c1a7aa43196
This commit is contained in:
Krzysztof Chalupka 2022-05-11 18:50:24 -07:00 committed by Facebook GitHub Bot
parent d737a05e55
commit c21ba144e7

View File

@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import NamedTuple, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
@ -13,13 +13,56 @@ import torch.nn as nn
from .rasterize_meshes import rasterize_meshes
# Class to store the outputs of mesh rasterization
class Fragments(NamedTuple):
@dataclass(frozen=True)
class Fragments:
"""
Members:
pix_to_face:
LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving
the indices of the nearest faces at each pixel, sorted in ascending
z-order. Concretely ``pix_to_face[n, y, x, k] = f`` means that
``faces_verts[f]`` is the kth closest face (in the z-direction) to pixel
(y, x). Pixels that are hit by fewer than faces_per_pixel are padded with
-1.
zbuf:
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
the NDC z-coordinates of the nearest faces at each pixel, sorted in
ascending z-order. Concretely, if ``pix_to_face[n, y, x, k] = f`` then
``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
faces_per_pixel are padded with -1.
bary_coords:
FloatTensor of shape (N, image_size, image_size, faces_per_pixel, 3)
giving the barycentric coordinates in NDC units of the nearest faces at
each pixel, sorted in ascending z-order. Concretely, if ``pix_to_face[n,
y, x, k] = f`` then ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives the
barycentric coords for pixel (y, x) relative to the face defined by
``face_verts[f]``. Pixels hit by fewer than faces_per_pixel are padded
with -1.
dists:
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
the signed Euclidean distance (in NDC units) in the x/y plane of each
point closest to the pixel. Concretely if ``pix_to_face[n, y, x, k] = f``
then ``pix_dists[n, y, x, k]`` is the squared distance between the pixel
(y, x) and the face given by vertices ``face_verts[f]``. Pixels hit with
fewer than ``faces_per_pixel`` are padded with -1.
"""
pix_to_face: torch.Tensor
zbuf: torch.Tensor
bary_coords: torch.Tensor
dists: torch.Tensor
def detach(self) -> "Fragments":
return Fragments(
pix_to_face=self.pix_to_face,
zbuf=self.zbuf.detach(),
bary_coords=self.bary_coords.detach(),
dists=self.dists.detach(),
)
@dataclass
class RasterizationSettings: