coarse rasterization bug fix

Summary:
Fix a bug which resulted in a rendering artifacts if the image size was not a multiple of 16.
Fix: Revert coarse rasterization to original implementation and only update fine rasterization to reverse the ordering of Y and X axis. This is much simpler than the previous approach!

Additional changes:
- updated mesh rendering end-end tests to check outputs from both naive and coarse to fine rasterization.
- added pointcloud rendering end-end tests

Reviewed By: gkioxari

Differential Revision: D21102725

fbshipit-source-id: 2e7e1b013dd6dd12b3a00b79eb8167deddb2e89a
This commit is contained in:
Nikhila Ravi
2020-04-20 14:51:19 -07:00
committed by Facebook GitHub Bot
parent 1e4749602d
commit 9ef1ee8455
15 changed files with 381 additions and 173 deletions

View File

@@ -1,10 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
from pathlib import Path
from typing import Callable, Optional, Union
import numpy as np
import torch
from PIL import Image
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)
return image[..., :3]
TensorOrArray = Union[torch.Tensor, np.ndarray]

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -896,10 +896,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
torch.ones((1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device)
* -1
)
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([1, 2])
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([1, 2])
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([0, 1])
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
# +Y up, +X left, +Z in
bin_faces = _C._rasterize_meshes_coarse(
@@ -911,7 +911,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
bin_size,
max_faces_per_bin,
)
# Flip x and y axis of output before comparing to expected
bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
self.assertTrue(bin_faces_same.item() == 1)

View File

@@ -434,23 +434,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
def _test_coarse_rasterize(self, device):
#
# Note that +Y is up and +X is left in the diagram below.
#
# (4) |2
# |
# |
# |
# |1
# |
# (1) |
# | (2)
# ____________(0)__(5)___________________
# 2 1 | -1 -2
# |
# (3) |
# |
# |-1
# |
# |2 (4)
# |
# |
# |
# |1
# |
# | (1)
# (2)|
# _________(5)___(0)_______________
# -1 | 1 2
# |
# | (3)
# |
# |-1
#
# Locations of the points are shown by o. The screen bounding box
# is between [-1, 1] in both the x and y directions.
@@ -486,9 +484,9 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
# fit in one chunk. This will the the case for this small example, but
# to properly exercise coordianted writes among multiple chunks we need
# to use a bigger test case.
bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
bin_points_expected[0, 0, 1, :2] = torch.tensor([0, 3])
bin_points_expected[0, 1, 0, 0] = torch.tensor([2])
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
pointclouds = Pointclouds(points=[points])
args = (
@@ -502,4 +500,5 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
)
bin_points = _C._rasterize_points_coarse(*args)
bin_points_same = (bin_points == bin_points_expected).all()
self.assertTrue(bin_points_same.item() == 1)

View File

@@ -9,6 +9,7 @@ from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
@@ -35,15 +36,7 @@ DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
def load_rgb_image(filename, data_dir=DATA_DIR):
filepath = data_dir / filename
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)
return image[..., :3]
class TestRenderingMeshes(unittest.TestCase):
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self, elevated_camera=False):
"""
Test output of phong and gouraud shading matches a reference image using
@@ -81,7 +74,7 @@ class TestRenderingMeshes(unittest.TestCase):
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
@@ -96,14 +89,14 @@ class TestRenderingMeshes(unittest.TestCase):
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
image_ref = load_rgb_image("test_%s" % filename)
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_" % filename
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
self.assertClose(rgb, image_ref, atol=0.05)
########################################################
# Move the light to the +z axis in world space so it is
@@ -124,8 +117,10 @@ class TestRenderingMeshes(unittest.TestCase):
)
# Load reference image
image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
image_ref_phong_dark = load_rgb_image(
"test_simple_sphere_dark%s.png" % postfix, DATA_DIR
)
self.assertClose(rgb, image_ref_phong_dark, atol=0.05)
def test_simple_sphere_elevated_camera(self):
"""
@@ -160,7 +155,7 @@ class TestRenderingMeshes(unittest.TestCase):
R, T = look_at_view_transform(dist, elev, azim)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
# Init shader settings
@@ -179,10 +174,12 @@ class TestRenderingMeshes(unittest.TestCase):
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_meshes)
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
image_ref = load_rgb_image(
"test_simple_sphere_light_%s.png" % name, DATA_DIR
)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
self.assertClose(rgb, image_ref, atol=0.05)
def test_silhouette_with_grad(self):
"""
@@ -200,7 +197,6 @@ class TestRenderingMeshes(unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=80,
bin_size=0,
)
# Init rasterizer settings
@@ -222,7 +218,7 @@ class TestRenderingMeshes(unittest.TestCase):
with Image.open(image_ref_filename) as raw_image_ref:
image_ref = torch.from_numpy(np.array(raw_image_ref))
image_ref = image_ref.to(dtype=torch.float32) / 255.0
self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
self.assertClose(alpha, image_ref, atol=0.055)
# Check grad exist
verts.requires_grad = True
@@ -237,8 +233,8 @@ class TestRenderingMeshes(unittest.TestCase):
The pupils in the eyes of the cow should always be looking to the left.
"""
device = torch.device("cuda:0")
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
obj_filename = obj_dir / "cow_mesh/cow.obj"
# Load mesh + texture
mesh = load_objs_as_meshes([obj_filename], device=device)
@@ -247,7 +243,7 @@ class TestRenderingMeshes(unittest.TestCase):
R, T = look_at_view_transform(2.7, 0, 0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
# Init shader settings
@@ -265,22 +261,26 @@ class TestRenderingMeshes(unittest.TestCase):
lights=lights, cameras=cameras, materials=materials
),
)
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_map_back.png")
image_ref = load_rgb_image("test_texture_map_back.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_back.png"
)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(mesh)
rgb = images[0, ..., :3].squeeze().cpu()
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_back.png"
)
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
# Check grad exists
[verts] = mesh.verts_list()
@@ -299,16 +299,27 @@ class TestRenderingMeshes(unittest.TestCase):
# Move light to the front of the cow in world space
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
images = renderer(mesh, cameras=cameras, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_texture_map_front.png")
image_ref = load_rgb_image("test_texture_map_front.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_front.png"
)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(mesh, cameras=cameras, lights=lights)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_texture_map_front.png"
)
# NOTE some pixels can be flaky and will not lead to
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
self.assertTrue(cond1 or cond2)
#################################
# Add blurring to rasterization
@@ -320,23 +331,26 @@ class TestRenderingMeshes(unittest.TestCase):
image_size=512,
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
faces_per_pixel=100,
bin_size=0,
)
images = renderer(
mesh.clone(),
cameras=cameras,
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
# Load reference image
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
image_ref = load_rgb_image("test_blurry_textured_rendering.png", DATA_DIR)
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(
mesh.clone(),
cameras=cameras,
raster_settings=raster_settings,
blend_params=blend_params,
)
rgb = images[0, ..., :3].squeeze().cpu()
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
if DEBUG:
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
)
self.assertClose(rgb, image_ref, atol=0.05)

