coarse rasterization bug fix

Summary:
Fix a bug which resulted in a rendering artifacts if the image size was not a multiple of 16.
Fix: Revert coarse rasterization to original implementation and only update fine rasterization to reverse the ordering of Y and X axis. This is much simpler than the previous approach!

Additional changes:
- updated mesh rendering end-end tests to check outputs from both naive and coarse to fine rasterization.
- added pointcloud rendering end-end tests

Reviewed By: gkioxari

Differential Revision: D21102725

fbshipit-source-id: 2e7e1b013dd6dd12b3a00b79eb8167deddb2e89a
This commit is contained in:
Nikhila Ravi 2020-04-20 14:51:19 -07:00 committed by Facebook GitHub Bot
parent 1e4749602d
commit 9ef1ee8455
15 changed files with 381 additions and 173 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

After

Width:  |  Height:  |  Size: 64 KiB

File diff suppressed because one or more lines are too long

View File

@ -556,18 +556,16 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
// PixToNdc gives the location of the center of each pixel, so we
// need to add/subtract a half pixel to get the true extent of the bin.
// Reverse ordering of Y axis so that +Y is upwards in the image.
const int yidx = num_bins - by;
const float bin_y_max = PixToNdc(yidx * bin_size - 1, H) + half_pix;
const float bin_y_min = PixToNdc((yidx - 1) * bin_size, H) - half_pix;
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
for (int bx = 0; bx < num_bins; ++bx) {
// X coordinate of the left and right of the bin.
// Reverse ordering of x axis so that +X is left.
const int xidx = num_bins - bx;
const float bin_x_max = PixToNdc(xidx * bin_size - 1, W) + half_pix;
const float bin_x_min = PixToNdc((xidx - 1) * bin_size, W) - half_pix;
const float bin_x_max =
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
if (y_overlap && x_overlap) {
@ -629,6 +627,7 @@ torch::Tensor RasterizeMeshesCoarseCuda(
const int N = num_faces_per_mesh.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
const int M = max_faces_per_bin;
if (num_bins >= 22) {
std::stringstream ss;
ss << "Got " << num_bins << "; that's too many!";
@ -702,13 +701,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
if (yi >= H || xi >= W)
continue;
// Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const float xf = PixToNdc(xidx, W);
const float yf = PixToNdc(yidx, H);
const float xf = PixToNdc(xi, W);
const float yf = PixToNdc(yi, H);
const float2 pxy = make_float2(xf, yf);
// This part looks like the naive rasterization kernel, except we use
@ -743,7 +737,12 @@ __global__ void RasterizeMeshesFineCudaKernel(
// output for the current pixel.
// TODO: make sorting an option as only top k is needed, not sorted values.
BubbleSort(q, q_size);
const int pix_idx = n * H * W * K + yi * H * K + xi * K;
// Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = H - 1 - yi;
const int xidx = W - 1 - xi;
const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;
for (int k = 0; k < q_size; k++) {
face_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;

View File

@ -430,13 +430,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
const int face_stop_idx =
(face_start_idx + num_faces_per_mesh[n].item().to<int32_t>());
float bin_y_max = 1.0f;
float bin_y_min = bin_y_max - bin_width;
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
// Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < BH; ++by) {
float bin_x_max = 1.0f;
float bin_x_min = bin_x_max - bin_width;
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
// Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < BW; ++bx) {
@ -473,13 +473,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
}
}
// Shift the bin to the left for the next loop iteration.
bin_x_max = bin_x_min;
bin_x_min = bin_x_min - bin_width;
// Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
}
// Shift the bin down for the next loop iteration.
bin_y_max = bin_y_min;
bin_y_min = bin_y_min - bin_width;
// Shift the bin down for the next loop iteration
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
}
}
return bin_faces;

View File

@ -95,7 +95,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
const int n = i / (S * S); // Batch index
const int pix_idx = i % (S * S);
// Reverse ordering of X and Y axes.
// Reverse ordering of the X and Y axis as the camera coordinates
// assume that +Y is pointing up and +X is pointing left.
const int yi = S - 1 - pix_idx / S;
const int xi = S - 1 - pix_idx % S;
@ -260,23 +261,20 @@ __global__ void RasterizePointsCoarseCudaKernel(
// Get y extent for the bin. PixToNdc gives us the location of
// the center of each pixel, so we need to add/subtract a half
// pixel to get the true extent of the bin.
// Reverse ordering of Y axis so that +Y is upwards in the image.
const int yidx = num_bins - by;
const float bin_y_max = PixToNdc(yidx * bin_size - 1, S) + half_pix;
const float bin_y_min = PixToNdc((yidx - 1) * bin_size, S) - half_pix;
const float by0 = PixToNdc(by * bin_size, S) - half_pix;
const float by1 = PixToNdc((by + 1) * bin_size - 1, S) + half_pix;
const bool y_overlap = (py0 <= by1) && (by0 <= py1);
const bool y_overlap = (py0 <= bin_y_max) && (bin_y_min <= py1);
if (!y_overlap) {
continue;
}
for (int bx = 0; bx < num_bins; ++bx) {
// Get x extent for the bin; again we need to adjust the
// output of PixToNdc by half a pixel.
// Reverse ordering of x axis so that +X is left.
const int xidx = num_bins - bx;
const float bin_x_max = PixToNdc(xidx * bin_size - 1, S) + half_pix;
const float bin_x_min = PixToNdc((xidx - 1) * bin_size, S) - half_pix;
const bool x_overlap = (px0 <= bin_x_max) && (bin_x_min <= px1);
const float bx0 = PixToNdc(bx * bin_size, S) - half_pix;
const float bx1 = PixToNdc((bx + 1) * bin_size - 1, S) + half_pix;
const bool x_overlap = (px0 <= bx1) && (bx0 <= px1);
if (x_overlap) {
binmask.set(by, bx, p);
}
@ -330,6 +328,7 @@ torch::Tensor RasterizePointsCoarseCuda(
const int N = num_points_per_cloud.size(0);
const int num_bins = 1 + (image_size - 1) / bin_size; // divide round up
const int M = max_points_per_bin;
if (points.ndimension() != 2 || points.size(1) != 3) {
AT_ERROR("points must have dimensions (num_points, 3)");
}
@ -346,6 +345,7 @@ torch::Tensor RasterizePointsCoarseCuda(
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
const size_t blocks = 64;
const size_t threads = 512;
RasterizePointsCoarseCudaKernel<<<blocks, threads, shared_size>>>(
points.contiguous().data_ptr<float>(),
cloud_to_packed_first_idx.contiguous().data_ptr<int64_t>(),
@ -372,7 +372,7 @@ __global__ void RasterizePointsFineCudaKernel(
const float radius,
const int bin_size,
const int N,
const int B,
const int B, // num_bins
const int M,
const int S,
const int K,
@ -397,19 +397,15 @@ __global__ void RasterizePointsFineCudaKernel(
i %= B * bin_size * bin_size;
const int bx = i / (bin_size * bin_size);
i %= bin_size * bin_size;
const int yi = i / bin_size + by * bin_size;
const int xi = i % bin_size + bx * bin_size;
if (yi >= S || xi >= S)
continue;
// Reverse ordering of the X and Y axis so that
// in the image +Y is pointing up and +X is pointing left.
const int yidx = S - 1 - yi;
const int xidx = S - 1 - xi;
const float xf = PixToNdc(xidx, S);
const float yf = PixToNdc(yidx, S);
const float xf = PixToNdc(xi, S);
const float yf = PixToNdc(yi, S);
// This part looks like the naive rasterization kernel, except we use
// bin_points to only look at a subset of points already known to fall
@ -431,7 +427,13 @@ __global__ void RasterizePointsFineCudaKernel(
// Now we've looked at all the points for this bin, so we can write
// output for the current pixel.
BubbleSort(q, q_size);
const int pix_idx = n * S * S * K + yi * S * K + xi * K;
// Reverse ordering of the X and Y axis as the camera coordinates
// assume that +Y is pointing up and +X is pointing left.
const int yidx = S - 1 - yi;
const int xidx = S - 1 - xi;
const int pix_idx = n * S * S * K + yidx * S * K + xidx * K;
for (int k = 0; k < q_size; ++k) {
point_idxs[pix_idx + k] = q[k].idx;
zbuf[pix_idx + k] = q[k].z;
@ -448,7 +450,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFineCuda(
const int bin_size,
const int points_per_pixel) {
const int N = bin_points.size(0);
const int B = bin_points.size(1);
const int B = bin_points.size(1); // num_bins
const int M = bin_points.size(3);
const int S = image_size;
const int K = points_per_pixel;

View File

@ -125,13 +125,13 @@ torch::Tensor RasterizePointsCoarseCpu(
const int point_stop_idx =
(point_start_idx + num_points_per_cloud[n].item().to<int32_t>());
float bin_y_max = 1.0f;
float bin_y_min = bin_y_max - bin_width;
float bin_y_min = -1.0f;
float bin_y_max = bin_y_min + bin_width;
// Iterate through the horizontal bins from top to bottom.
for (int by = 0; by < B; by++) {
float bin_x_max = 1.0f;
float bin_x_min = bin_x_max - bin_width;
float bin_x_min = -1.0f;
float bin_x_max = bin_x_min + bin_width;
// Iterate through bins on this horizontal line, left to right.
for (int bx = 0; bx < B; bx++) {
@ -166,13 +166,13 @@ torch::Tensor RasterizePointsCoarseCpu(
// Record the number of points found in this bin
points_per_bin_a[n][by][bx] = points_hit;
// Shift the bin to the left for the next loop iteration.
bin_x_max = bin_x_min;
bin_x_min = bin_x_min - bin_width;
// Shift the bin to the right for the next loop iteration
bin_x_min = bin_x_max;
bin_x_max = bin_x_min + bin_width;
}
// Shift the bin down for the next loop iteration.
bin_y_max = bin_y_min;
bin_y_min = bin_y_min - bin_width;
// Shift the bin down for the next loop iteration
bin_y_min = bin_y_max;
bin_y_max = bin_y_min + bin_width;
}
}
return bin_points;

View File

@ -9,7 +9,7 @@ from pytorch3d import _C
# TODO make the epsilon user configurable
kEpsilon = 1e-30
kEpsilon = 1e-8
def rasterize_meshes(

View File

@ -19,12 +19,28 @@ class PointFragments(NamedTuple):
# Class to store the point rasterization params with defaults
class PointsRasterizationSettings(NamedTuple):
image_size: int = 256
radius: float = 0.01
points_per_pixel: int = 8
bin_size: Optional[int] = None
max_points_per_bin: Optional[int] = None
class PointsRasterizationSettings:
__slots__ = [
"image_size",
"radius",
"points_per_pixel",
"bin_size",
"max_points_per_bin",
]
def __init__(
self,
image_size: int = 256,
radius: float = 0.01,
points_per_pixel: int = 8,
bin_size: Optional[int] = None,
max_points_per_bin: Optional[int] = None,
):
self.image_size = image_size
self.radius = radius
self.points_per_pixel = points_per_pixel
self.bin_size = bin_size
self.max_points_per_bin = max_points_per_bin
class PointsRasterizer(nn.Module):

View File

@ -1,10 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from pathlib import Path
from typing import Callable, Optional, Union
import numpy as np
import torch
from PIL import Image
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)
return image[..., :3]
TensorOrArray = Union[torch.Tensor, np.ndarray]

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@ -896,10 +896,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
torch.ones((1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device)
* -1
)
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([1, 2])
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([1, 2])
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([0, 1])
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
# +Y up, +X left, +Z in
bin_faces = _C._rasterize_meshes_coarse(
@ -911,7 +911,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
bin_size,
max_faces_per_bin,
)
# Flip x and y axis of output before comparing to expected
bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
self.assertTrue(bin_faces_same.item() == 1)

View File

@ -434,23 +434,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
def _test_coarse_rasterize(self, device):
#
# Note that +Y is up and +X is left in the diagram below.
#
# (4) |2
# |
# |
# |
# |1
# |
# (1) |
# | (2)
# ____________(0)__(5)___________________
# 2 1 | -1 -2
# |
# (3) |
# |
# |-1
# |
# |2 (4)
# |
# |
# |
# |1
# |
# | (1)
# (2)|
# _________(5)___(0)_______________
# -1 | 1 2
# |
# | (3)
# |
# |-1
#
# Locations of the points are shown by o. The screen bounding box
# is between [-1, 1] in both the x and y directions.
@ -486,9 +484,9 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
# fit in one chunk. This will the the case for this small example, but
# to properly exercise coordianted writes among multiple chunks we need
# to use a bigger test case.
bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
bin_points_expected[0, 0, 1, :2] = torch.tensor([0, 3])
bin_points_expected[0, 1, 0, 0] = torch.tensor([2])
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
pointclouds = Pointclouds(points=[points])
args = (
@ -502,4 +500,5 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
)
bin_points = _C._rasterize_points_coarse(*args)
bin_points_same = (bin_points == bin_points_expected).all()
self.assertTrue(bin_points_same.item() == 1)

View File

@ -9,6 +9,7 @@ from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
@ -35,15 +36,7 @@ DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
def load_rgb_image(filename, data_dir=DATA_DIR):
filepath = data_dir / filename
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)
return image[..., :3]
class TestRenderingMeshes(unittest.TestCase):
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False):
"""
Test output of phong and gouraud shading matches a reference image using
@ -81,7 +74,7 @@ class TestRenderingMeshes(unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
@ -96,14 +89,14 @@ class TestRenderingMeshes(unittest.TestCase):
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
image_ref = load_rgb_image("test_%s" % filename)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_" % filename
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
self.assertClose(rgb, image_ref, atol=0.05)
########################################################
# Move the light to the +z axis in world space so it is
@ -124,8 +117,10 @@ class TestRenderingMeshes(unittest.TestCase):
)
# Load reference image
image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
image_ref_phong_dark = load_rgb_image(
"test_simple_sphere_dark%s.png" % postfix, DATA_DIR
)
self.assertClose(rgb, image_ref_phong_dark, atol=0.05)
def test_simple_sphere_elevated_camera(self):
"""
@ -160,7 +155,7 @@ class TestRenderingMeshes(unittest.TestCase):
R, T = look_at_view_transform(dist, elev, azim)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
# Init shader settings
@ -179,10 +174,12 @@ class TestRenderingMeshes(unittest.TestCase):
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes)
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
image_ref = load_rgb_image(
"test_simple_sphere_light_%s.png" % name, DATA_DIR
)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
self.assertClose(rgb, image_ref, atol=0.05)
def test_silhouette_with_grad(self):
"""
@ -200,7 +197,6 @@ class TestRenderingMeshes(unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=80,
bin_size=0,
)
# Init rasterizer settings
@ -222,7 +218,7 @@ class TestRenderingMeshes(unittest.TestCase):
with Image.open(image_ref_filename) as raw_image_ref:
image_ref = torch.from_numpy(np.array(raw_image_ref))
image_ref = image_ref.to(dtype=torch.float32) / 255.0
self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
self.assertClose(alpha, image_ref, atol=0.055)
# Check grad exist
verts.requires_grad = True
@ -237,8 +233,8 @@ class TestRenderingMeshes(unittest.TestCase):
The pupils in the eyes of the cow should always be looking to the left.
"""
device = torch.device("cuda:0")
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture
mesh = load_objs_as_meshes([obj_filename], device=device)
@ -247,7 +243,7 @@ class TestRenderingMeshes(unittest.TestCase):
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
# Init shader settings
@ -265,22 +261,26 @@ class TestRenderingMeshes(unittest.TestCase):
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_map_back.png")
image_ref = load_rgb_image("test_texture_map_back.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_back.png"
)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_back.png"
)
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
# Check grad exists
[verts] = mesh.verts_list()
@ -299,16 +299,27 @@ class TestRenderingMeshes(unittest.TestCase):
# Move light to the front of the cow in world space
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
images = renderer(mesh, cameras=cameras, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_map_front.png")
image_ref = load_rgb_image("test_texture_map_front.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_front.png"
)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(mesh, cameras=cameras, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_front.png"
)
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
#################################
# Add blurring to rasterization
@ -320,23 +331,26 @@ class TestRenderingMeshes(unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
bin_size=0,
)
images = renderer(
mesh.clone(),
cameras=cameras,
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
image_ref = load_rgb_image("test_blurry_textured_rendering.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(
mesh.clone(),
cameras=cameras,
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
)
self.assertClose(rgb, image_ref, atol=0.05)

173
tests/test_render_points.py Normal file
View File

@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for output images from the pointcloud renderer.
"""
import unittest
import warnings
from os import path
from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.renderer.cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
look_at_view_transform,
)
from pytorch3d.renderer.points import (
AlphaCompositor,
NormWeightedCompositor,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
)
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestRenderPoints(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self):
device = torch.device("cuda:0")
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
# Shift vertices to check coordinate frames are correct.
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(
points=verts_padded, features=torch.ones_like(verts_padded)
)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
# Load reference image
filename = "simple_pointcloud_sphere.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(pointclouds)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref)
def test_pointcloud_with_features(self):
device = torch.device("cuda:0")
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz"
# Note, this file is too large to check in to the repo.
# Download the file to run the test locally.
if not path.exists(pointcloud_filename):
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz"
msg = (
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
% (url, pointcloud_filename)
)
warnings.warn(msg)
return True
# Load point cloud
pointcloud = np.load(pointcloud_filename)
verts = torch.Tensor(pointcloud["verts"]).to(device)
rgb_feats = torch.Tensor(pointcloud["rgb"]).to(device)
verts.requires_grad = True
rgb_feats.requires_grad = True
point_cloud = Pointclouds(points=[verts], features=[rgb_feats])
R, T = look_at_view_transform(20, 10, 0)
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)
raster_settings = PointsRasterizationSettings(
# Set image_size so it is not a multiple of 16 (min bin_size)
# in order to confirm that there are no errors in coarse rasterization.
image_size=500,
radius=0.003,
points_per_pixel=10,
)
renderer = PointsRenderer(
rasterizer=PointsRasterizer(
cameras=cameras, raster_settings=raster_settings
),
compositor=AlphaCompositor(),
)
images = renderer(point_cloud)
# Load reference image
filename = "bridge_pointcloud.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(point_cloud)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.detach().numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.015)
# Check grad exists.
grad_images = torch.randn_like(images)
images.backward(grad_images)
self.assertIsNotNone(verts.grad)
self.assertIsNotNone(rgb_feats.grad)
def test_simple_sphere_batched(self):
device = torch.device("cuda:0")
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(
points=verts_padded, features=torch.ones_like(verts_padded)
)
batch_size = 20
pointclouds = pointclouds.extend(batch_size)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
# Load reference image
filename = "simple_pointcloud_sphere.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
images = renderer(pointclouds)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref)