mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
renderer: add support for rendering high dimensional textures for classification/segmentation use cases (#1248)
Summary: For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck. This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way. This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1248 Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture. Reviewed By: bottler Differential Revision: D37764610 Pulled By: d4l3k fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
This commit is contained in:
committed by
Facebook GitHub Bot
parent
aa8b03f31d
commit
8d10ba52b2
BIN
tests/data/test_nd_sphere.png
Normal file
BIN
tests/data/test_nd_sphere.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 56 KiB |
@@ -1236,3 +1236,81 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_nd_sphere(self):
|
||||
"""
|
||||
Test that the render can handle textures with more than 3 channels and
|
||||
not just 3 channel RGB.
|
||||
"""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
C = 5
|
||||
WHITE = ((1.0,) * C,)
|
||||
BLACK = ((0.0,) * C,)
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||
n_verts = feats.shape[1]
|
||||
# make some non-uniform pattern
|
||||
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# No elevation or azimuth rotation
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
diffuse_color=WHITE,
|
||||
specular_color=WHITE,
|
||||
)
|
||||
lights = AmbientLights(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=BLACK[0],
|
||||
)
|
||||
|
||||
# only test HardFlatShader since that's the only one that makes
|
||||
# sense for classification
|
||||
shader = HardFlatShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
self.assertEqual(images.shape[-1], C + 1)
|
||||
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||
|
||||
# grab last 3 color channels
|
||||
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||
filename = "test_nd_sphere.png"
|
||||
|
||||
if DEBUG:
|
||||
debug_filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / debug_filename
|
||||
)
|
||||
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
Reference in New Issue
Block a user