Utils for clipping mesh faces partially behind the image plane

Summary:
Instead of culling faces behind the camera, partially clip them if they intersect with the image plane.

This diff implements the utils functions for clipping.

There are 4 cases for the mesh faces which are all handled:

```
Case 1: the triangle is completely in front of the clipping plane (it is left
        unchanged)
Case 2: the triangle is completely behind the clipping plane (it is culled)
Case 3: the triangle has exactly two vertices behind the clipping plane (it is
        clipped into a smaller triangle)
Case 4: the triangle has exactly one vertex behind the clipping plane (it is clipped
        into a smaller quadrilateral and divided into two triangular faces)
```

Reviewed By: jcjohnson

Differential Revision: D23108673

fbshipit-source-id: 550a8b6a982d06065dff10aba10d47e8b144ae52
This commit is contained in:
Nikhila Ravi 2021-02-05 18:26:14 -08:00 committed by Facebook GitHub Bot
parent db6fbfad90
commit 23279c5f1d
3 changed files with 953 additions and 0 deletions

View File

@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .clip import ClipFrustum, ClippedFaces, clip_faces
from .rasterize_meshes import rasterize_meshes
from .rasterizer import MeshRasterizer, RasterizationSettings
from .renderer import MeshRenderer

View File

@ -0,0 +1,600 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import Any, List, Optional, Tuple
import torch
"""
Mesh clipping is done before rasterization and is implemented using 4 cases
(these will be referred to throughout the functions below)
Case 1: the triangle is completely in front of the clipping plane (it is left
unchanged)
Case 2: the triangle is completely behind the clipping plane (it is culled)
Case 3: the triangle has exactly two vertices behind the clipping plane (it is
clipped into a smaller triangle)
Case 4: the triangle has exactly one vertex behind the clipping plane (it is clipped
into a smaller quadrilateral and divided into two triangular faces)
After rasterization, the Fragments from the clipped/modified triangles
are mapped back to the triangles in the original mesh. The indices,
barycentric coordinates and distances are all relative to original mesh triangles.
NOTE: It is assumed that all z-coordinates are in world coordinates (not NDC
coordinates), while x and y coordinates may be in NDC/screen coordinates
(i.e after applying a projective transform e.g. cameras.transform_points(points)).
"""
class ClippedFaces:
"""
Helper class to store the data for the clipped version of a Meshes object
(face_verts, mesh_to_face_first_idx, num_faces_per_mesh) along with
conversion information (faces_clipped_to_unclipped_idx, barycentric_conversion,
faces_clipped_to_conversion_idx, clipped_faces_neighbor_idx) required to convert
barycentric coordinates from rasterization of the clipped Meshes to barycentric
coordinates in terms of the unclipped Meshes.
Args:
face_verts: FloatTensor of shape (F_clipped, 3, 3) giving the verts of
each of the clipped faces
mesh_to_face_first_idx: an tensor of shape (N,), where N is the number of meshes
in the batch. The ith element stores the index into face_verts
of the first face of the ith mesh.
num_faces_per_mesh: a tensor of shape (N,) storing the number of faces in each mesh.
faces_clipped_to_unclipped_idx: (F_clipped,) shaped LongTensor mapping each clipped
face back to the face in faces_unclipped (i.e. the faces in the original meshes
obtained using meshes.faces_packed())
barycentric_conversion: (T, 3, 3) FloatTensor, where barycentric_conversion[i, :, k]
stores the barycentric weights in terms of the world coordinates of the original
(big) unclipped triangle for the kth vertex in the clipped (small) triangle.
If the rasterizer then expresses some NDC coordinate in terms of barycentric
world coordinates for the clipped (small) triangle as alpha_clipped[i,:],
alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :]
faces_clipped_to_conversion_idx: (F_clipped,) shaped LongTensor mapping each clipped
face to the applicable row of barycentric_conversion (or set to -1 if conversion is
not needed).
clipped_faces_neighbor_idx: LongTensor of shape (F_clipped,) giving the index of the
neighboring face for each case 4 triangle. e.g. for a case 4 face with f split
into two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx.
Faces which are not clipped and subdivided are set to -1 (i.e cases 1/2/3).
"""
__slots__ = [
"face_verts",
"mesh_to_face_first_idx",
"num_faces_per_mesh",
"faces_clipped_to_unclipped_idx",
"barycentric_conversion",
"faces_clipped_to_conversion_idx",
"clipped_faces_neighbor_idx",
]
def __init__(
self,
face_verts: torch.Tensor,
mesh_to_face_first_idx: torch.Tensor,
num_faces_per_mesh: torch.Tensor,
faces_clipped_to_unclipped_idx: Optional[torch.Tensor] = None,
barycentric_conversion: Optional[torch.Tensor] = None,
faces_clipped_to_conversion_idx: Optional[torch.Tensor] = None,
clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
):
self.face_verts = face_verts
self.mesh_to_face_first_idx = mesh_to_face_first_idx
self.num_faces_per_mesh = num_faces_per_mesh
self.faces_clipped_to_unclipped_idx = faces_clipped_to_unclipped_idx
self.barycentric_conversion = barycentric_conversion
self.faces_clipped_to_conversion_idx = faces_clipped_to_conversion_idx
self.clipped_faces_neighbor_idx = clipped_faces_neighbor_idx
class ClipFrustum:
"""
Helper class to store the information needed to represent a view frustum
(left, right, top, bottom, znear, zfar), which is used to clip or cull triangles.
Values left as None mean that culling should not be performed for that axis.
The parameters perspective_correct, cull, and z_clip_value are used to define
behavior for clipping triangles to the frustum.
Args:
left: NDC coordinate of the left clipping plane (along x axis)
right: NDC coordinate of the right clipping plane (along x axis)
top: NDC coordinate of the top clipping plane (along y axis)
bottom: NDC coordinate of the bottom clipping plane (along y axis)
znear: world space z coordinate of the near clipping plane
zfar: world space z coordinate of the far clipping plane
perspective_correct: should be set to True for a perspective camera
cull: if True, triangles outside the frustum should be culled
z_clip_value: if not None, then triangles should be clipped (possibly into
smaller triangles) such that z >= z_clip_value. This avoids projections
that go to infinity as z->0
"""
__slots__ = [
"left",
"right",
"top",
"bottom",
"znear",
"zfar",
"perspective_correct",
"cull",
"z_clip_value",
]
def __init__(
self,
left: Optional[float] = None,
right: Optional[float] = None,
top: Optional[float] = None,
bottom: Optional[float] = None,
znear: Optional[float] = None,
zfar: Optional[float] = None,
perspective_correct: bool = False,
cull: bool = True,
z_clip_value: Optional[float] = None,
):
self.left = left
self.right = right
self.top = top
self.bottom = bottom
self.znear = znear
self.zfar = zfar
self.perspective_correct = perspective_correct
self.cull = cull
self.z_clip_value = z_clip_value
def _get_culled_faces(face_verts: torch.Tensor, frustum: ClipFrustum) -> torch.Tensor:
"""
Helper function used to find all the faces in Meshes which are
fully outside the view frustum. A face is culled if all 3 vertices are outside
the same axis of the view frustum.
Args:
face_verts: An (F,3,3) tensor, where F is the number of faces in
the packed representation of Meshes. The 2nd dimension represents the 3 vertices
of a triangle, and the 3rd dimension stores the xyz locations of each
vertex.
frustum: An instance of the ClipFrustum class with the information on the
position of the clipping planes.
Returns:
faces_culled: An boolean tensor of size F specifying whether or not each face should be
culled.
"""
clipping_planes = (
(frustum.left, 0, "<"),
(frustum.right, 0, ">"),
(frustum.top, 1, "<"),
(frustum.bottom, 1, ">"),
(frustum.znear, 2, "<"),
(frustum.zfar, 2, ">"),
)
faces_culled = torch.zeros(
[face_verts.shape[0]], dtype=torch.bool, device=face_verts.device
)
for plane in clipping_planes:
clip_value, axis, op = plane
# If clip_value is None then don't clip along that plane
if frustum.cull and clip_value is not None:
if op == "<":
verts_clipped = face_verts[:, axis] < clip_value
else:
verts_clipped = face_verts[:, axis] > clip_value
# If all verts are clipped then face is outside the frustum
faces_culled |= verts_clipped.sum(1) == 3
return faces_culled
def _find_verts_intersecting_clipping_plane(
face_verts: torch.Tensor,
p1_face_ind: torch.Tensor,
clip_value: float,
perspective_correct: bool,
) -> Tuple[Tuple[Any, Any, Any, Any, Any], List[Any]]:
r"""
Helper function to find the vertices used to form a new triangle for case 3/case 4 faces.
Given a list of triangles that are already known to intersect the clipping plane,
solve for the two vertices p4 and p5 where the edges of the triangle intersects the
clipping plane.
p1
/\
/ \
/ t \
_____________p4/______\p5__________ clip_value
/ \
/____ \
p2 ---____\p3
Args:
face_verts: An (F,3,3) tensor, where F is the number of faces in
the packed representation of the Meshes, the 2nd dimension represents
the 3 vertices of the face, and the 3rd dimension stores the xyz locations of each
vertex. The z-coordinates must be represented in world coordinates, while
the xy-coordinates may be in NDC/screen coordinates (i.e. after projection).
p1_face_ind: A tensor of shape (N,) with values in the range of 0 to 2. In each
case 3/case 4 triangle, two vertices are on the same side of the
clipping plane and the 3rd is on the other side. p1_face_ind stores the index of
the vertex that is not on the same side as any other vertex in the triangle.
clip_value: Float, the z-value defining where to clip the triangle.
perspective_correct: Bool, Should be set to true if a perspective camera was
used and xy-coordinates of face_verts_unclipped are in NDC/screen coordinates.
Returns:
A 2-tuple
p: (p1, p2, p3, p4, p5))
p_barycentric (p1_bary, p2_bary, p3_bary, p4_bary, p5_bary)
Each of p1...p5 is an (F,3) tensor of the xyz locations of the 5 points in the
diagram above for case 3/case 4 faces. Each p1_bary...p5_bary is an (F, 3) tensor
storing the barycentric weights used to encode p1...p5 in terms of the the original
unclipped triangle.
"""
# Let T be number of triangles in face_verts (note that these correspond to the subset
# of case 1 or case 2 triangles). p1_face_ind, p2_face_ind, and p3_face_ind are (T)
# tensors with values in the range of 0 to 2. p1_face_ind stores the index of the
# vertex that is not on the same side as any other vertex in the triangle, and
# p2_face_ind and p3_face_ind are the indices of the other two vertices preserving
# the same counterclockwise or clockwise ordering
T = face_verts.shape[0]
p2_face_ind = torch.remainder(p1_face_ind + 1, 3)
p3_face_ind = torch.remainder(p1_face_ind + 2, 3)
# p1, p2, p3 are (T, 3) tensors storing the corresponding (x, y, z) coordinates
# of p1_face_ind, p2_face_ind, p3_face_ind
# pyre-ignore[16]
p1 = face_verts.gather(1, p1_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1)
p2 = face_verts.gather(1, p2_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1)
p3 = face_verts.gather(1, p3_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1)
##################################
# Solve for intersection point p4
##################################
# p4 is a (T, 3) tensor is the point on the segment between p1 and p2 that
# intersects the clipping plane.
# Solve for the weight w2 such that p1.z*(1-w2) + p2.z*w2 = clip_value.
# Then interpolate p4 = p1*(1-w2) + p2*w2 where it is assumed that z-coordinates
# are expressed in world coordinates (since we want to clip z in world coordinates).
w2 = (p1[:, 2] - clip_value) / (p1[:, 2] - p2[:, 2])
p4 = p1 * (1 - w2[:, None]) + p2 * w2[:, None]
if perspective_correct:
# It is assumed that all z-coordinates are in world coordinates (not NDC
# coordinates), while x and y coordinates may be in NDC/screen coordinates.
# If x and y are in NDC/screen coordinates and a projective transform was used
# in a perspective camera, then we effectively want to:
# 1. Convert back to world coordinates (by multiplying by z)
# 2. Interpolate using w2
# 3. Convert back to NDC/screen coordinates (by dividing by the new z=clip_value)
p1_world = p1[:, :2] * p1[:, 2:3]
p2_world = p2[:, :2] * p2[:, 2:3]
p4[:, :2] = (p1_world * (1 - w2[:, None]) + p2_world * w2[:, None]) / clip_value
##################################
# Solve for intersection point p5
##################################
# p5 is a (T, 3) tensor representing the point on the segment between p1 and p3 that
# intersects the clipping plane.
# Solve for the weight w3 such that p1.z * (1-w3) + p2.z * w3 = clip_value,
# and then interpolate p5 = p1 * (1-w3) + p3 * w3
w3 = (p1[:, 2] - clip_value) / (p1[:, 2] - p3[:, 2])
w3 = w3.detach()
p5 = p1 * (1 - w3[:, None]) + p3 * w3[:, None]
if perspective_correct:
# Again if using a perspective camera, convert back to world coordinates
# interpolate and convert back
p1_world = p1[:, :2] * p1[:, 2:3]
p3_world = p3[:, :2] * p3[:, 2:3]
p5[:, :2] = (p1_world * (1 - w3[:, None]) + p3_world * w3[:, None]) / clip_value
# Set the barycentric coordinates of p1,p2,p3,p4,p5 in terms of the original
# unclipped triangle in face_verts.
T_idx = torch.arange(T, device=face_verts.device)
p_barycentric = [torch.zeros((T, 3), device=face_verts.device) for i in range(5)]
p_barycentric[0][(T_idx, p1_face_ind)] = 1
p_barycentric[1][(T_idx, p2_face_ind)] = 1
p_barycentric[2][(T_idx, p3_face_ind)] = 1
p_barycentric[3][(T_idx, p1_face_ind)] = 1 - w2
p_barycentric[3][(T_idx, p2_face_ind)] = w2
p_barycentric[4][(T_idx, p1_face_ind)] = 1 - w3
p_barycentric[4][(T_idx, p3_face_ind)] = w3
p = (p1, p2, p3, p4, p5)
return p, p_barycentric
###################
# Main Entry point
###################
def clip_faces(
face_verts_unclipped: torch.Tensor,
mesh_to_face_first_idx: torch.Tensor,
num_faces_per_mesh: torch.Tensor,
frustum: ClipFrustum,
) -> ClippedFaces:
"""
Clip a mesh to the portion contained within a view frustum and with z > z_clip_value.
There are two types of clipping:
1) Cull triangles that are completely outside the view frustum. This is purely
to save computation by reducing the number of triangles that need to be
rasterized.
2) Clip triangles into the portion of the triangle where z > z_clip_value. The
clipped region may be a quadrilateral, which results in splitting a triangle
into two triangles. This does not save computation, but is necessary to
correctly rasterize using perspective cameras for triangles that pass through
z <= 0, because NDC/screen coordinates go to infinity at z=0.
Args:
face_verts_unclipped: An (F, 3, 3) tensor, where F is the number of faces in
the packed representation of Meshes, the 2nd dimension represents the 3 vertices
of the triangle, and the 3rd dimension stores the xyz locations of each
vertex. The z-coordinates must be represented in world coordinates, while
the xy-coordinates may be in NDC/screen coordinates
mesh_to_face_first_idx: an tensor of shape (N,), where N is the number of meshes
in the batch. The ith element stores the index into face_verts_unclipped
of the first face of the ith mesh.
num_faces_per_mesh: a tensor of shape (N,) storing the number of faces in each mesh.
frustum: a ClipFrustum object defining the frustum used to cull faces.
Returns:
clipped_faces: ClippedFaces object storing a clipped version of the Meshes
along with tensors that can be used to convert barycentric coordinates
returned by rasterization of the clipped meshes into a barycentric
coordinates for the unclipped meshes.
"""
F = face_verts_unclipped.shape[0]
device = face_verts_unclipped.device
# Triangles completely outside the view frustum will be culled
# faces_culled is of shape (F, )
faces_culled = _get_culled_faces(face_verts_unclipped, frustum)
# Triangles that are partially behind the z clipping plane will be clipped to
# smaller triangles
z_clip_value = frustum.z_clip_value
perspective_correct = frustum.perspective_correct
if z_clip_value is not None:
# (F, 3) tensor (where F is the number of triangles) marking whether each vertex
# in a triangle is behind the clipping plane
faces_clipped_verts = face_verts_unclipped[:, :, 2] < z_clip_value
# (F) dim tensor containing the number of clipped vertices in each triangle
faces_num_clipped_verts = faces_clipped_verts.sum(1)
else:
faces_num_clipped_verts = torch.zeros([F, 3], device=device)
# If no triangles need to be clipped or culled, avoid unnecessary computation
# and return early
if faces_num_clipped_verts.sum().item() == 0 and faces_culled.sum().item() == 0:
return ClippedFaces(
face_verts=face_verts_unclipped,
mesh_to_face_first_idx=mesh_to_face_first_idx,
num_faces_per_mesh=num_faces_per_mesh,
)
#####################################################################################
# Classify faces into the 4 relevant cases:
# 1) The triangle is completely in front of the clipping plane (it is left
# unchanged)
# 2) The triangle is completely behind the clipping plane (it is culled)
# 3) The triangle has exactly two vertices behind the clipping plane (it is
# clipped into a smaller triangle)
# 4) The triangle has exactly one vertex behind the clipping plane (it is clipped
# into a smaller quadrilateral and split into two triangles)
#####################################################################################
# pyre-ignore[16]:
faces_unculled = ~faces_culled
# Case 1: no clipped verts or culled faces
cases1_unclipped = torch.logical_and(faces_num_clipped_verts == 0, faces_unculled)
case1_unclipped_idx = cases1_unclipped.nonzero(as_tuple=True)[0]
# Case 2: all verts clipped
case2_unclipped = torch.logical_or(faces_num_clipped_verts == 3, faces_culled)
# Case 3: two verts clipped
case3_unclipped = torch.logical_and(faces_num_clipped_verts == 2, faces_unculled)
case3_unclipped_idx = case3_unclipped.nonzero(as_tuple=True)[0]
# Case 4: one vert clipped
case4_unclipped = torch.logical_and(faces_num_clipped_verts == 1, faces_unculled)
case4_unclipped_idx = case4_unclipped.nonzero(as_tuple=True)[0]
# faces_unclipped_to_clipped_idx is an (F) dim tensor storing the index of each
# face to the corresponding face in face_verts_clipped.
# Each case 2 triangle will be culled (deleted from face_verts_clipped),
# while each case 4 triangle will be split into two smaller triangles
# (replaced by two consecutive triangles in face_verts_clipped)
# case2_unclipped is an (F,) dim 0/1 tensor of all the case2 faces
# case4_unclipped is an (F,) dim 0/1 tensor of all the case4 faces
faces_delta = case4_unclipped.int() - case2_unclipped.int()
# faces_delta_cum gives the per face change in index. Faces which are
# clipped in the original mesh are mapped to the closest non clipped face
# in face_verts_clipped (this doesn't matter as they are not used
# during rasterization anyway).
faces_delta_cum = faces_delta.cumsum(0) - faces_delta
delta = 1 + case4_unclipped.int() - case2_unclipped.int()
# pyre-ignore[16]
faces_unclipped_to_clipped_idx = delta.cumsum(0) - delta
###########################################
# Allocate tensors for the output Meshes.
# These will then be filled in for each case.
###########################################
F_clipped = (
F + faces_delta_cum[-1].item() + faces_delta[-1].item()
) # Total number of faces in the new Meshes
face_verts_clipped = torch.zeros(
(F_clipped, 3, 3), dtype=face_verts_unclipped.dtype, device=device
)
faces_clipped_to_unclipped_idx = torch.zeros(
[F_clipped], dtype=torch.int64, device=device
)
# Update version of mesh_to_face_first_idx and num_faces_per_mesh applicable to
# face_verts_clipped
mesh_to_face_first_idx_clipped = faces_unclipped_to_clipped_idx[
mesh_to_face_first_idx
]
F_clipped_t = torch.full([1], F_clipped, dtype=torch.int64, device=device)
num_faces_next = torch.cat((mesh_to_face_first_idx_clipped[1:], F_clipped_t))
num_faces_per_mesh_clipped = num_faces_next - mesh_to_face_first_idx_clipped
################# Start Case 1 ########################################
# Case 1: Triangles are fully visible, copy unchanged triangles into the
# appropriate position in the new list of faces
case1_clipped_idx = faces_unclipped_to_clipped_idx[case1_unclipped_idx]
face_verts_clipped[case1_clipped_idx] = face_verts_unclipped[case1_unclipped_idx]
faces_clipped_to_unclipped_idx[case1_clipped_idx] = case1_unclipped_idx
# If no triangles need to be clipped but some triangles were culled, avoid
# unnecessary clipping computation
if case3_unclipped_idx.shape[0] + case4_unclipped_idx.shape[0] == 0:
return ClippedFaces(
face_verts=face_verts_clipped,
mesh_to_face_first_idx=mesh_to_face_first_idx_clipped,
num_faces_per_mesh=num_faces_per_mesh_clipped,
faces_clipped_to_unclipped_idx=faces_clipped_to_unclipped_idx,
)
################# End Case 1 ##########################################
################# Start Case 3 ########################################
# Case 3: exactly two vertices are behind the camera, clipping the triangle into a
# triangle. In the diagram below, we clip the bottom part of the triangle, and add
# new vertices p4 and p5 by intersecting with the clipping plane. The updated
# triangle is the triangle between p4, p1, p5
#
# p1 (unclipped vertex)
# /\
# / \
# / t \
# _____________p4/______\p5__________ clip_value
# xxxxxxxxxxxxxx/ \xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxx/____ \xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
# xxxxxxxxxx p2 xxxx---____\p3 xxxxxxxxxxxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
faces_case3 = face_verts_unclipped[case3_unclipped_idx]
# index (0, 1, or 2) of the vertex in front of the clipping plane
p1_face_ind = torch.where(~faces_clipped_verts[case3_unclipped_idx])[1]
# Solve for the points p4, p5 that intersect the clipping plane
p, p_barycentric = _find_verts_intersecting_clipping_plane(
faces_case3, p1_face_ind, z_clip_value, perspective_correct
)
p1, _, _, p4, p5 = p
p1_barycentric, _, _, p4_barycentric, p5_barycentric = p_barycentric
# Store clipped triangle
case3_clipped_idx = faces_unclipped_to_clipped_idx[case3_unclipped_idx]
t_barycentric = torch.stack((p4_barycentric, p5_barycentric, p1_barycentric), 2)
face_verts_clipped[case3_clipped_idx] = torch.stack((p4, p5, p1), 1)
faces_clipped_to_unclipped_idx[case3_clipped_idx] = case3_unclipped_idx
################# End Case 3 ##########################################
################# Start Case 4 ########################################
# Case 4: exactly one vertex is behind the camera, clip the triangle into a
# quadrilateral. In the diagram below, we clip the bottom part of the triangle,
# and add new vertices p4 and p5 by intersecting with the cliiping plane. The
# unclipped region is a quadrilateral, which is split into two triangles:
# t1: p4, p2, p5
# t2: p5, p2, p3
#
# p3_____________________p2
# \ __--/
# \ t2 __-- /
# \ __-- t1 /
# ______________p5\__--_________/p4_________clip_value
# xxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx
# xxxxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx
# p1 (clipped vertex)
faces_case4 = face_verts_unclipped[case4_unclipped_idx]
# index (0, 1, or 2) of the vertex behind the clipping plane
p1_face_ind = torch.where(faces_clipped_verts[case4_unclipped_idx])[1]
# Solve for the points p4, p5 that intersect the clipping plane
p, p_barycentric = _find_verts_intersecting_clipping_plane(
faces_case4, p1_face_ind, z_clip_value, perspective_correct
)
_, p2, p3, p4, p5 = p
_, p2_barycentric, p3_barycentric, p4_barycentric, p5_barycentric = p_barycentric
# Store clipped triangles
case4_clipped_idx = faces_unclipped_to_clipped_idx[case4_unclipped_idx]
face_verts_clipped[case4_clipped_idx] = torch.stack((p4, p2, p5), 1)
face_verts_clipped[case4_clipped_idx + 1] = torch.stack((p5, p2, p3), 1)
t1_barycentric = torch.stack((p4_barycentric, p2_barycentric, p5_barycentric), 2)
t2_barycentric = torch.stack((p5_barycentric, p2_barycentric, p3_barycentric), 2)
faces_clipped_to_unclipped_idx[case4_clipped_idx] = case4_unclipped_idx
faces_clipped_to_unclipped_idx[case4_clipped_idx + 1] = case4_unclipped_idx
##################### End Case 4 #########################
# Triangles that were clipped (case 3 & case 4) will require conversion of
# barycentric coordinates from being in terms of the smaller clipped triangle to in terms
# of the original big triangle. If there are T clipped triangles,
# barycentric_conversion is a (T, 3, 3) tensor, where barycentric_conversion[i, :, k]
# stores the barycentric weights in terms of the world coordinates of the original
# (big) triangle for the kth vertex in the clipped (small) triangle. If our
# rasterizer then expresses some NDC coordinate in terms of barycentric
# world coordinates for the clipped (small) triangle as alpha_clipped[i,:],
# alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :]
barycentric_conversion = torch.cat((t_barycentric, t1_barycentric, t2_barycentric))
# faces_clipped_to_conversion_idx is an (F_clipped,) shape tensor mapping each output
# face to the applicable row of barycentric_conversion (or set to -1 if conversion is
# not needed)
faces_to_convert_idx = torch.cat(
(case3_clipped_idx, case4_clipped_idx, case4_clipped_idx + 1), 0
)
barycentric_idx = torch.arange(
barycentric_conversion.shape[0], dtype=torch.int64, device=device
)
faces_clipped_to_conversion_idx = torch.full(
[F_clipped], -1, dtype=torch.int64, device=device
)
faces_clipped_to_conversion_idx[faces_to_convert_idx] = barycentric_idx
# clipped_faces_quadrilateral_ind is an (F_clipped) dim tensor
# For case 4 clipped triangles (where a big triangle is split in two smaller triangles),
# store the index of the neighboring clipped triangle.
# This will be needed because if the soft rasterizer includes both
# triangles in the list of top K nearest triangles, we
# should only use the one with the smaller distance.
clipped_faces_neighbor_idx = torch.full(
[F_clipped], -1, dtype=torch.int64, device=device
)
clipped_faces_neighbor_idx[case4_clipped_idx] = case4_clipped_idx + 1
clipped_faces_neighbor_idx[case4_clipped_idx + 1] = case4_clipped_idx
clipped_faces = ClippedFaces(
face_verts=face_verts_clipped,
mesh_to_face_first_idx=mesh_to_face_first_idx_clipped,
num_faces_per_mesh=num_faces_per_mesh_clipped,
faces_clipped_to_unclipped_idx=faces_clipped_to_unclipped_idx,
barycentric_conversion=barycentric_conversion,
faces_clipped_to_conversion_idx=faces_clipped_to_conversion_idx,
clipped_faces_neighbor_idx=clipped_faces_neighbor_idx,
)
return clipped_faces

