mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
coarse rasterization bug fix
Summary: Fix a bug which resulted in a rendering artifacts if the image size was not a multiple of 16. Fix: Revert coarse rasterization to original implementation and only update fine rasterization to reverse the ordering of Y and X axis. This is much simpler than the previous approach! Additional changes: - updated mesh rendering end-end tests to check outputs from both naive and coarse to fine rasterization. - added pointcloud rendering end-end tests Reviewed By: gkioxari Differential Revision: D21102725 fbshipit-source-id: 2e7e1b013dd6dd12b3a00b79eb8167deddb2e89a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
1e4749602d
commit
9ef1ee8455
@@ -1,10 +1,20 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def load_rgb_image(filename: str, data_dir: Union[str, Path]):
|
||||
filepath = data_dir / filename
|
||||
with Image.open(filepath) as raw_image:
|
||||
image = torch.from_numpy(np.array(raw_image) / 255.0)
|
||||
image = image.to(dtype=torch.float32)
|
||||
return image[..., :3]
|
||||
|
||||
|
||||
TensorOrArray = Union[torch.Tensor, np.ndarray]
|
||||
|
||||
BIN
tests/data/test_bridge_pointcloud.png
Normal file
BIN
tests/data/test_bridge_pointcloud.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 74 KiB |
BIN
tests/data/test_simple_pointcloud_sphere.png
Normal file
BIN
tests/data/test_simple_pointcloud_sphere.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.1 KiB |
@@ -896,10 +896,10 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
torch.ones((1, 2, 2, max_faces_per_bin), dtype=torch.int32, device=device)
|
||||
* -1
|
||||
)
|
||||
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
|
||||
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([1, 2])
|
||||
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([0, 1])
|
||||
bin_faces_expected[0, 1, 1, 0] = torch.tensor([1])
|
||||
bin_faces_expected[0, 0, 1, 0:2] = torch.tensor([1, 2])
|
||||
bin_faces_expected[0, 1, 0, 0:2] = torch.tensor([0, 1])
|
||||
bin_faces_expected[0, 0, 0, 0] = torch.tensor([1])
|
||||
|
||||
# +Y up, +X left, +Z in
|
||||
bin_faces = _C._rasterize_meshes_coarse(
|
||||
@@ -911,7 +911,7 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
||||
bin_size,
|
||||
max_faces_per_bin,
|
||||
)
|
||||
# Flip x and y axis of output before comparing to expected
|
||||
|
||||
bin_faces_same = (bin_faces.squeeze() == bin_faces_expected).all()
|
||||
self.assertTrue(bin_faces_same.item() == 1)
|
||||
|
||||
|
||||
@@ -434,23 +434,21 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
def _test_coarse_rasterize(self, device):
|
||||
#
|
||||
# Note that +Y is up and +X is left in the diagram below.
|
||||
#
|
||||
# (4) |2
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# |1
|
||||
# |
|
||||
# (1) |
|
||||
# | (2)
|
||||
# ____________(0)__(5)___________________
|
||||
# 2 1 | -1 -2
|
||||
# |
|
||||
# (3) |
|
||||
# |
|
||||
# |-1
|
||||
# |
|
||||
# |2 (4)
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# |1
|
||||
# |
|
||||
# | (1)
|
||||
# (2)|
|
||||
# _________(5)___(0)_______________
|
||||
# -1 | 1 2
|
||||
# |
|
||||
# | (3)
|
||||
# |
|
||||
# |-1
|
||||
#
|
||||
# Locations of the points are shown by o. The screen bounding box
|
||||
# is between [-1, 1] in both the x and y directions.
|
||||
@@ -486,9 +484,9 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
# fit in one chunk. This will the the case for this small example, but
|
||||
# to properly exercise coordianted writes among multiple chunks we need
|
||||
# to use a bigger test case.
|
||||
bin_points_expected[0, 1, 0, :2] = torch.tensor([0, 3])
|
||||
bin_points_expected[0, 0, 1, 0] = torch.tensor([2])
|
||||
bin_points_expected[0, 0, 0, :2] = torch.tensor([0, 1])
|
||||
bin_points_expected[0, 0, 1, :2] = torch.tensor([0, 3])
|
||||
bin_points_expected[0, 1, 0, 0] = torch.tensor([2])
|
||||
bin_points_expected[0, 1, 1, :2] = torch.tensor([0, 1])
|
||||
|
||||
pointclouds = Pointclouds(points=[points])
|
||||
args = (
|
||||
@@ -502,4 +500,5 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
bin_points = _C._rasterize_points_coarse(*args)
|
||||
bin_points_same = (bin_points == bin_points_expected).all()
|
||||
|
||||
self.assertTrue(bin_points_same.item() == 1)
|
||||
|
||||
@@ -9,6 +9,7 @@ from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
|
||||
@@ -35,15 +36,7 @@ DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
def load_rgb_image(filename, data_dir=DATA_DIR):
|
||||
filepath = data_dir / filename
|
||||
with Image.open(filepath) as raw_image:
|
||||
image = torch.from_numpy(np.array(raw_image) / 255.0)
|
||||
image = image.to(dtype=torch.float32)
|
||||
return image[..., :3]
|
||||
|
||||
|
||||
class TestRenderingMeshes(unittest.TestCase):
|
||||
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
def test_simple_sphere(self, elevated_camera=False):
|
||||
"""
|
||||
Test output of phong and gouraud shading matches a reference image using
|
||||
@@ -81,7 +74,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
|
||||
@@ -96,14 +89,14 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||
image_ref = load_rgb_image("test_%s" % filename)
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_" % filename
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
########################################################
|
||||
# Move the light to the +z axis in world space so it is
|
||||
@@ -124,8 +117,10 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Load reference image
|
||||
image_ref_phong_dark = load_rgb_image("test_simple_sphere_dark%s.png" % postfix)
|
||||
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
|
||||
image_ref_phong_dark = load_rgb_image(
|
||||
"test_simple_sphere_dark%s.png" % postfix, DATA_DIR
|
||||
)
|
||||
self.assertClose(rgb, image_ref_phong_dark, atol=0.05)
|
||||
|
||||
def test_simple_sphere_elevated_camera(self):
|
||||
"""
|
||||
@@ -160,7 +155,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
R, T = look_at_view_transform(dist, elev, azim)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
@@ -179,10 +174,12 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_meshes)
|
||||
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
|
||||
image_ref = load_rgb_image(
|
||||
"test_simple_sphere_light_%s.png" % name, DATA_DIR
|
||||
)
|
||||
for i in range(batch_size):
|
||||
rgb = images[i, ..., :3].squeeze().cpu()
|
||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_silhouette_with_grad(self):
|
||||
"""
|
||||
@@ -200,7 +197,6 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
image_size=512,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||
faces_per_pixel=80,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
@@ -222,7 +218,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
with Image.open(image_ref_filename) as raw_image_ref:
|
||||
image_ref = torch.from_numpy(np.array(raw_image_ref))
|
||||
image_ref = image_ref.to(dtype=torch.float32) / 255.0
|
||||
self.assertTrue(torch.allclose(alpha, image_ref, atol=0.055))
|
||||
self.assertClose(alpha, image_ref, atol=0.055)
|
||||
|
||||
# Check grad exist
|
||||
verts.requires_grad = True
|
||||
@@ -237,8 +233,8 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
The pupils in the eyes of the cow should always be looking to the left.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = DATA_DIR / "cow_mesh/cow.obj"
|
||||
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = obj_dir / "cow_mesh/cow.obj"
|
||||
|
||||
# Load mesh + texture
|
||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||
@@ -247,7 +243,7 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
R, T = look_at_view_transform(2.7, 0, 0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
@@ -265,22 +261,26 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
lights=lights, cameras=cameras, materials=materials
|
||||
),
|
||||
)
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_texture_map_back.png")
|
||||
image_ref = load_rgb_image("test_texture_map_back.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_map_back.png"
|
||||
)
|
||||
for bin_size in [0, None]:
|
||||
# Check both naive and coarse to fine produce the same output.
|
||||
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# NOTE some pixels can be flaky and will not lead to
|
||||
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
||||
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
||||
self.assertTrue(cond1 or cond2)
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_map_back.png"
|
||||
)
|
||||
|
||||
# NOTE some pixels can be flaky and will not lead to
|
||||
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
||||
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
||||
self.assertTrue(cond1 or cond2)
|
||||
|
||||
# Check grad exists
|
||||
[verts] = mesh.verts_list()
|
||||
@@ -299,16 +299,27 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
|
||||
# Move light to the front of the cow in world space
|
||||
lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None]
|
||||
images = renderer(mesh, cameras=cameras, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_texture_map_front.png")
|
||||
image_ref = load_rgb_image("test_texture_map_front.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_map_front.png"
|
||||
)
|
||||
for bin_size in [0, None]:
|
||||
# Check both naive and coarse to fine produce the same output.
|
||||
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||
|
||||
images = renderer(mesh, cameras=cameras, lights=lights)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_map_front.png"
|
||||
)
|
||||
|
||||
# NOTE some pixels can be flaky and will not lead to
|
||||
# `cond1` being true. Add `cond2` and check `cond1 or cond2`
|
||||
cond1 = torch.allclose(rgb, image_ref, atol=0.05)
|
||||
cond2 = ((rgb - image_ref).abs() > 0.05).sum() < 5
|
||||
self.assertTrue(cond1 or cond2)
|
||||
|
||||
#################################
|
||||
# Add blurring to rasterization
|
||||
@@ -320,23 +331,26 @@ class TestRenderingMeshes(unittest.TestCase):
|
||||
image_size=512,
|
||||
blur_radius=np.log(1.0 / 1e-4 - 1.0) * blend_params.sigma,
|
||||
faces_per_pixel=100,
|
||||
bin_size=0,
|
||||
)
|
||||
|
||||
images = renderer(
|
||||
mesh.clone(),
|
||||
cameras=cameras,
|
||||
raster_settings=raster_settings,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_blurry_textured_rendering.png")
|
||||
image_ref = load_rgb_image("test_blurry_textured_rendering.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
||||
for bin_size in [0, None]:
|
||||
# Check both naive and coarse to fine produce the same output.
|
||||
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||
|
||||
images = renderer(
|
||||
mesh.clone(),
|
||||
cameras=cameras,
|
||||
raster_settings=raster_settings,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
|
||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_blurry_textured_rendering.png"
|
||||
)
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
173
tests/test_render_points.py
Normal file
173
tests/test_render_points.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
"""
|
||||
Sanity checks for output images from the pointcloud renderer.
|
||||
"""
|
||||
import unittest
|
||||
import warnings
|
||||
from os import path
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from common_testing import TestCaseMixin, load_rgb_image
|
||||
from PIL import Image
|
||||
from pytorch3d.renderer.cameras import (
|
||||
OpenGLOrthographicCameras,
|
||||
OpenGLPerspectiveCameras,
|
||||
look_at_view_transform,
|
||||
)
|
||||
from pytorch3d.renderer.points import (
|
||||
AlphaCompositor,
|
||||
NormWeightedCompositor,
|
||||
PointsRasterizationSettings,
|
||||
PointsRasterizer,
|
||||
PointsRenderer,
|
||||
)
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
# All saved images have prefix DEBUG_
|
||||
DEBUG = False
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
||||
def test_simple_sphere(self):
|
||||
device = torch.device("cuda:0")
|
||||
sphere_mesh = ico_sphere(1, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
# Shift vertices to check coordinate frames are correct.
|
||||
verts_padded[..., 1] += 0.2
|
||||
verts_padded[..., 0] += 0.2
|
||||
pointclouds = Pointclouds(
|
||||
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||
)
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
raster_settings = PointsRasterizationSettings(
|
||||
image_size=256, radius=5e-2, points_per_pixel=1
|
||||
)
|
||||
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
compositor = NormWeightedCompositor()
|
||||
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
||||
|
||||
# Load reference image
|
||||
filename = "simple_pointcloud_sphere.png"
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
|
||||
for bin_size in [0, None]:
|
||||
# Check both naive and coarse to fine produce the same output.
|
||||
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||
images = renderer(pointclouds)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref)
|
||||
|
||||
def test_pointcloud_with_features(self):
|
||||
device = torch.device("cuda:0")
|
||||
file_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
pointcloud_filename = file_dir / "PittsburghBridge/pointcloud.npz"
|
||||
|
||||
# Note, this file is too large to check in to the repo.
|
||||
# Download the file to run the test locally.
|
||||
if not path.exists(pointcloud_filename):
|
||||
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz"
|
||||
msg = (
|
||||
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
|
||||
% (url, pointcloud_filename)
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
# Load point cloud
|
||||
pointcloud = np.load(pointcloud_filename)
|
||||
verts = torch.Tensor(pointcloud["verts"]).to(device)
|
||||
rgb_feats = torch.Tensor(pointcloud["rgb"]).to(device)
|
||||
|
||||
verts.requires_grad = True
|
||||
rgb_feats.requires_grad = True
|
||||
point_cloud = Pointclouds(points=[verts], features=[rgb_feats])
|
||||
|
||||
R, T = look_at_view_transform(20, 10, 0)
|
||||
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)
|
||||
|
||||
raster_settings = PointsRasterizationSettings(
|
||||
# Set image_size so it is not a multiple of 16 (min bin_size)
|
||||
# in order to confirm that there are no errors in coarse rasterization.
|
||||
image_size=500,
|
||||
radius=0.003,
|
||||
points_per_pixel=10,
|
||||
)
|
||||
|
||||
renderer = PointsRenderer(
|
||||
rasterizer=PointsRasterizer(
|
||||
cameras=cameras, raster_settings=raster_settings
|
||||
),
|
||||
compositor=AlphaCompositor(),
|
||||
)
|
||||
|
||||
images = renderer(point_cloud)
|
||||
|
||||
# Load reference image
|
||||
filename = "bridge_pointcloud.png"
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
|
||||
for bin_size in [0, None]:
|
||||
# Check both naive and coarse to fine produce the same output.
|
||||
renderer.rasterizer.raster_settings.bin_size = bin_size
|
||||
images = renderer(point_cloud)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.detach().numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.015)
|
||||
|
||||
# Check grad exists.
|
||||
grad_images = torch.randn_like(images)
|
||||
images.backward(grad_images)
|
||||
self.assertIsNotNone(verts.grad)
|
||||
self.assertIsNotNone(rgb_feats.grad)
|
||||
|
||||
def test_simple_sphere_batched(self):
|
||||
device = torch.device("cuda:0")
|
||||
sphere_mesh = ico_sphere(1, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
verts_padded[..., 1] += 0.2
|
||||
verts_padded[..., 0] += 0.2
|
||||
pointclouds = Pointclouds(
|
||||
points=verts_padded, features=torch.ones_like(verts_padded)
|
||||
)
|
||||
batch_size = 20
|
||||
pointclouds = pointclouds.extend(batch_size)
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
raster_settings = PointsRasterizationSettings(
|
||||
image_size=256, radius=5e-2, points_per_pixel=1
|
||||
)
|
||||
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
compositor = NormWeightedCompositor()
|
||||
renderer = PointsRenderer(rasterizer=rasterizer, compositor=compositor)
|
||||
|
||||
# Load reference image
|
||||
filename = "simple_pointcloud_sphere.png"
|
||||
image_ref = load_rgb_image("test_%s" % filename, DATA_DIR)
|
||||
|
||||
images = renderer(pointclouds)
|
||||
for i in range(batch_size):
|
||||
rgb = images[i, ..., :3].squeeze().cpu()
|
||||
if i == 0 and DEBUG:
|
||||
filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / filename
|
||||
)
|
||||
self.assertClose(rgb, image_ref)
|
||||
Reference in New Issue
Block a user