mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Utils for clipping mesh faces partially behind the image plane
Summary: Instead of culling faces behind the camera, partially clip them if they intersect with the image plane. This diff implements the utils functions for clipping. There are 4 cases for the mesh faces which are all handled: ``` Case 1: the triangle is completely in front of the clipping plane (it is left unchanged) Case 2: the triangle is completely behind the clipping plane (it is culled) Case 3: the triangle has exactly two vertices behind the clipping plane (it is clipped into a smaller triangle) Case 4: the triangle has exactly one vertex behind the clipping plane (it is clipped into a smaller quadrilateral and divided into two triangular faces) ``` Reviewed By: jcjohnson Differential Revision: D23108673 fbshipit-source-id: 550a8b6a982d06065dff10aba10d47e8b144ae52
This commit is contained in:
parent
db6fbfad90
commit
23279c5f1d
@ -1,6 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .clip import ClipFrustum, ClippedFaces, clip_faces
|
||||
from .rasterize_meshes import rasterize_meshes
|
||||
from .rasterizer import MeshRasterizer, RasterizationSettings
|
||||
from .renderer import MeshRenderer
|
||||
|
600
pytorch3d/renderer/mesh/clip.py
Normal file
600
pytorch3d/renderer/mesh/clip.py
Normal file
@ -0,0 +1,600 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
"""
|
||||
Mesh clipping is done before rasterization and is implemented using 4 cases
|
||||
(these will be referred to throughout the functions below)
|
||||
|
||||
Case 1: the triangle is completely in front of the clipping plane (it is left
|
||||
unchanged)
|
||||
Case 2: the triangle is completely behind the clipping plane (it is culled)
|
||||
Case 3: the triangle has exactly two vertices behind the clipping plane (it is
|
||||
clipped into a smaller triangle)
|
||||
Case 4: the triangle has exactly one vertex behind the clipping plane (it is clipped
|
||||
into a smaller quadrilateral and divided into two triangular faces)
|
||||
|
||||
After rasterization, the Fragments from the clipped/modified triangles
|
||||
are mapped back to the triangles in the original mesh. The indices,
|
||||
barycentric coordinates and distances are all relative to original mesh triangles.
|
||||
|
||||
NOTE: It is assumed that all z-coordinates are in world coordinates (not NDC
|
||||
coordinates), while x and y coordinates may be in NDC/screen coordinates
|
||||
(i.e after applying a projective transform e.g. cameras.transform_points(points)).
|
||||
"""
|
||||
|
||||
|
||||
class ClippedFaces:
|
||||
"""
|
||||
Helper class to store the data for the clipped version of a Meshes object
|
||||
(face_verts, mesh_to_face_first_idx, num_faces_per_mesh) along with
|
||||
conversion information (faces_clipped_to_unclipped_idx, barycentric_conversion,
|
||||
faces_clipped_to_conversion_idx, clipped_faces_neighbor_idx) required to convert
|
||||
barycentric coordinates from rasterization of the clipped Meshes to barycentric
|
||||
coordinates in terms of the unclipped Meshes.
|
||||
|
||||
Args:
|
||||
face_verts: FloatTensor of shape (F_clipped, 3, 3) giving the verts of
|
||||
each of the clipped faces
|
||||
mesh_to_face_first_idx: an tensor of shape (N,), where N is the number of meshes
|
||||
in the batch. The ith element stores the index into face_verts
|
||||
of the first face of the ith mesh.
|
||||
num_faces_per_mesh: a tensor of shape (N,) storing the number of faces in each mesh.
|
||||
faces_clipped_to_unclipped_idx: (F_clipped,) shaped LongTensor mapping each clipped
|
||||
face back to the face in faces_unclipped (i.e. the faces in the original meshes
|
||||
obtained using meshes.faces_packed())
|
||||
barycentric_conversion: (T, 3, 3) FloatTensor, where barycentric_conversion[i, :, k]
|
||||
stores the barycentric weights in terms of the world coordinates of the original
|
||||
(big) unclipped triangle for the kth vertex in the clipped (small) triangle.
|
||||
If the rasterizer then expresses some NDC coordinate in terms of barycentric
|
||||
world coordinates for the clipped (small) triangle as alpha_clipped[i,:],
|
||||
alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :]
|
||||
faces_clipped_to_conversion_idx: (F_clipped,) shaped LongTensor mapping each clipped
|
||||
face to the applicable row of barycentric_conversion (or set to -1 if conversion is
|
||||
not needed).
|
||||
clipped_faces_neighbor_idx: LongTensor of shape (F_clipped,) giving the index of the
|
||||
neighboring face for each case 4 triangle. e.g. for a case 4 face with f split
|
||||
into two triangles (t1, t2): clipped_faces_neighbor_idx[t1_idx] = t2_idx.
|
||||
Faces which are not clipped and subdivided are set to -1 (i.e cases 1/2/3).
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
"face_verts",
|
||||
"mesh_to_face_first_idx",
|
||||
"num_faces_per_mesh",
|
||||
"faces_clipped_to_unclipped_idx",
|
||||
"barycentric_conversion",
|
||||
"faces_clipped_to_conversion_idx",
|
||||
"clipped_faces_neighbor_idx",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
face_verts: torch.Tensor,
|
||||
mesh_to_face_first_idx: torch.Tensor,
|
||||
num_faces_per_mesh: torch.Tensor,
|
||||
faces_clipped_to_unclipped_idx: Optional[torch.Tensor] = None,
|
||||
barycentric_conversion: Optional[torch.Tensor] = None,
|
||||
faces_clipped_to_conversion_idx: Optional[torch.Tensor] = None,
|
||||
clipped_faces_neighbor_idx: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.face_verts = face_verts
|
||||
self.mesh_to_face_first_idx = mesh_to_face_first_idx
|
||||
self.num_faces_per_mesh = num_faces_per_mesh
|
||||
self.faces_clipped_to_unclipped_idx = faces_clipped_to_unclipped_idx
|
||||
self.barycentric_conversion = barycentric_conversion
|
||||
self.faces_clipped_to_conversion_idx = faces_clipped_to_conversion_idx
|
||||
self.clipped_faces_neighbor_idx = clipped_faces_neighbor_idx
|
||||
|
||||
|
||||
class ClipFrustum:
|
||||
"""
|
||||
Helper class to store the information needed to represent a view frustum
|
||||
(left, right, top, bottom, znear, zfar), which is used to clip or cull triangles.
|
||||
Values left as None mean that culling should not be performed for that axis.
|
||||
The parameters perspective_correct, cull, and z_clip_value are used to define
|
||||
behavior for clipping triangles to the frustum.
|
||||
|
||||
Args:
|
||||
left: NDC coordinate of the left clipping plane (along x axis)
|
||||
right: NDC coordinate of the right clipping plane (along x axis)
|
||||
top: NDC coordinate of the top clipping plane (along y axis)
|
||||
bottom: NDC coordinate of the bottom clipping plane (along y axis)
|
||||
znear: world space z coordinate of the near clipping plane
|
||||
zfar: world space z coordinate of the far clipping plane
|
||||
perspective_correct: should be set to True for a perspective camera
|
||||
cull: if True, triangles outside the frustum should be culled
|
||||
z_clip_value: if not None, then triangles should be clipped (possibly into
|
||||
smaller triangles) such that z >= z_clip_value. This avoids projections
|
||||
that go to infinity as z->0
|
||||
"""
|
||||
|
||||
__slots__ = [
|
||||
"left",
|
||||
"right",
|
||||
"top",
|
||||
"bottom",
|
||||
"znear",
|
||||
"zfar",
|
||||
"perspective_correct",
|
||||
"cull",
|
||||
"z_clip_value",
|
||||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
left: Optional[float] = None,
|
||||
right: Optional[float] = None,
|
||||
top: Optional[float] = None,
|
||||
bottom: Optional[float] = None,
|
||||
znear: Optional[float] = None,
|
||||
zfar: Optional[float] = None,
|
||||
perspective_correct: bool = False,
|
||||
cull: bool = True,
|
||||
z_clip_value: Optional[float] = None,
|
||||
):
|
||||
self.left = left
|
||||
self.right = right
|
||||
self.top = top
|
||||
self.bottom = bottom
|
||||
self.znear = znear
|
||||
self.zfar = zfar
|
||||
self.perspective_correct = perspective_correct
|
||||
self.cull = cull
|
||||
self.z_clip_value = z_clip_value
|
||||
|
||||
|
||||
def _get_culled_faces(face_verts: torch.Tensor, frustum: ClipFrustum) -> torch.Tensor:
|
||||
"""
|
||||
Helper function used to find all the faces in Meshes which are
|
||||
fully outside the view frustum. A face is culled if all 3 vertices are outside
|
||||
the same axis of the view frustum.
|
||||
|
||||
Args:
|
||||
face_verts: An (F,3,3) tensor, where F is the number of faces in
|
||||
the packed representation of Meshes. The 2nd dimension represents the 3 vertices
|
||||
of a triangle, and the 3rd dimension stores the xyz locations of each
|
||||
vertex.
|
||||
frustum: An instance of the ClipFrustum class with the information on the
|
||||
position of the clipping planes.
|
||||
|
||||
Returns:
|
||||
faces_culled: An boolean tensor of size F specifying whether or not each face should be
|
||||
culled.
|
||||
"""
|
||||
clipping_planes = (
|
||||
(frustum.left, 0, "<"),
|
||||
(frustum.right, 0, ">"),
|
||||
(frustum.top, 1, "<"),
|
||||
(frustum.bottom, 1, ">"),
|
||||
(frustum.znear, 2, "<"),
|
||||
(frustum.zfar, 2, ">"),
|
||||
)
|
||||
faces_culled = torch.zeros(
|
||||
[face_verts.shape[0]], dtype=torch.bool, device=face_verts.device
|
||||
)
|
||||
for plane in clipping_planes:
|
||||
clip_value, axis, op = plane
|
||||
# If clip_value is None then don't clip along that plane
|
||||
if frustum.cull and clip_value is not None:
|
||||
if op == "<":
|
||||
verts_clipped = face_verts[:, axis] < clip_value
|
||||
else:
|
||||
verts_clipped = face_verts[:, axis] > clip_value
|
||||
|
||||
# If all verts are clipped then face is outside the frustum
|
||||
faces_culled |= verts_clipped.sum(1) == 3
|
||||
|
||||
return faces_culled
|
||||
|
||||
|
||||
def _find_verts_intersecting_clipping_plane(
|
||||
face_verts: torch.Tensor,
|
||||
p1_face_ind: torch.Tensor,
|
||||
clip_value: float,
|
||||
perspective_correct: bool,
|
||||
) -> Tuple[Tuple[Any, Any, Any, Any, Any], List[Any]]:
|
||||
r"""
|
||||
Helper function to find the vertices used to form a new triangle for case 3/case 4 faces.
|
||||
|
||||
Given a list of triangles that are already known to intersect the clipping plane,
|
||||
solve for the two vertices p4 and p5 where the edges of the triangle intersects the
|
||||
clipping plane.
|
||||
|
||||
p1
|
||||
/\
|
||||
/ \
|
||||
/ t \
|
||||
_____________p4/______\p5__________ clip_value
|
||||
/ \
|
||||
/____ \
|
||||
p2 ---____\p3
|
||||
|
||||
Args:
|
||||
face_verts: An (F,3,3) tensor, where F is the number of faces in
|
||||
the packed representation of the Meshes, the 2nd dimension represents
|
||||
the 3 vertices of the face, and the 3rd dimension stores the xyz locations of each
|
||||
vertex. The z-coordinates must be represented in world coordinates, while
|
||||
the xy-coordinates may be in NDC/screen coordinates (i.e. after projection).
|
||||
p1_face_ind: A tensor of shape (N,) with values in the range of 0 to 2. In each
|
||||
case 3/case 4 triangle, two vertices are on the same side of the
|
||||
clipping plane and the 3rd is on the other side. p1_face_ind stores the index of
|
||||
the vertex that is not on the same side as any other vertex in the triangle.
|
||||
clip_value: Float, the z-value defining where to clip the triangle.
|
||||
perspective_correct: Bool, Should be set to true if a perspective camera was
|
||||
used and xy-coordinates of face_verts_unclipped are in NDC/screen coordinates.
|
||||
|
||||
Returns:
|
||||
A 2-tuple
|
||||
p: (p1, p2, p3, p4, p5))
|
||||
p_barycentric (p1_bary, p2_bary, p3_bary, p4_bary, p5_bary)
|
||||
|
||||
Each of p1...p5 is an (F,3) tensor of the xyz locations of the 5 points in the
|
||||
diagram above for case 3/case 4 faces. Each p1_bary...p5_bary is an (F, 3) tensor
|
||||
storing the barycentric weights used to encode p1...p5 in terms of the the original
|
||||
unclipped triangle.
|
||||
"""
|
||||
|
||||
# Let T be number of triangles in face_verts (note that these correspond to the subset
|
||||
# of case 1 or case 2 triangles). p1_face_ind, p2_face_ind, and p3_face_ind are (T)
|
||||
# tensors with values in the range of 0 to 2. p1_face_ind stores the index of the
|
||||
# vertex that is not on the same side as any other vertex in the triangle, and
|
||||
# p2_face_ind and p3_face_ind are the indices of the other two vertices preserving
|
||||
# the same counterclockwise or clockwise ordering
|
||||
T = face_verts.shape[0]
|
||||
p2_face_ind = torch.remainder(p1_face_ind + 1, 3)
|
||||
p3_face_ind = torch.remainder(p1_face_ind + 2, 3)
|
||||
|
||||
# p1, p2, p3 are (T, 3) tensors storing the corresponding (x, y, z) coordinates
|
||||
# of p1_face_ind, p2_face_ind, p3_face_ind
|
||||
# pyre-ignore[16]
|
||||
p1 = face_verts.gather(1, p1_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1)
|
||||
p2 = face_verts.gather(1, p2_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1)
|
||||
p3 = face_verts.gather(1, p3_face_ind[:, None, None].expand(-1, -1, 3)).squeeze(1)
|
||||
|
||||
##################################
|
||||
# Solve for intersection point p4
|
||||
##################################
|
||||
|
||||
# p4 is a (T, 3) tensor is the point on the segment between p1 and p2 that
|
||||
# intersects the clipping plane.
|
||||
# Solve for the weight w2 such that p1.z*(1-w2) + p2.z*w2 = clip_value.
|
||||
# Then interpolate p4 = p1*(1-w2) + p2*w2 where it is assumed that z-coordinates
|
||||
# are expressed in world coordinates (since we want to clip z in world coordinates).
|
||||
w2 = (p1[:, 2] - clip_value) / (p1[:, 2] - p2[:, 2])
|
||||
p4 = p1 * (1 - w2[:, None]) + p2 * w2[:, None]
|
||||
if perspective_correct:
|
||||
# It is assumed that all z-coordinates are in world coordinates (not NDC
|
||||
# coordinates), while x and y coordinates may be in NDC/screen coordinates.
|
||||
# If x and y are in NDC/screen coordinates and a projective transform was used
|
||||
# in a perspective camera, then we effectively want to:
|
||||
# 1. Convert back to world coordinates (by multiplying by z)
|
||||
# 2. Interpolate using w2
|
||||
# 3. Convert back to NDC/screen coordinates (by dividing by the new z=clip_value)
|
||||
p1_world = p1[:, :2] * p1[:, 2:3]
|
||||
p2_world = p2[:, :2] * p2[:, 2:3]
|
||||
p4[:, :2] = (p1_world * (1 - w2[:, None]) + p2_world * w2[:, None]) / clip_value
|
||||
|
||||
##################################
|
||||
# Solve for intersection point p5
|
||||
##################################
|
||||
|
||||
# p5 is a (T, 3) tensor representing the point on the segment between p1 and p3 that
|
||||
# intersects the clipping plane.
|
||||
# Solve for the weight w3 such that p1.z * (1-w3) + p2.z * w3 = clip_value,
|
||||
# and then interpolate p5 = p1 * (1-w3) + p3 * w3
|
||||
w3 = (p1[:, 2] - clip_value) / (p1[:, 2] - p3[:, 2])
|
||||
w3 = w3.detach()
|
||||
p5 = p1 * (1 - w3[:, None]) + p3 * w3[:, None]
|
||||
if perspective_correct:
|
||||
# Again if using a perspective camera, convert back to world coordinates
|
||||
# interpolate and convert back
|
||||
p1_world = p1[:, :2] * p1[:, 2:3]
|
||||
p3_world = p3[:, :2] * p3[:, 2:3]
|
||||
p5[:, :2] = (p1_world * (1 - w3[:, None]) + p3_world * w3[:, None]) / clip_value
|
||||
|
||||
# Set the barycentric coordinates of p1,p2,p3,p4,p5 in terms of the original
|
||||
# unclipped triangle in face_verts.
|
||||
T_idx = torch.arange(T, device=face_verts.device)
|
||||
p_barycentric = [torch.zeros((T, 3), device=face_verts.device) for i in range(5)]
|
||||
p_barycentric[0][(T_idx, p1_face_ind)] = 1
|
||||
p_barycentric[1][(T_idx, p2_face_ind)] = 1
|
||||
p_barycentric[2][(T_idx, p3_face_ind)] = 1
|
||||
p_barycentric[3][(T_idx, p1_face_ind)] = 1 - w2
|
||||
p_barycentric[3][(T_idx, p2_face_ind)] = w2
|
||||
p_barycentric[4][(T_idx, p1_face_ind)] = 1 - w3
|
||||
p_barycentric[4][(T_idx, p3_face_ind)] = w3
|
||||
|
||||
p = (p1, p2, p3, p4, p5)
|
||||
|
||||
return p, p_barycentric
|
||||
|
||||
|
||||
###################
|
||||
# Main Entry point
|
||||
###################
|
||||
def clip_faces(
|
||||
face_verts_unclipped: torch.Tensor,
|
||||
mesh_to_face_first_idx: torch.Tensor,
|
||||
num_faces_per_mesh: torch.Tensor,
|
||||
frustum: ClipFrustum,
|
||||
) -> ClippedFaces:
|
||||
"""
|
||||
Clip a mesh to the portion contained within a view frustum and with z > z_clip_value.
|
||||
|
||||
There are two types of clipping:
|
||||
1) Cull triangles that are completely outside the view frustum. This is purely
|
||||
to save computation by reducing the number of triangles that need to be
|
||||
rasterized.
|
||||
2) Clip triangles into the portion of the triangle where z > z_clip_value. The
|
||||
clipped region may be a quadrilateral, which results in splitting a triangle
|
||||
into two triangles. This does not save computation, but is necessary to
|
||||
correctly rasterize using perspective cameras for triangles that pass through
|
||||
z <= 0, because NDC/screen coordinates go to infinity at z=0.
|
||||
|
||||
Args:
|
||||
face_verts_unclipped: An (F, 3, 3) tensor, where F is the number of faces in
|
||||
the packed representation of Meshes, the 2nd dimension represents the 3 vertices
|
||||
of the triangle, and the 3rd dimension stores the xyz locations of each
|
||||
vertex. The z-coordinates must be represented in world coordinates, while
|
||||
the xy-coordinates may be in NDC/screen coordinates
|
||||
mesh_to_face_first_idx: an tensor of shape (N,), where N is the number of meshes
|
||||
in the batch. The ith element stores the index into face_verts_unclipped
|
||||
of the first face of the ith mesh.
|
||||
num_faces_per_mesh: a tensor of shape (N,) storing the number of faces in each mesh.
|
||||
frustum: a ClipFrustum object defining the frustum used to cull faces.
|
||||
|
||||
Returns:
|
||||
clipped_faces: ClippedFaces object storing a clipped version of the Meshes
|
||||
along with tensors that can be used to convert barycentric coordinates
|
||||
returned by rasterization of the clipped meshes into a barycentric
|
||||
coordinates for the unclipped meshes.
|
||||
"""
|
||||
F = face_verts_unclipped.shape[0]
|
||||
device = face_verts_unclipped.device
|
||||
|
||||
# Triangles completely outside the view frustum will be culled
|
||||
# faces_culled is of shape (F, )
|
||||
faces_culled = _get_culled_faces(face_verts_unclipped, frustum)
|
||||
|
||||
# Triangles that are partially behind the z clipping plane will be clipped to
|
||||
# smaller triangles
|
||||
z_clip_value = frustum.z_clip_value
|
||||
perspective_correct = frustum.perspective_correct
|
||||
if z_clip_value is not None:
|
||||
# (F, 3) tensor (where F is the number of triangles) marking whether each vertex
|
||||
# in a triangle is behind the clipping plane
|
||||
faces_clipped_verts = face_verts_unclipped[:, :, 2] < z_clip_value
|
||||
|
||||
# (F) dim tensor containing the number of clipped vertices in each triangle
|
||||
faces_num_clipped_verts = faces_clipped_verts.sum(1)
|
||||
else:
|
||||
faces_num_clipped_verts = torch.zeros([F, 3], device=device)
|
||||
|
||||
# If no triangles need to be clipped or culled, avoid unnecessary computation
|
||||
# and return early
|
||||
if faces_num_clipped_verts.sum().item() == 0 and faces_culled.sum().item() == 0:
|
||||
return ClippedFaces(
|
||||
face_verts=face_verts_unclipped,
|
||||
mesh_to_face_first_idx=mesh_to_face_first_idx,
|
||||
num_faces_per_mesh=num_faces_per_mesh,
|
||||
)
|
||||
|
||||
#####################################################################################
|
||||
# Classify faces into the 4 relevant cases:
|
||||
# 1) The triangle is completely in front of the clipping plane (it is left
|
||||
# unchanged)
|
||||
# 2) The triangle is completely behind the clipping plane (it is culled)
|
||||
# 3) The triangle has exactly two vertices behind the clipping plane (it is
|
||||
# clipped into a smaller triangle)
|
||||
# 4) The triangle has exactly one vertex behind the clipping plane (it is clipped
|
||||
# into a smaller quadrilateral and split into two triangles)
|
||||
#####################################################################################
|
||||
|
||||
# pyre-ignore[16]:
|
||||
faces_unculled = ~faces_culled
|
||||
# Case 1: no clipped verts or culled faces
|
||||
cases1_unclipped = torch.logical_and(faces_num_clipped_verts == 0, faces_unculled)
|
||||
case1_unclipped_idx = cases1_unclipped.nonzero(as_tuple=True)[0]
|
||||
# Case 2: all verts clipped
|
||||
case2_unclipped = torch.logical_or(faces_num_clipped_verts == 3, faces_culled)
|
||||
# Case 3: two verts clipped
|
||||
case3_unclipped = torch.logical_and(faces_num_clipped_verts == 2, faces_unculled)
|
||||
case3_unclipped_idx = case3_unclipped.nonzero(as_tuple=True)[0]
|
||||
# Case 4: one vert clipped
|
||||
case4_unclipped = torch.logical_and(faces_num_clipped_verts == 1, faces_unculled)
|
||||
case4_unclipped_idx = case4_unclipped.nonzero(as_tuple=True)[0]
|
||||
|
||||
# faces_unclipped_to_clipped_idx is an (F) dim tensor storing the index of each
|
||||
# face to the corresponding face in face_verts_clipped.
|
||||
# Each case 2 triangle will be culled (deleted from face_verts_clipped),
|
||||
# while each case 4 triangle will be split into two smaller triangles
|
||||
# (replaced by two consecutive triangles in face_verts_clipped)
|
||||
|
||||
# case2_unclipped is an (F,) dim 0/1 tensor of all the case2 faces
|
||||
# case4_unclipped is an (F,) dim 0/1 tensor of all the case4 faces
|
||||
faces_delta = case4_unclipped.int() - case2_unclipped.int()
|
||||
# faces_delta_cum gives the per face change in index. Faces which are
|
||||
# clipped in the original mesh are mapped to the closest non clipped face
|
||||
# in face_verts_clipped (this doesn't matter as they are not used
|
||||
# during rasterization anyway).
|
||||
faces_delta_cum = faces_delta.cumsum(0) - faces_delta
|
||||
delta = 1 + case4_unclipped.int() - case2_unclipped.int()
|
||||
# pyre-ignore[16]
|
||||
faces_unclipped_to_clipped_idx = delta.cumsum(0) - delta
|
||||
|
||||
###########################################
|
||||
# Allocate tensors for the output Meshes.
|
||||
# These will then be filled in for each case.
|
||||
###########################################
|
||||
F_clipped = (
|
||||
F + faces_delta_cum[-1].item() + faces_delta[-1].item()
|
||||
) # Total number of faces in the new Meshes
|
||||
face_verts_clipped = torch.zeros(
|
||||
(F_clipped, 3, 3), dtype=face_verts_unclipped.dtype, device=device
|
||||
)
|
||||
faces_clipped_to_unclipped_idx = torch.zeros(
|
||||
[F_clipped], dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Update version of mesh_to_face_first_idx and num_faces_per_mesh applicable to
|
||||
# face_verts_clipped
|
||||
mesh_to_face_first_idx_clipped = faces_unclipped_to_clipped_idx[
|
||||
mesh_to_face_first_idx
|
||||
]
|
||||
F_clipped_t = torch.full([1], F_clipped, dtype=torch.int64, device=device)
|
||||
num_faces_next = torch.cat((mesh_to_face_first_idx_clipped[1:], F_clipped_t))
|
||||
num_faces_per_mesh_clipped = num_faces_next - mesh_to_face_first_idx_clipped
|
||||
|
||||
################# Start Case 1 ########################################
|
||||
|
||||
# Case 1: Triangles are fully visible, copy unchanged triangles into the
|
||||
# appropriate position in the new list of faces
|
||||
case1_clipped_idx = faces_unclipped_to_clipped_idx[case1_unclipped_idx]
|
||||
face_verts_clipped[case1_clipped_idx] = face_verts_unclipped[case1_unclipped_idx]
|
||||
faces_clipped_to_unclipped_idx[case1_clipped_idx] = case1_unclipped_idx
|
||||
|
||||
# If no triangles need to be clipped but some triangles were culled, avoid
|
||||
# unnecessary clipping computation
|
||||
if case3_unclipped_idx.shape[0] + case4_unclipped_idx.shape[0] == 0:
|
||||
return ClippedFaces(
|
||||
face_verts=face_verts_clipped,
|
||||
mesh_to_face_first_idx=mesh_to_face_first_idx_clipped,
|
||||
num_faces_per_mesh=num_faces_per_mesh_clipped,
|
||||
faces_clipped_to_unclipped_idx=faces_clipped_to_unclipped_idx,
|
||||
)
|
||||
|
||||
################# End Case 1 ##########################################
|
||||
|
||||
################# Start Case 3 ########################################
|
||||
|
||||
# Case 3: exactly two vertices are behind the camera, clipping the triangle into a
|
||||
# triangle. In the diagram below, we clip the bottom part of the triangle, and add
|
||||
# new vertices p4 and p5 by intersecting with the clipping plane. The updated
|
||||
# triangle is the triangle between p4, p1, p5
|
||||
#
|
||||
# p1 (unclipped vertex)
|
||||
# /\
|
||||
# / \
|
||||
# / t \
|
||||
# _____________p4/______\p5__________ clip_value
|
||||
# xxxxxxxxxxxxxx/ \xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxx/____ \xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxx p2 xxxx---____\p3 xxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
||||
faces_case3 = face_verts_unclipped[case3_unclipped_idx]
|
||||
|
||||
# index (0, 1, or 2) of the vertex in front of the clipping plane
|
||||
p1_face_ind = torch.where(~faces_clipped_verts[case3_unclipped_idx])[1]
|
||||
|
||||
# Solve for the points p4, p5 that intersect the clipping plane
|
||||
p, p_barycentric = _find_verts_intersecting_clipping_plane(
|
||||
faces_case3, p1_face_ind, z_clip_value, perspective_correct
|
||||
)
|
||||
|
||||
p1, _, _, p4, p5 = p
|
||||
p1_barycentric, _, _, p4_barycentric, p5_barycentric = p_barycentric
|
||||
|
||||
# Store clipped triangle
|
||||
case3_clipped_idx = faces_unclipped_to_clipped_idx[case3_unclipped_idx]
|
||||
t_barycentric = torch.stack((p4_barycentric, p5_barycentric, p1_barycentric), 2)
|
||||
face_verts_clipped[case3_clipped_idx] = torch.stack((p4, p5, p1), 1)
|
||||
faces_clipped_to_unclipped_idx[case3_clipped_idx] = case3_unclipped_idx
|
||||
|
||||
################# End Case 3 ##########################################
|
||||
|
||||
################# Start Case 4 ########################################
|
||||
|
||||
# Case 4: exactly one vertex is behind the camera, clip the triangle into a
|
||||
# quadrilateral. In the diagram below, we clip the bottom part of the triangle,
|
||||
# and add new vertices p4 and p5 by intersecting with the cliiping plane. The
|
||||
# unclipped region is a quadrilateral, which is split into two triangles:
|
||||
# t1: p4, p2, p5
|
||||
# t2: p5, p2, p3
|
||||
#
|
||||
# p3_____________________p2
|
||||
# \ __--/
|
||||
# \ t2 __-- /
|
||||
# \ __-- t1 /
|
||||
# ______________p5\__--_________/p4_________clip_value
|
||||
# xxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx
|
||||
# xxxxxxxxxxxxxxxxxxxxxx\ /xxxxxxxxxxxxxxxxxxxxx
|
||||
# p1 (clipped vertex)
|
||||
|
||||
faces_case4 = face_verts_unclipped[case4_unclipped_idx]
|
||||
|
||||
# index (0, 1, or 2) of the vertex behind the clipping plane
|
||||
p1_face_ind = torch.where(faces_clipped_verts[case4_unclipped_idx])[1]
|
||||
|
||||
# Solve for the points p4, p5 that intersect the clipping plane
|
||||
p, p_barycentric = _find_verts_intersecting_clipping_plane(
|
||||
faces_case4, p1_face_ind, z_clip_value, perspective_correct
|
||||
)
|
||||
_, p2, p3, p4, p5 = p
|
||||
_, p2_barycentric, p3_barycentric, p4_barycentric, p5_barycentric = p_barycentric
|
||||
|
||||
# Store clipped triangles
|
||||
case4_clipped_idx = faces_unclipped_to_clipped_idx[case4_unclipped_idx]
|
||||
face_verts_clipped[case4_clipped_idx] = torch.stack((p4, p2, p5), 1)
|
||||
face_verts_clipped[case4_clipped_idx + 1] = torch.stack((p5, p2, p3), 1)
|
||||
t1_barycentric = torch.stack((p4_barycentric, p2_barycentric, p5_barycentric), 2)
|
||||
t2_barycentric = torch.stack((p5_barycentric, p2_barycentric, p3_barycentric), 2)
|
||||
faces_clipped_to_unclipped_idx[case4_clipped_idx] = case4_unclipped_idx
|
||||
faces_clipped_to_unclipped_idx[case4_clipped_idx + 1] = case4_unclipped_idx
|
||||
|
||||
##################### End Case 4 #########################
|
||||
|
||||
# Triangles that were clipped (case 3 & case 4) will require conversion of
|
||||
# barycentric coordinates from being in terms of the smaller clipped triangle to in terms
|
||||
# of the original big triangle. If there are T clipped triangles,
|
||||
# barycentric_conversion is a (T, 3, 3) tensor, where barycentric_conversion[i, :, k]
|
||||
# stores the barycentric weights in terms of the world coordinates of the original
|
||||
# (big) triangle for the kth vertex in the clipped (small) triangle. If our
|
||||
# rasterizer then expresses some NDC coordinate in terms of barycentric
|
||||
# world coordinates for the clipped (small) triangle as alpha_clipped[i,:],
|
||||
# alpha_unclipped[i, :] = barycentric_conversion[i, :, :]*alpha_clipped[i, :]
|
||||
barycentric_conversion = torch.cat((t_barycentric, t1_barycentric, t2_barycentric))
|
||||
|
||||
# faces_clipped_to_conversion_idx is an (F_clipped,) shape tensor mapping each output
|
||||
# face to the applicable row of barycentric_conversion (or set to -1 if conversion is
|
||||
# not needed)
|
||||
faces_to_convert_idx = torch.cat(
|
||||
(case3_clipped_idx, case4_clipped_idx, case4_clipped_idx + 1), 0
|
||||
)
|
||||
barycentric_idx = torch.arange(
|
||||
barycentric_conversion.shape[0], dtype=torch.int64, device=device
|
||||
)
|
||||
faces_clipped_to_conversion_idx = torch.full(
|
||||
[F_clipped], -1, dtype=torch.int64, device=device
|
||||
)
|
||||
faces_clipped_to_conversion_idx[faces_to_convert_idx] = barycentric_idx
|
||||
|
||||
# clipped_faces_quadrilateral_ind is an (F_clipped) dim tensor
|
||||
# For case 4 clipped triangles (where a big triangle is split in two smaller triangles),
|
||||
# store the index of the neighboring clipped triangle.
|
||||
# This will be needed because if the soft rasterizer includes both
|
||||
# triangles in the list of top K nearest triangles, we
|
||||
# should only use the one with the smaller distance.
|
||||
clipped_faces_neighbor_idx = torch.full(
|
||||
[F_clipped], -1, dtype=torch.int64, device=device
|
||||
)
|
||||
clipped_faces_neighbor_idx[case4_clipped_idx] = case4_clipped_idx + 1
|
||||
clipped_faces_neighbor_idx[case4_clipped_idx + 1] = case4_clipped_idx
|
||||
|
||||
clipped_faces = ClippedFaces(
|
||||
face_verts=face_verts_clipped,
|
||||
mesh_to_face_first_idx=mesh_to_face_first_idx_clipped,
|
||||
num_faces_per_mesh=num_faces_per_mesh_clipped,
|
||||
faces_clipped_to_unclipped_idx=faces_clipped_to_unclipped_idx,
|
||||
barycentric_conversion=barycentric_conversion,
|
||||
faces_clipped_to_conversion_idx=faces_clipped_to_conversion_idx,
|
||||
clipped_faces_neighbor_idx=clipped_faces_neighbor_idx,
|
||||
)
|
||||
return clipped_faces
|
352
tests/test_render_meshes_clipped.py
Normal file
352
tests/test_render_meshes_clipped.py
Normal file
@ -0,0 +1,352 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
"""
|
||||
Checks for mesh rasterization in the case where the camera enters the
|
||||
inside of the mesh and some mesh faces are partially
|
||||
behind the image plane. These faces are clipped and then rasterized.
|
||||
See pytorch3d/renderer/mesh/clip.py for more details about the
|
||||
clipping process.
|
||||
"""
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.renderer.mesh import ClipFrustum, clip_faces
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
|
||||
|
||||
class TestRenderMeshesClipping(TestCaseMixin, unittest.TestCase):
|
||||
@staticmethod
|
||||
def clip_faces(meshes):
|
||||
verts_packed = meshes.verts_packed()
|
||||
faces_packed = meshes.faces_packed()
|
||||
face_verts = verts_packed[faces_packed]
|
||||
mesh_to_face_first_idx = meshes.mesh_to_faces_packed_first_idx()
|
||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||
|
||||
frustum = ClipFrustum(
|
||||
left=-1,
|
||||
right=1,
|
||||
top=-1,
|
||||
bottom=1,
|
||||
# In the unit tests for each case below the triangles are asummed
|
||||
# to have already been projected onto the image plane.
|
||||
perspective_correct=False,
|
||||
z_clip_value=1e-2,
|
||||
cull=True, # Cull to frustrum
|
||||
)
|
||||
|
||||
clipped_faces = clip_faces(
|
||||
face_verts, mesh_to_face_first_idx, num_faces_per_mesh, frustum
|
||||
)
|
||||
return clipped_faces
|
||||
|
||||
def test_case_1(self):
|
||||
"""
|
||||
Case 1: Single triangle fully in front of the image plane (z=0)
|
||||
Triangle is not clipped or culled. The triangle is asummed to have
|
||||
already been projected onto the image plane so no perspective
|
||||
correction is needed.
|
||||
"""
|
||||
device = "cuda:0"
|
||||
verts = torch.tensor(
|
||||
[[0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [0.0, 1.0, 1.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
self.assertClose(clipped_faces.face_verts, verts[faces])
|
||||
self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0)
|
||||
self.assertEqual(clipped_faces.num_faces_per_mesh.item(), 1)
|
||||
self.assertIsNone(clipped_faces.faces_clipped_to_unclipped_idx)
|
||||
self.assertIsNone(clipped_faces.faces_clipped_to_conversion_idx)
|
||||
self.assertIsNone(clipped_faces.clipped_faces_neighbor_idx)
|
||||
self.assertIsNone(clipped_faces.barycentric_conversion)
|
||||
|
||||
def test_case_2(self):
|
||||
"""
|
||||
Case 2 triangles are fully behind the image plane (z=0) so are completely culled.
|
||||
Test with a single triangle behind the image plane.
|
||||
"""
|
||||
|
||||
device = "cuda:0"
|
||||
verts = torch.tensor(
|
||||
[[-1.0, 0.0, -1.0], [0.0, 1.0, -1.0], [1.0, 0.0, -1.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
zero_t = torch.zeros(size=(1,), dtype=torch.int64, device=device)
|
||||
self.assertClose(
|
||||
clipped_faces.face_verts, torch.empty(device=device, size=(0, 3, 3))
|
||||
)
|
||||
self.assertClose(clipped_faces.mesh_to_face_first_idx, zero_t)
|
||||
self.assertClose(clipped_faces.num_faces_per_mesh, zero_t)
|
||||
self.assertClose(
|
||||
clipped_faces.faces_clipped_to_unclipped_idx,
|
||||
torch.empty(device=device, dtype=torch.int64, size=(0,)),
|
||||
)
|
||||
self.assertIsNone(clipped_faces.faces_clipped_to_conversion_idx)
|
||||
self.assertIsNone(clipped_faces.clipped_faces_neighbor_idx)
|
||||
self.assertIsNone(clipped_faces.barycentric_conversion)
|
||||
|
||||
def test_case_3(self):
|
||||
"""
|
||||
Case 3 triangles have exactly two vertices behind the clipping plane (z=0) so are
|
||||
clipped into a smaller triangle.
|
||||
|
||||
Test with a single triangle parallel to the z axis which intersects with
|
||||
the image plane.
|
||||
"""
|
||||
|
||||
device = "cuda:0"
|
||||
verts = torch.tensor(
|
||||
[[-1.0, 0.0, -1.0], [0.0, 0.0, 1.0], [1.0, 0.0, -1.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
zero_t = torch.zeros(size=(1,), dtype=torch.int64, device=device)
|
||||
clipped_face_verts = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.4950, 0.0000, 0.0100],
|
||||
[-0.4950, 0.0000, 0.0100],
|
||||
[0.0000, 0.0000, 1.0000],
|
||||
]
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# barycentric_conversion[i, :, k] stores the barycentric weights
|
||||
# in terms of the world coordinates of the original
|
||||
# (big) triangle for the kth vertex in the clipped (small) triangle.
|
||||
barycentric_conversion = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.0000, 0.4950, 0.0000],
|
||||
[0.5050, 0.5050, 1.0000],
|
||||
[0.4950, 0.0000, 0.0000],
|
||||
]
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.assertClose(clipped_faces.face_verts, clipped_face_verts)
|
||||
self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0)
|
||||
self.assertEqual(clipped_faces.num_faces_per_mesh.item(), 1)
|
||||
self.assertClose(clipped_faces.faces_clipped_to_unclipped_idx, zero_t)
|
||||
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, zero_t)
|
||||
self.assertClose(
|
||||
clipped_faces.clipped_faces_neighbor_idx,
|
||||
zero_t - 1, # default is -1
|
||||
)
|
||||
self.assertClose(clipped_faces.barycentric_conversion, barycentric_conversion)
|
||||
|
||||
def test_case_4(self):
|
||||
"""
|
||||
Case 4 triangles have exactly 1 vertex behind the clipping plane (z=0) so
|
||||
are clipped into a smaller quadrilateral and then divided into two triangles.
|
||||
|
||||
Test with a single triangle parallel to the z axis which intersects with
|
||||
the image plane.
|
||||
"""
|
||||
|
||||
device = "cuda:0"
|
||||
verts = torch.tensor(
|
||||
[[0.0, 0.0, -1.0], [-1.0, 0.0, 1.0], [1.0, 0.0, 1.0]],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
)
|
||||
faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
meshes = Meshes(verts=[verts], faces=[faces])
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
clipped_face_verts = torch.tensor(
|
||||
[
|
||||
# t1
|
||||
[
|
||||
[-0.5050, 0.0000, 0.0100],
|
||||
[-1.0000, 0.0000, 1.0000],
|
||||
[0.5050, 0.0000, 0.0100],
|
||||
],
|
||||
# t2
|
||||
[
|
||||
[0.5050, 0.0000, 0.0100],
|
||||
[-1.0000, 0.0000, 1.0000],
|
||||
[1.0000, 0.0000, 1.0000],
|
||||
],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
barycentric_conversion = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.4950, 0.0000, 0.4950],
|
||||
[0.5050, 1.0000, 0.0000],
|
||||
[0.0000, 0.0000, 0.5050],
|
||||
],
|
||||
[
|
||||
[0.4950, 0.0000, 0.0000],
|
||||
[0.0000, 1.0000, 0.0000],
|
||||
[0.5050, 0.0000, 1.0000],
|
||||
],
|
||||
],
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
self.assertClose(clipped_faces.face_verts, clipped_face_verts)
|
||||
self.assertEqual(clipped_faces.mesh_to_face_first_idx.item(), 0)
|
||||
self.assertEqual(
|
||||
clipped_faces.num_faces_per_mesh.item(), 2
|
||||
) # now two faces instead of 1
|
||||
self.assertClose(
|
||||
clipped_faces.faces_clipped_to_unclipped_idx,
|
||||
torch.tensor([0, 0], device=device, dtype=torch.int64),
|
||||
)
|
||||
# Neighboring face for each of the sub triangles e.g. for t1, neighbor is t2,
|
||||
# and for t2, neighbor is t1
|
||||
self.assertClose(
|
||||
clipped_faces.clipped_faces_neighbor_idx,
|
||||
torch.tensor([1, 0], device=device, dtype=torch.int64),
|
||||
)
|
||||
# barycentric_conversion is of shape (F_clipped)
|
||||
self.assertEqual(clipped_faces.barycentric_conversion.shape[0], 2)
|
||||
self.assertClose(clipped_faces.barycentric_conversion, barycentric_conversion)
|
||||
# Index into barycentric_conversion for each clipped face.
|
||||
self.assertClose(
|
||||
clipped_faces.faces_clipped_to_conversion_idx,
|
||||
torch.tensor([0, 1], device=device, dtype=torch.int64),
|
||||
)
|
||||
|
||||
def test_mixture_of_cases(self):
|
||||
"""
|
||||
Test with two meshes composed of different cases to check all the
|
||||
indexing is correct.
|
||||
Case 4 faces are subdivided into two faces which are referred
|
||||
to as t1 and t2.
|
||||
"""
|
||||
device = "cuda:0"
|
||||
# fmt: off
|
||||
verts = [
|
||||
torch.tensor(
|
||||
[
|
||||
[-1.0, 0.0, -1.0], # noqa: E241, E201
|
||||
[ 0.0, 1.0, -1.0], # noqa: E241, E201
|
||||
[ 1.0, 0.0, -1.0], # noqa: E241, E201
|
||||
[ 0.0, -1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, 0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, 1.0], # noqa: E241, E201
|
||||
[ 0.0, -1.0, 1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, -0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.0, 1.0], # noqa: E241, E201
|
||||
[ 0.0, 1.0, 1.0], # noqa: E241, E201
|
||||
[ 1.0, 0.0, 1.0], # noqa: E241, E201
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[ 0.0, -1.0, -1.0], # noqa: E241, E201
|
||||
[-1.0, 0.5, 0.5], # noqa: E241, E201
|
||||
[ 1.0, 1.0, 1.0], # noqa: E241, E201
|
||||
],
|
||||
dtype=torch.float32,
|
||||
device=device
|
||||
)
|
||||
]
|
||||
faces = [
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 2], # noqa: E241, E201 Case 2 fully clipped
|
||||
[3, 4, 5], # noqa: E241, E201 Case 4 clipped and subdivided
|
||||
[5, 4, 3], # noqa: E241, E201 Repeat of Case 4
|
||||
[6, 7, 8], # noqa: E241, E201 Case 3 clipped
|
||||
[9, 10, 11], # noqa: E241, E201 Case 1 untouched
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
torch.tensor(
|
||||
[
|
||||
[0, 1, 2], # noqa: E241, E201 Case 4
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
),
|
||||
]
|
||||
# fmt: on
|
||||
meshes = Meshes(verts=verts, faces=faces)
|
||||
|
||||
# Clip meshes
|
||||
clipped_faces = self.clip_faces(meshes)
|
||||
|
||||
# mesh 1: 4x faces (from Case 4) + 1 (from Case 3) + 1 (from Case 1)
|
||||
# mesh 2: 2x faces (from Case 4)
|
||||
self.assertEqual(clipped_faces.face_verts.shape[0], 6 + 2)
|
||||
|
||||
# dummy idx type tensor to avoid having to initialize the dype/device each time
|
||||
idx = torch.empty(size=(1,), dtype=torch.int64, device=device)
|
||||
unclipped_idx = idx.new_tensor([1, 1, 2, 2, 3, 4, 5, 5])
|
||||
neighbors = idx.new_tensor([1, 0, 3, 2, -1, -1, 7, 6])
|
||||
first_idx = idx.new_tensor([0, 6])
|
||||
num_faces = idx.new_tensor([6, 2])
|
||||
|
||||
self.assertClose(clipped_faces.clipped_faces_neighbor_idx, neighbors)
|
||||
self.assertClose(clipped_faces.faces_clipped_to_unclipped_idx, unclipped_idx)
|
||||
self.assertClose(clipped_faces.mesh_to_face_first_idx, first_idx)
|
||||
self.assertClose(clipped_faces.num_faces_per_mesh, num_faces)
|
||||
|
||||
# faces_clipped_to_conversion_idx maps each output face to the
|
||||
# corresponding row of the barycentric_conversion matrix.
|
||||
# The barycentric_conversion matrix is composed by
|
||||
# finding the barycentric conversion weights for case 3 faces
|
||||
# case 4 (t1) faces and case 4 (t2) faces. These are then
|
||||
# concatenated. Therefore case 3 faces will be the first rows of
|
||||
# the barycentric_conversion matrix followed by t1 and then t2.
|
||||
# Case type of all faces: [4 (t1), 4 (t2), 4 (t1), 4 (t2), 3, 1, 4 (t1), 4 (t2)]
|
||||
# Based on this information we can calculate the indices into the
|
||||
# barycentric conversion matrix.
|
||||
bary_idx = idx.new_tensor([1, 4, 2, 5, 0, -1, 3, 6])
|
||||
self.assertClose(clipped_faces.faces_clipped_to_conversion_idx, bary_idx)
|
Loading…
x
Reference in New Issue
Block a user