mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Utils for converting rasterization fragments of clipped meshes back to unclipped
Summary: This diff adds utils functions for converting rasterization fragments of the clipped mesh into fragments expressed in terms of the original unclipped mesh. The face indices and barycentric coordinates are converted in this step. The pixel to triangle distances are handled in the rasterizer which is updated in the next diff in the stack. Reviewed By: jcjohnson Differential Revision: D26169539 fbshipit-source-id: ba451d3facd60ef88a8ffaf25fd04ca07b449ceb
This commit is contained in:
parent
23279c5f1d
commit
39f49c22cd
@ -1,7 +1,12 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
from .clip import ClipFrustum, ClippedFaces, clip_faces
|
from .clip import (
|
||||||
|
ClipFrustum,
|
||||||
|
ClippedFaces,
|
||||||
|
clip_faces,
|
||||||
|
convert_clipped_rasterization_to_original_faces,
|
||||||
|
)
|
||||||
from .rasterize_meshes import rasterize_meshes
|
from .rasterize_meshes import rasterize_meshes
|
||||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||||
from .renderer import MeshRenderer
|
from .renderer import MeshRenderer
|
||||||
|
@ -598,3 +598,117 @@ def clip_faces(
|
|||||||
clipped_faces_neighbor_idx=clipped_faces_neighbor_idx,
|
clipped_faces_neighbor_idx=clipped_faces_neighbor_idx,
|
||||||
)
|
)
|
||||||
return clipped_faces
|
return clipped_faces
|
||||||
|
|
||||||
|
|
||||||
|
def convert_clipped_rasterization_to_original_faces(
|
||||||
|
pix_to_face_clipped, bary_coords_clipped, clipped_faces: ClippedFaces
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Convert rasterization Fragments (expressed as pix_to_face_clipped,
|
||||||
|
bary_coords_clipped, dists_clipped) of clipped Meshes computed using clip_faces()
|
||||||
|
to the corresponding rasterization Fragments where barycentric coordinates and
|
||||||
|
face indices are in terms of the original unclipped Meshes. The distances are
|
||||||
|
handled in the rasterizer C++/CUDA kernels (i.e. for Cases 1/3 the distance
|
||||||
|
can be used directly and for Case 4 triangles the distance of the pixel to
|
||||||
|
the closest of the two subdivided triangles is used).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pix_to_face_clipped: LongTensor of shape (N, image_size, image_size,
|
||||||
|
faces_per_pixel) giving the indices of the nearest faces at each pixel,
|
||||||
|
sorted in ascending z-order. Concretely
|
||||||
|
``pix_to_face_clipped[n, y, x, k] = f`` means that ``faces_verts_clipped[f]``
|
||||||
|
is the kth closest face (in the z-direction) to pixel (y, x). Pixels that
|
||||||
|
are hit by fewer than faces_per_pixel are padded with -1.
|
||||||
|
bary_coords_clipped: FloatTensor of shape
|
||||||
|
(N, image_size, image_size, faces_per_pixel, 3) giving the barycentric
|
||||||
|
coordinates in world coordinates of the nearest faces at each pixel, sorted
|
||||||
|
in ascending z-order. Concretely, if ``pix_to_face_clipped[n, y, x, k] = f``
|
||||||
|
then ``[w0, w1, w2] = bary_coords_clipped[n, y, x, k]`` gives the
|
||||||
|
barycentric coords for pixel (y, x) relative to the face defined by
|
||||||
|
``unproject(face_verts_clipped[f])``. Pixels hit by fewer than
|
||||||
|
faces_per_pixel are padded with -1.
|
||||||
|
clipped_faces: an instance of ClippedFaces class giving the auxillary variables
|
||||||
|
for converting rasterization outputs from clipped to unclipped Meshes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
3-tuple: (pix_to_face_unclipped, bary_coords_unclipped, dists_unclipped) that
|
||||||
|
have the same definition as (pix_to_face_clipped, bary_coords_clipped,
|
||||||
|
dists_clipped) except that they pertain to faces_verts_unclipped instead of
|
||||||
|
faces_verts_clipped (i.e the original meshes as opposed to the modified meshes)
|
||||||
|
"""
|
||||||
|
faces_clipped_to_unclipped_idx = clipped_faces.faces_clipped_to_unclipped_idx
|
||||||
|
|
||||||
|
# If no clipping or culling then return inputs
|
||||||
|
if faces_clipped_to_unclipped_idx is None:
|
||||||
|
return pix_to_face_clipped, bary_coords_clipped
|
||||||
|
|
||||||
|
device = pix_to_face_clipped.device
|
||||||
|
|
||||||
|
# Convert pix_to_face indices to now refer to the faces in the unclipped Meshes.
|
||||||
|
# Init empty tensor to fill in all the background values which have pix_to_face=-1.
|
||||||
|
empty = torch.full(pix_to_face_clipped.shape, -1, device=device, dtype=torch.int64)
|
||||||
|
pix_to_face_unclipped = torch.where(
|
||||||
|
pix_to_face_clipped != -1,
|
||||||
|
faces_clipped_to_unclipped_idx[pix_to_face_clipped],
|
||||||
|
empty,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For triangles that were clipped into smaller triangle(s), convert barycentric
|
||||||
|
# coordinates from being in terms of the clipped triangle to being in terms of the
|
||||||
|
# original unclipped triangle.
|
||||||
|
|
||||||
|
# barycentric_conversion is a (T, 3, 3) tensor such that
|
||||||
|
# alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :]
|
||||||
|
barycentric_conversion = clipped_faces.barycentric_conversion
|
||||||
|
|
||||||
|
# faces_clipped_to_conversion_idx is an (F_clipped,) shape tensor mapping each output
|
||||||
|
# face to the applicable row of barycentric_conversion (or set to -1 if conversion is
|
||||||
|
# not needed)
|
||||||
|
faces_clipped_to_conversion_idx = clipped_faces.faces_clipped_to_conversion_idx
|
||||||
|
|
||||||
|
if barycentric_conversion is not None:
|
||||||
|
bary_coords_unclipped = bary_coords_clipped.clone()
|
||||||
|
|
||||||
|
# Select the subset of faces that require conversion, where N is the sum
|
||||||
|
# number of case3/case4 triangles that are in the closest k triangles to some
|
||||||
|
# rasterized pixel.
|
||||||
|
pix_to_conversion_idx = torch.where(
|
||||||
|
pix_to_face_clipped != -1,
|
||||||
|
faces_clipped_to_conversion_idx[pix_to_face_clipped],
|
||||||
|
empty,
|
||||||
|
)
|
||||||
|
faces_to_convert_mask = pix_to_conversion_idx != -1
|
||||||
|
N = faces_to_convert_mask.sum().item()
|
||||||
|
|
||||||
|
# Expand to (N, H, W, K, 3) to be the same shape as barycentric coordinates
|
||||||
|
faces_to_convert_mask_expanded = faces_to_convert_mask[:, :, :, :, None].expand(
|
||||||
|
-1, -1, -1, -1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
# An (N,) dim tensor of indices into barycentric_conversion
|
||||||
|
conversion_idx_subset = pix_to_conversion_idx[faces_to_convert_mask]
|
||||||
|
|
||||||
|
# An (N, 3, 1) tensor of barycentric coordinates in terms of the clipped triangles
|
||||||
|
bary_coords_clipped_subset = bary_coords_clipped[faces_to_convert_mask_expanded]
|
||||||
|
bary_coords_clipped_subset = bary_coords_clipped_subset.reshape((N, 3, 1))
|
||||||
|
|
||||||
|
# An (N, 3, 3) tensor storing matrices to convert from clipped to unclipped
|
||||||
|
# barycentric coordinates
|
||||||
|
bary_conversion_subset = barycentric_conversion[conversion_idx_subset]
|
||||||
|
|
||||||
|
# An (N, 3, 1) tensor of barycentric coordinates in terms of the unclipped triangle
|
||||||
|
bary_coords_unclipped_subset = bary_conversion_subset.bmm(
|
||||||
|
bary_coords_clipped_subset
|
||||||
|
)
|
||||||
|
|
||||||
|
bary_coords_unclipped_subset = bary_coords_unclipped_subset.reshape([N * 3])
|
||||||
|
bary_coords_unclipped[
|
||||||
|
faces_to_convert_mask_expanded
|
||||||
|
] = bary_coords_unclipped_subset
|
||||||
|
|
||||||
|
# dists for case 4 faces will be handled in the rasterizer
|
||||||
|
# so no need to modify them here.
|
||||||
|
else:
|
||||||
|
bary_coords_unclipped = bary_coords_clipped
|
||||||
|
|
||||||
|
return pix_to_face_unclipped, bary_coords_unclipped
|
||||||
|
Loading…
x
Reference in New Issue
Block a user