View File

@ -0,0 +1,352 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Checks for mesh rasterization in the case where the camera enters the
inside of the mesh and some mesh faces are partially
behind the image plane. These faces are clipped and then rasterized.
See pytorch3d/renderer/mesh/clip.py for more details about the
clipping process.
"""
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
from pytorch3d.structures.meshes import Meshes
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
@staticmethod
def clip_faces(meshes):
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
face_verts = verts_packed[faces_packed]
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
num_faces_per_mesh = meshes.num_faces_per_mesh()
frustum = ClipFrustum(
left=-1,
right=1,
top=-1,
bottom=1,
# In the unit tests for each case below the triangles are asummed
# to have already been projected onto the image plane.
perspective_correct=False,
z_clip_value=1e-2,
cull=True, # Cull to frustrum
)
clipped_faces = clip_faces(
face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum
)
return clipped_faces
def test_case_1(self):
"""
Case 1: Single triangle fully in front of the image plane (z=0)
Triangle is not clipped or culled. The triangle is asummed to have
already been projected onto the image plane so no perspective
correction is needed.
"""
device = "cuda:0"
verts = torch.tensor(
[[0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2],
],
dtype=torch.int64,
device=device,
)
meshes = Meshes(verts=[verts], faces=[faces])
clipped_faces = self.clip_faces(meshes)
self.assertClose(clipped_faces.face_verts, verts[faces])
self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0)
self.assertEqual(clipped_faces.num_faces_per_mesh.item(), 1)
self.assertIsNone(clipped_faces.faces_clipped_to_unclipped_idx)
self.assertIsNone(clipped_faces.faces_clipped_to_conversion_idx)
self.assertIsNone(clipped_faces.clipped_faces_neighbor_idx)
self.assertIsNone(clipped_faces.barycentric_conversion)
def test_case_2(self):
"""
Case 2 triangles are fully behind the image plane (z=0) so are completely culled.
Test with a single triangle behind the image plane.
"""
device = "cuda:0"
verts = torch.tensor(
[[-1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [1.0, 0.0, -1.0]],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2],
],
dtype=torch.int64,
device=device,
)
meshes = Meshes(verts=[verts], faces=[faces])
clipped_faces = self.clip_faces(meshes)
zero_t = torch.zeros(size=(1,), dtype=torch.int64, device=device)
self.assertClose(
clipped_faces.face_verts, torch.empty(device=device, size=(0, 3, 3))
)
self.assertClose(clipped_faces.mesh_to_face_first_idx, zero_t)
self.assertClose(clipped_faces.num_faces_per_mesh, zero_t)
self.assertClose(
clipped_faces.faces_clipped_to_unclipped_idx,
torch.empty(device=device, dtype=torch.int64, size=(0,)),
)
self.assertIsNone(clipped_faces.faces_clipped_to_conversion_idx)
self.assertIsNone(clipped_faces.clipped_faces_neighbor_idx)
self.assertIsNone(clipped_faces.barycentric_conversion)
def test_case_3(self):
"""
Case 3 triangles have exactly two vertices behind the clipping plane (z=0) so are
clipped into a smaller triangle.
Test with a single triangle parallel to the z axis which intersects with
the image plane.
"""
device = "cuda:0"
verts = torch.tensor(
[[-1.0, 0.0, -1.0], [0.0, 0.0, 1.0], [1.0, 0.0, -1.0]],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2],
],
dtype=torch.int64,
device=device,
)
meshes = Meshes(verts=[verts], faces=[faces])
clipped_faces = self.clip_faces(meshes)
zero_t = torch.zeros(size=(1,), dtype=torch.int64, device=device)
clipped_face_verts = torch.tensor(
[
[
[0.4950, 0.0000, 0.0100],
[-0.4950, 0.0000, 0.0100],
[0.0000, 0.0000, 1.0000],
]
],
device=device,
dtype=torch.float32,
)
# barycentric_conversion[i, :, k] stores the barycentric weights
# in terms of the world coordinates of the original
# (big) triangle for the kth vertex in the clipped (small) triangle.
barycentric_conversion = torch.tensor(
[
[
[0.0000, 0.4950, 0.0000],
[0.5050, 0.5050, 1.0000],
[0.4950, 0.0000, 0.0000],
]
],
device=device,
dtype=torch.float32,
)
self.assertClose(clipped_faces.face_verts, clipped_face_verts)
self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0)
self.assertEqual(clipped_faces.num_faces_per_mesh.item(), 1)
self.assertClose(clipped_faces.faces_clipped_to_unclipped_idx, zero_t)
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, zero_t)
self.assertClose(
clipped_faces.clipped_faces_neighbor_idx,
zero_t - 1, # default is -1
)
self.assertClose(clipped_faces.barycentric_conversion, barycentric_conversion)
def test_case_4(self):
"""
Case 4 triangles have exactly 1 vertex behind the clipping plane (z=0) so
are clipped into a smaller quadrilateral and then divided into two triangles.
Test with a single triangle parallel to the z axis which intersects with
the image plane.
"""
device = "cuda:0"
verts = torch.tensor(
[[0.0, 0.0, -1.0], [-1.0, 0.0, 1.0], [1.0, 0.0, 1.0]],
dtype=torch.float32,
device=device,
)
faces = torch.tensor(
[
[0, 1, 2],
],
dtype=torch.int64,
device=device,
)
meshes = Meshes(verts=[verts], faces=[faces])
clipped_faces = self.clip_faces(meshes)
clipped_face_verts = torch.tensor(
[
# t1
[
[-0.5050, 0.0000, 0.0100],
[-1.0000, 0.0000, 1.0000],
[0.5050, 0.0000, 0.0100],
],
# t2
[
[0.5050, 0.0000, 0.0100],
[-1.0000, 0.0000, 1.0000],
[1.0000, 0.0000, 1.0000],
],
],
device=device,
dtype=torch.float32,
)
barycentric_conversion = torch.tensor(
[
[
[0.4950, 0.0000, 0.4950],
[0.5050, 1.0000, 0.0000],
[0.0000, 0.0000, 0.5050],
],
[
[0.4950, 0.0000, 0.0000],
[0.0000, 1.0000, 0.0000],
[0.5050, 0.0000, 1.0000],
],
],
device=device,
dtype=torch.float32,
)
self.assertClose(clipped_faces.face_verts, clipped_face_verts)
self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0)
self.assertEqual(
clipped_faces.num_faces_per_mesh.item(), 2
) # now two faces instead of 1
self.assertClose(
clipped_faces.faces_clipped_to_unclipped_idx,
torch.tensor([0, 0], device=device, dtype=torch.int64),
)
# Neighboring face for each of the sub triangles e.g. for t1, neighbor is t2,
# and for t2, neighbor is t1
self.assertClose(
clipped_faces.clipped_faces_neighbor_idx,
torch.tensor([1, 0], device=device, dtype=torch.int64),
)
# barycentric_conversion is of shape (F_clipped)
self.assertEqual(clipped_faces.barycentric_conversion.shape[0], 2)
self.assertClose(clipped_faces.barycentric_conversion, barycentric_conversion)
# Index into barycentric_conversion for each clipped face.
self.assertClose(
clipped_faces.faces_clipped_to_conversion_idx,
torch.tensor([0, 1], device=device, dtype=torch.int64),
)
def test_mixture_of_cases(self):
"""
Test with two meshes composed of different cases to check all the
indexing is correct.
Case 4 faces are subdivided into two faces which are referred
to as t1 and t2.
"""
device = "cuda:0"
# fmt: off
verts = [
torch.tensor(
[
[-1.0, 0.0, -1.0], # noqa: E241, E201
[ 0.0, 1.0, -1.0], # noqa: E241, E201
[ 1.0, 0.0, -1.0], # noqa: E241, E201
[ 0.0, -1.0, -1.0], # noqa: E241, E201
[-1.0, 0.5, 0.5], # noqa: E241, E201
[ 1.0, 1.0, 1.0], # noqa: E241, E201
[ 0.0, -1.0, 1.0], # noqa: E241, E201
[-1.0, 0.5, -0.5], # noqa: E241, E201
[ 1.0, 1.0, -1.0], # noqa: E241, E201
[-1.0, 0.0, 1.0], # noqa: E241, E201
[ 0.0, 1.0, 1.0], # noqa: E241, E201
[ 1.0, 0.0, 1.0], # noqa: E241, E201
],
dtype=torch.float32,
device=device,
),
torch.tensor(
[
[ 0.0, -1.0, -1.0], # noqa: E241, E201
[-1.0, 0.5, 0.5], # noqa: E241, E201
[ 1.0, 1.0, 1.0], # noqa: E241, E201
],
dtype=torch.float32,
device=device
)
]
faces = [
torch.tensor(
[
[0, 1, 2], # noqa: E241, E201 Case 2 fully clipped
[3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided
[5, 4, 3], # noqa: E241, E201 Repeat of Case 4
[6, 7, 8], # noqa: E241, E201 Case 3 clipped
[9, 10, 11], # noqa: E241, E201 Case 1 untouched
],
dtype=torch.int64,
device=device,
),
torch.tensor(
[
[0, 1, 2], # noqa: E241, E201 Case 4
],
dtype=torch.int64,
device=device,
),
]
# fmt: on
meshes = Meshes(verts=verts, faces=faces)
# Clip meshes
clipped_faces = self.clip_faces(meshes)
# mesh 1: 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
# mesh 2: 2x faces (from Case 4)
self.assertEqual(clipped_faces.face_verts.shape[0], 6 + 2)
# dummy idx type tensor to avoid having to initialize the dype/device each time
idx = torch.empty(size=(1,), dtype=torch.int64, device=device)
unclipped_idx = idx.new_tensor([1, 1, 2, 2, 3, 4, 5, 5])
neighbors = idx.new_tensor([1, 0, 3, 2, -1, -1, 7, 6])
first_idx = idx.new_tensor([0, 6])
num_faces = idx.new_tensor([6, 2])
self.assertClose(clipped_faces.clipped_faces_neighbor_idx, neighbors)
self.assertClose(clipped_faces.faces_clipped_to_unclipped_idx, unclipped_idx)
self.assertClose(clipped_faces.mesh_to_face_first_idx, first_idx)
self.assertClose(clipped_faces.num_faces_per_mesh, num_faces)
# faces_clipped_to_conversion_idx maps each output face to the
# corresponding row of the barycentric_conversion matrix.
# The barycentric_conversion matrix is composed by
# finding the barycentric conversion weights for case 3 faces
# case 4 (t1) faces and case 4 (t2) faces. These are then
# concatenated. Therefore case 3 faces will be the first rows of
# the barycentric_conversion matrix followed by t1 and then t2.
# Case type of all faces: [4 (t1), 4 (t2), 4 (t1), 4 (t2), 3, 1, 4 (t1), 4 (t2)]
# Based on this information we can calculate the indices into the
# barycentric conversion matrix.
bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)