back face culling in rasterization

Summary:
Added backface culling as an option to the `raster_settings`. This is needed for the full forward rendering of shapenet meshes with texture (some meshes contain
multiple overlapping segments which have different textures).

For a triangle (v0, v1, v2) define the vectors A = (v1 - v0) and B = (v2 − v0) and use this to calculate the area of the triangle as:
```
area = 0.5 * A  x B
area = 0.5 * ((x1 − x0)(y2 − y0) − (x2 − x0)(y1 − y0))
```
The area will be positive if (v0, v1, v2) are oriented counterclockwise (a front face), and negative if (v0, v1, v2) are oriented clockwise (a back face).

We can reuse the `edge_function` as it already calculates the triangle area.

Reviewed By: jcjohnson

Differential Revision: D20960115

fbshipit-source-id: 2d8a4b9ccfb653df18e79aed8d05c7ec0f057ab1
This commit is contained in:
Nikhila Ravi 2020-04-22 08:20:16 -07:00 committed by Facebook GitHub Bot
parent 3c6f9220fc
commit 4bf30593ff
7 changed files with 187 additions and 30 deletions

View File

@ -111,7 +111,8 @@ __device__ void CheckPixelInsideFace(
const float blur_radius,
const float2 pxy, // Coordinates of the pixel
const int K,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
const auto v012 = GetSingleFaceVerts(face_verts, face_idx);
const float3 v0 = thrust::get<0>(v012);
const float3 v1 = thrust::get<1>(v012);
@ -124,16 +125,20 @@ __device__ void CheckPixelInsideFace(
// Perform checks and skip if:
// 1. the face is behind the camera
// 2. the face has very small face area
// 3. the pixel is outside the face bbox
// 2. the face is facing away from the camera
// 3. the face has very small face area
// 4. the pixel is outside the face bbox
const float zmax = FloatMax3(v0.z, v1.z, v2.z);
const bool outside_bbox = CheckPointOutsideBoundingBox(
v0, v1, v2, sqrt(blur_radius), pxy); // use sqrt of blur for bbox
const float face_area = EdgeFunctionForward(v0xy, v1xy, v2xy);
// Check if the face is visible to the camera.
const bool back_face = face_area < 0.0;
const bool zero_face_area =
(face_area <= kEpsilon && face_area >= -1.0f * kEpsilon);
if (zmax < 0 || outside_bbox || zero_face_area) {
if (zmax < 0 || cull_backfaces && back_face || outside_bbox ||
zero_face_area) {
return;
}
@ -191,6 +196,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
const int64_t* num_faces_per_mesh,
const float blur_radius,
const bool perspective_correct,
const bool cull_backfaces,
const int N,
const int H,
const int W,
@ -251,7 +257,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
blur_radius,
pxy,
K,
perspective_correct);
perspective_correct,
cull_backfaces);
}
// TODO: make sorting an option as only top k is needed, not sorted values.
@ -276,7 +283,8 @@ RasterizeMeshesNaiveCuda(
const int image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@ -314,6 +322,7 @@ RasterizeMeshesNaiveCuda(
num_faces_per_mesh.contiguous().data_ptr<int64_t>(),
blur_radius,
perspective_correct,
cull_backfaces,
N,
H,
W,
@ -667,6 +676,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
const float blur_radius,
const int bin_size,
const bool perspective_correct,
const bool cull_backfaces,
const int N,
const int B,
const int M,
@ -730,7 +740,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
blur_radius,
pxy,
K,
perspective_correct);
perspective_correct,
cull_backfaces);
}
// Now we've looked at all the faces for this bin, so we can write
@ -762,7 +773,8 @@ RasterizeMeshesFineCuda(
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@ -797,6 +809,7 @@ RasterizeMeshesFineCuda(
blur_radius,
bin_size,
perspective_correct,
cull_backfaces,
N,
B,
M,

View File

@ -17,7 +17,8 @@ RasterizeMeshesNaiveCpu(
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct);
const bool perspective_correct,
const bool cull_backfaces);
#ifdef WITH_CUDA
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
@ -28,7 +29,8 @@ RasterizeMeshesNaiveCuda(
const int image_size,
const float blur_radius,
const int num_closest,
const bool perspective_correct);
const bool perspective_correct,
const bool cull_backfaces);
#endif
// Forward pass for rasterizing a batch of meshes.
//
@ -55,6 +57,14 @@ RasterizeMeshesNaiveCuda(
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
// cull_backfaces: Bool, Whether to only rasterize mesh faces which are
// visible to the camera. This assumes that vertices of
// front-facing triangles are ordered in an anti-clockwise
// fashion, and triangles that face away from the camera are
// in a clockwise order relative to the current view
// direction. NOTE: This will only work if the mesh faces are
// consistently defined with counter-clockwise ordering when
// viewed from the outside.
//
// Returns:
// A 4 element tuple of:
@ -80,7 +90,8 @@ RasterizeMeshesNaive(
const int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
// TODO: Better type checking.
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
@ -91,7 +102,8 @@ RasterizeMeshesNaive(
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
perspective_correct,
cull_backfaces);
#else
AT_ERROR("Not compiled with GPU support");
#endif
@ -103,7 +115,8 @@ RasterizeMeshesNaive(
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
perspective_correct,
cull_backfaces);
}
}
@ -274,7 +287,8 @@ RasterizeMeshesFineCuda(
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct);
const bool perspective_correct,
const bool cull_backfaces);
#endif
// Args:
// face_verts: Tensor of shape (F, 3, 3) giving (packed) vertex positions for
@ -296,6 +310,14 @@ RasterizeMeshesFineCuda(
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
// cull_backfaces: Bool, Whether to only rasterize mesh faces which are
// visible to the camera. This assumes that vertices of
// front-facing triangles are ordered in an anti-clockwise
// fashion, and triangles that face away from the camera are
// in a clockwise order relative to the current view
// direction. NOTE: This will only work if the mesh faces are
// consistently defined with counter-clockwise ordering when
// viewed from the outside.
//
// Returns (same as rasterize_meshes):
// A 4 element tuple of:
@ -321,7 +343,8 @@ RasterizeMeshesFine(
const float blur_radius,
const int bin_size,
const int faces_per_pixel,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
if (face_verts.is_cuda()) {
#ifdef WITH_CUDA
return RasterizeMeshesFineCuda(
@ -331,7 +354,8 @@ RasterizeMeshesFine(
blur_radius,
bin_size,
faces_per_pixel,
perspective_correct);
perspective_correct,
cull_backfaces);
#else
AT_ERROR("Not compiled with GPU support");
#endif
@ -372,7 +396,14 @@ RasterizeMeshesFine(
// coordinates for each pixel; if this is False then
// this function instead returns screen-space
// barycentric coordinates for each pixel.
//
// cull_backfaces: Bool, Whether to only rasterize mesh faces which are
// visible to the camera. This assumes that vertices of
// front-facing triangles are ordered in an anti-clockwise
// fashion, and triangles that face away from the camera are
// in a clockwise order relative to the current view
// direction. NOTE: This will only work if the mesh faces are
// consistently defined with counter-clockwise ordering when
// viewed from the outside.
//
// Returns:
// A 4 element tuple of:
@ -400,7 +431,8 @@ RasterizeMeshes(
const int faces_per_pixel,
const int bin_size,
const int max_faces_per_bin,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
if (bin_size > 0 && max_faces_per_bin > 0) {
// Use coarse-to-fine rasterization
auto bin_faces = RasterizeMeshesCoarse(
@ -418,7 +450,8 @@ RasterizeMeshes(
blur_radius,
bin_size,
faces_per_pixel,
perspective_correct);
perspective_correct,
cull_backfaces);
} else {
// Use the naive per-pixel implementation
return RasterizeMeshesNaive(
@ -428,6 +461,7 @@ RasterizeMeshes(
image_size,
blur_radius,
faces_per_pixel,
perspective_correct);
perspective_correct,
cull_backfaces);
}
}

View File

@ -107,7 +107,8 @@ RasterizeMeshesNaiveCpu(
int image_size,
const float blur_radius,
const int faces_per_pixel,
const bool perspective_correct) {
const bool perspective_correct,
const bool cull_backfaces) {
if (face_verts.ndimension() != 3 || face_verts.size(1) != 3 ||
face_verts.size(2) != 3) {
AT_ERROR("face_verts must have dimensions (num_faces, 3, 3)");
@ -184,8 +185,13 @@ RasterizeMeshesNaiveCpu(
const vec2<float> v1(x1, y1);
const vec2<float> v2(x2, y2);
// Skip faces with zero area.
const float face_area = face_areas_a[f];
const bool back_face = face_area < 0.0;
// Check if the face is visible to the camera.
if (cull_backfaces && back_face) {
continue;
}
// Skip faces with zero area.
if (face_area <= kEpsilon && face_area >= -1.0f * kEpsilon) {
continue;
}

View File

@ -140,16 +140,16 @@ def load_obj(f_obj, load_textures=True):
If there are faces with more than 3 vertices
they are subdivided into triangles. Polygonal faces are assummed to have
vertices ordered counter-clockwise so the (right-handed) normal points
into the screen e.g. a proper rectangular face would be specified like this:
out of the screen e.g. a proper rectangular face would be specified like this:
::
0_________1
| |
| |
3 ________2
The face would be split into two triangles: (0, 1, 2) and (0, 2, 3),
both of which are also oriented clockwise and have normals
pointing into the screen.
The face would be split into two triangles: (0, 2, 1) and (0, 3, 2),
both of which are also oriented counter-clockwise and have normals
pointing out of the screen.
Args:
f: A file-like object (with methods read, readline, tell, and seek),

View File

@ -20,6 +20,7 @@ def rasterize_meshes(
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
cull_backfaces: bool = False,
):
"""
Rasterize a batch of meshes given the shape of the desired output image.
@ -45,8 +46,16 @@ def rasterize_meshes(
bin. If more than this many faces actually fall into a bin, an error
will be raised. This should not affect the output values, but can affect
the memory usage in the forward pass.
perspective_correct: Whether to apply perspective correction when computing
perspective_correct: Bool, Whether to apply perspective correction when computing
barycentric coordinates for pixels.
cull_backfaces: Bool, Whether to only rasterize mesh faces which are
visible to the camera. This assumes that vertices of
front-facing triangles are ordered in an anti-clockwise
fashion, and triangles that face away from the camera are
in a clockwise order relative to the current view
direction. NOTE: This will only work if the mesh faces are
consistently defined with counter-clockwise ordering when
viewed from the outside.
Returns:
4-element tuple containing
@ -118,6 +127,7 @@ def rasterize_meshes(
bin_size,
max_faces_per_bin,
perspective_correct,
cull_backfaces,
)
@ -139,6 +149,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
for each mesh in the batch.
image_size, blur_radius, faces_per_pixel: same as rasterize_meshes.
perspective_correct: same as rasterize_meshes.
cull_backfaces: same as rasterize_meshes.
Returns:
same as rasterize_meshes function.
@ -156,6 +167,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
bin_size: int = 0,
max_faces_per_bin: int = 0,
perspective_correct: bool = False,
cull_backfaces: bool = False,
):
pix_to_face, zbuf, barycentric_coords, dists = _C.rasterize_meshes(
face_verts,
@ -167,6 +179,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
bin_size,
max_faces_per_bin,
perspective_correct,
cull_backfaces,
)
ctx.save_for_backward(face_verts, pix_to_face)
ctx.perspective_correct = perspective_correct
@ -183,6 +196,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_bin_size = None
grad_max_faces_per_bin = None
grad_perspective_correct = None
grad_cull_backfaces = None
face_verts, pix_to_face = ctx.saved_tensors
grad_face_verts = _C.rasterize_meshes_backward(
face_verts,
@ -202,6 +216,7 @@ class _RasterizeFaceVerts(torch.autograd.Function):
grad_bin_size,
grad_max_faces_per_bin,
grad_perspective_correct,
grad_cull_backfaces,
)
return grads
@ -217,6 +232,7 @@ def rasterize_meshes_python(
blur_radius: float = 0.0,
faces_per_pixel: int = 8,
perspective_correct: bool = False,
cull_backfaces: bool = False,
):
"""
Naive PyTorch implementation of mesh rasterization with the same inputs and
@ -287,7 +303,12 @@ def rasterize_meshes_python(
face = faces_verts[f].squeeze()
v0, v1, v2 = face.unbind(0)
face_area = edge_function(v2, v0, v1)
face_area = edge_function(v0, v1, v2)
# Ignore triangles facing away from the camera.
back_face = face_area < 0
if cull_backfaces and back_face:
continue
# Ignore faces which have zero area.
if face_area == 0.0:
@ -365,8 +386,8 @@ def edge_function(p, v0, v1):
.. code-block:: python
A = p - v0
B = v1 - v0
B = p - v0
A = v1 - v0
v1 ________
/\ /

View File

@ -26,6 +26,7 @@ class RasterizationSettings:
"bin_size",
"max_faces_per_bin",
"perspective_correct",
"cull_backfaces",
]
def __init__(
@ -36,6 +37,7 @@ class RasterizationSettings:
bin_size: Optional[int] = None,
max_faces_per_bin: Optional[int] = None,
perspective_correct: bool = False,
cull_backfaces: bool = False,
):
self.image_size = image_size
self.blur_radius = blur_radius
@ -43,6 +45,7 @@ class RasterizationSettings:
self.bin_size = bin_size
self.max_faces_per_bin = max_faces_per_bin
self.perspective_correct = perspective_correct
self.cull_backfaces = cull_backfaces
class MeshRasterizer(nn.Module):
@ -122,6 +125,7 @@ class MeshRasterizer(nn.Module):
bin_size=raster_settings.bin_size,
max_faces_per_bin=raster_settings.max_faces_per_bin,
perspective_correct=raster_settings.perspective_correct,
cull_backfaces=raster_settings.cull_backfaces,
)
return Fragments(
pix_to_face=pix_to_face, zbuf=zbuf, bary_coords=bary_coords, dists=dists

View File

@ -21,6 +21,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._simple_blurry_raster(rasterize_meshes_python, device, bin_size=-1)
self._test_behind_camera(rasterize_meshes_python, device, bin_size=-1)
self._test_perspective_correct(rasterize_meshes_python, device, bin_size=-1)
self._test_back_face_culling(rasterize_meshes_python, device, bin_size=-1)
def test_simple_cpu_naive(self):
device = torch.device("cpu")
@ -28,6 +29,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cuda_naive(self):
device = torch.device("cuda:0")
@ -35,6 +37,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._simple_blurry_raster(rasterize_meshes, device, bin_size=0)
self._test_behind_camera(rasterize_meshes, device, bin_size=0)
self._test_perspective_correct(rasterize_meshes, device, bin_size=0)
self._test_back_face_culling(rasterize_meshes, device, bin_size=0)
def test_simple_cuda_binned(self):
device = torch.device("cuda:0")
@ -42,6 +45,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
self._simple_blurry_raster(rasterize_meshes, device, bin_size=5)
self._test_behind_camera(rasterize_meshes, device, bin_size=5)
self._test_perspective_correct(rasterize_meshes, device, bin_size=5)
self._test_back_face_culling(rasterize_meshes, device, bin_size=5)
def test_python_vs_cpu_vs_cuda(self):
torch.manual_seed(231)
@ -377,6 +381,81 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
args = ()
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
def _test_back_face_culling(self, rasterize_meshes_fn, device, bin_size):
# Square based pyramid mesh.
# fmt: off
verts = torch.tensor([
[-0.5, 0.0, 0.5], # noqa: E241 E201 Front right
[ 0.5, 0.0, 0.5], # noqa: E241 E201 Front left
[ 0.5, 0.0, 1.5], # noqa: E241 E201 Back left
[-0.5, 0.0, 1.5], # noqa: E241 E201 Back right
[ 0.0, 1.0, 1.0] # noqa: E241 E201 Top point of pyramid
], dtype=torch.float32, device=device)
faces = torch.tensor([
[2, 1, 0], # noqa: E241 E201 Square base
[3, 2, 0], # noqa: E241 E201 Square base
[1, 0, 4], # noqa: E241 E201 Triangle on front
[2, 4, 3], # noqa: E241 E201 Triangle on back
[3, 4, 0], # noqa: E241 E201 Triangle on left side
[1, 4, 2] # noqa: E241 E201 Triangle on right side
], dtype=torch.int64, device=device)
# fmt: on
mesh = Meshes(verts=[verts], faces=[faces])
kwargs = {
"meshes": mesh,
"image_size": 10,
"faces_per_pixel": 2,
"blur_radius": 0.0,
"perspective_correct": False,
"cull_backfaces": False,
}
if bin_size != -1:
kwargs["bin_size"] = bin_size
# fmt: off
pix_to_face_frontface = torch.tensor([
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, 2, 2, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 2, 2, 2, 2, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201
], dtype=torch.int64, device=device)
pix_to_face_backface = torch.tensor([
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, 3, 3, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, 3, 3, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 3, 3, 3, 3, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, 3, 3, 3, 3, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1], # noqa: E241 E201
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1] # noqa: E241 E201
], dtype=torch.int64, device=device)
# fmt: on
pix_to_face_padded = -torch.ones_like(pix_to_face_frontface)
# Run with and without culling
# Without culling, for k=0, the front face (i.e. face 2) is
# rasterized and for k=1, the back face (i.e. face 3) is
# rasterized.
idx_f, zbuf_f, bary_f, dists_f = rasterize_meshes_fn(**kwargs)
self.assertTrue(torch.all(idx_f[..., 0].squeeze() == pix_to_face_frontface))
self.assertTrue(torch.all(idx_f[..., 1].squeeze() == pix_to_face_backface))
# With culling, for k=0, the front face (i.e. face 2) is
# rasterized and for k=1, there are no faces rasterized
kwargs["cull_backfaces"] = True
idx_t, zbuf_t, bary_t, dists_t = rasterize_meshes_fn(**kwargs)
self.assertTrue(torch.all(idx_t[..., 0].squeeze() == pix_to_face_frontface))
self.assertTrue(torch.all(idx_t[..., 1].squeeze() == pix_to_face_padded))
def _compare_impls(
self,
fn1,