173
tests/test_render_points.py Normal file
View File

@@ -0,0 +1,173 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
Sanity checks for output images from the pointcloud renderer.
"""
import unittest
import warnings
from os import path
from pathlib import Path
import numpy as np
import torch
from common_testing import TestCaseMixin, load_rgb_image
from PIL import Image
from pytorch3d.renderer.cameras import (
OpenGLOrthographicCameras,
OpenGLPerspectiveCameras,
look_at_view_transform,
)
from pytorch3d.renderer.points import (
AlphaCompositor,
NormWeightedCompositor,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
)
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
# If DEBUG=True, save out images generated in the tests for debugging.
# All saved images have prefix DEBUG_
DEBUG = False
DATA_DIR = Path(__file__).resolve().parent / "data"
class TestRenderPoints(TestCaseMixin, unittest.TestCase):
def test_simple_sphere(self):
device = torch.device("cuda:0")
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
# Shift vertices to check coordinate frames are correct.
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(
points=verts_padded, features=torch.ones_like(verts_padded)
)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
# Load reference image
filename = "simple_pointcloud_sphere.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(pointclouds)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref)
def test_pointcloud_with_features(self):
device = torch.device("cuda:0")
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz"
# Note, this file is too large to check in to the repo.
# Download the file to run the test locally.
if not path.exists(pointcloud_filename):
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz"
msg = (
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
% (url, pointcloud_filename)
)
warnings.warn(msg)
return True
# Load point cloud
pointcloud = np.load(pointcloud_filename)
verts = torch.Tensor(pointcloud["verts"]).to(device)
rgb_feats = torch.Tensor(pointcloud["rgb"]).to(device)
verts.requires_grad = True
rgb_feats.requires_grad = True
point_cloud = Pointclouds(points=[verts], features=[rgb_feats])
R, T = look_at_view_transform(20, 10, 0)
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)
raster_settings = PointsRasterizationSettings(
# Set image_size so it is not a multiple of 16 (min bin_size)
# in order to confirm that there are no errors in coarse rasterization.
image_size=500,
radius=0.003,
points_per_pixel=10,
)
renderer = PointsRenderer(
rasterizer=PointsRasterizer(
cameras=cameras, raster_settings=raster_settings
),
compositor=AlphaCompositor(),
)
images = renderer(point_cloud)
# Load reference image
filename = "bridge_pointcloud.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
for bin_size in [0, None]:
# Check both naive and coarse to fine produce the same output.
renderer.rasterizer.raster_settings.bin_size = bin_size
images = renderer(point_cloud)
rgb = images[0, ..., :3].squeeze().cpu()
if DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.detach().numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref, atol=0.015)
# Check grad exists.
grad_images = torch.randn_like(images)
images.backward(grad_images)
self.assertIsNotNone(verts.grad)
self.assertIsNotNone(rgb_feats.grad)
def test_simple_sphere_batched(self):
device = torch.device("cuda:0")
sphere_mesh = ico_sphere(1, device)
verts_padded = sphere_mesh.verts_padded()
verts_padded[..., 1] += 0.2
verts_padded[..., 0] += 0.2
pointclouds = Pointclouds(
points=verts_padded, features=torch.ones_like(verts_padded)
)
batch_size = 20
pointclouds = pointclouds.extend(batch_size)
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
raster_settings = PointsRasterizationSettings(
image_size=256, radius=5e-2, points_per_pixel=1
)
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
compositor = NormWeightedCompositor()
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
# Load reference image
filename = "simple_pointcloud_sphere.png"
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
images = renderer(pointclouds)
for i in range(batch_size):
rgb = images[i, ..., :3].squeeze().cpu()
if i == 0 and DEBUG:
filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / filename
)
self.assertClose(rgb, image_ref)