mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Add Fragments.detach()
Summary: Add a capability to detach all detachable tensors in Fragments. Reviewed By: bottler Differential Revision: D35918133 fbshipit-source-id: 03b5d4491a3a6791b0a7bc9119f26c1a7aa43196
This commit is contained in:
parent
d737a05e55
commit
c21ba144e7
@ -5,7 +5,7 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import NamedTuple, Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -13,13 +13,56 @@ import torch.nn as nn
|
||||
from .rasterize_meshes import rasterize_meshes
|
||||
|
||||
|
||||
# Class to store the outputs of mesh rasterization
|
||||
class Fragments(NamedTuple):
|
||||
@dataclass(frozen=True)
|
||||
class Fragments:
|
||||
"""
|
||||
Members:
|
||||
pix_to_face:
|
||||
LongTensor of shape (N, image_size, image_size, faces_per_pixel) giving
|
||||
the indices of the nearest faces at each pixel, sorted in ascending
|
||||
z-order. Concretely ``pix_to_face[n, y, x, k] = f`` means that
|
||||
``faces_verts[f]`` is the kth closest face (in the z-direction) to pixel
|
||||
(y, x). Pixels that are hit by fewer than faces_per_pixel are padded with
|
||||
-1.
|
||||
|
||||
zbuf:
|
||||
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
|
||||
the NDC z-coordinates of the nearest faces at each pixel, sorted in
|
||||
ascending z-order. Concretely, if ``pix_to_face[n, y, x, k] = f`` then
|
||||
``zbuf[n, y, x, k] = face_verts[f, 2]``. Pixels hit by fewer than
|
||||
faces_per_pixel are padded with -1.
|
||||
|
||||
bary_coords:
|
||||
FloatTensor of shape (N, image_size, image_size, faces_per_pixel, 3)
|
||||
giving the barycentric coordinates in NDC units of the nearest faces at
|
||||
each pixel, sorted in ascending z-order. Concretely, if ``pix_to_face[n,
|
||||
y, x, k] = f`` then ``[w0, w1, w2] = barycentric[n, y, x, k]`` gives the
|
||||
barycentric coords for pixel (y, x) relative to the face defined by
|
||||
``face_verts[f]``. Pixels hit by fewer than faces_per_pixel are padded
|
||||
with -1.
|
||||
|
||||
dists:
|
||||
FloatTensor of shape (N, image_size, image_size, faces_per_pixel) giving
|
||||
the signed Euclidean distance (in NDC units) in the x/y plane of each
|
||||
point closest to the pixel. Concretely if ``pix_to_face[n, y, x, k] = f``
|
||||
then ``pix_dists[n, y, x, k]`` is the squared distance between the pixel
|
||||
(y, x) and the face given by vertices ``face_verts[f]``. Pixels hit with
|
||||
fewer than ``faces_per_pixel`` are padded with -1.
|
||||
"""
|
||||
|
||||
pix_to_face: torch.Tensor
|
||||
zbuf: torch.Tensor
|
||||
bary_coords: torch.Tensor
|
||||
dists: torch.Tensor
|
||||
|
||||
def detach(self) -> "Fragments":
|
||||
return Fragments(
|
||||
pix_to_face=self.pix_to_face,
|
||||
zbuf=self.zbuf.detach(),
|
||||
bary_coords=self.bary_coords.detach(),
|
||||
dists=self.dists.detach(),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RasterizationSettings:
|
||||
|
Loading…
x
Reference in New Issue
Block a user