renderer: add support for rendering high dimensional textures for classification/segmentation use cases (#1248)

Summary:
For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck.

This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way.

This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1248

Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture.

Reviewed By: bottler

Differential Revision: D37764610

Pulled By: d4l3k

fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
This commit is contained in:
Tristan Rice 2022-07-11 21:22:45 -07:00 committed by Facebook GitHub Bot
parent aa8b03f31d
commit 8d10ba52b2
4 changed files with 101 additions and 12 deletions

View File

@ -292,6 +292,9 @@ class AmbientLights(TensorProperties):
A light object representing the same color of light everywhere. A light object representing the same color of light everywhere.
By default, this is white, which effectively means lighting is By default, this is white, which effectively means lighting is
not used in rendering. not used in rendering.
Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB.
The ambient_color input determines the number of channels.
""" """
def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None: def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
@ -304,9 +307,11 @@ class AmbientLights(TensorProperties):
device: Device (as str or torch.device) on which the tensors should be located device: Device (as str or torch.device) on which the tensors should be located
The ambient_color if provided, should be The ambient_color if provided, should be
- 3 element tuple/list or list of lists - tuple/list of C-element tuples of floats
- torch tensor of shape (1, 3) - torch tensor of shape (1, C)
- torch tensor of shape (N, 3) - torch tensor of shape (N, C)
where C is the number of channels and N is batch size.
For RGB, C is 3.
""" """
if ambient_color is None: if ambient_color is None:
ambient_color = ((1.0, 1.0, 1.0),) ambient_color = ((1.0, 1.0, 1.0),)
@ -317,10 +322,14 @@ class AmbientLights(TensorProperties):
return super().clone(other) return super().clone(other)
def diffuse(self, normals, points) -> torch.Tensor: def diffuse(self, normals, points) -> torch.Tensor:
return torch.zeros_like(points) return self._zeros_channels(points)
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor: def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
return torch.zeros_like(points) return self._zeros_channels(points)
def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
ch = self.ambient_color.shape[-1]
return torch.zeros(*points.shape[:-1], ch, device=points.device)
def _validate_light_properties(obj) -> None: def _validate_light_properties(obj) -> None:

View File

@ -27,9 +27,9 @@ class Materials(TensorProperties):
) -> None: ) -> None:
""" """
Args: Args:
ambient_color: RGB ambient reflectivity of the material ambient_color: ambient reflectivity of the material
diffuse_color: RGB diffuse reflectivity of the material diffuse_color: diffuse reflectivity of the material
specular_color: RGB specular reflectivity of the material specular_color: specular reflectivity of the material
shininess: The specular exponent for the material. This defines shininess: The specular exponent for the material. This defines
the focus of the specular highlight with a high value the focus of the specular highlight with a high value
resulting in a concentrated highlight. Shininess values resulting in a concentrated highlight. Shininess values
@ -37,7 +37,8 @@ class Materials(TensorProperties):
device: Device (as str or torch.device) on which the tensors should be located device: Device (as str or torch.device) on which the tensors should be located
ambient_color, diffuse_color and specular_color can be of shape ambient_color, diffuse_color and specular_color can be of shape
(1, 3) or (N, 3). shininess can be of shape (1) or (N). (1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,)
or (N,).
The colors and shininess are broadcast against each other so need to The colors and shininess are broadcast against each other so need to
have either the same batch dimension or batch dimension = 1. have either the same batch dimension or batch dimension = 1.
@ -49,11 +50,12 @@ class Materials(TensorProperties):
specular_color=specular_color, specular_color=specular_color,
shininess=shininess, shininess=shininess,
) )
C = self.ambient_color.shape[-1]
for n in ["ambient_color", "diffuse_color", "specular_color"]: for n in ["ambient_color", "diffuse_color", "specular_color"]:
t = getattr(self, n) t = getattr(self, n)
if t.shape[-1] != 3: if t.shape[-1] != C:
msg = "Expected %s to have shape (N, 3); got %r" msg = "Expected %s to have shape (N, %d); got %r"
raise ValueError(msg % (n, t.shape)) raise ValueError(msg % (n, C, t.shape))
if self.shininess.shape != torch.Size([self._N]): if self.shininess.shape != torch.Size([self._N]):
msg = "shininess should have shape (N); got %r" msg = "shininess should have shape (N); got %r"
raise ValueError(msg % repr(self.shininess.shape)) raise ValueError(msg % repr(self.shininess.shape))

Binary file not shown.

After

Width:  |  Height:  |  Size: 56 KiB

View File

@ -1236,3 +1236,81 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR "test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
) )
self.assertClose(rgb, image_ref, atol=0.05) self.assertClose(rgb, image_ref, atol=0.05)
def test_nd_sphere(self):
"""
Test that the render can handle textures with more than 3 channels and
not just 3 channel RGB.
"""
torch.manual_seed(1)
device = torch.device("cuda:0")
C = 5
WHITE = ((1.0,) * C,)
BLACK = ((0.0,) * C,)
# Init mesh
sphere_mesh = ico_sphere(5, device)
verts_padded = sphere_mesh.verts_padded()
faces_padded = sphere_mesh.faces_padded()
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
n_verts = feats.shape[1]
# make some non-uniform pattern
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
textures = TexturesVertex(verts_features=feats)
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
# No elevation or azimuth rotation
R, T = look_at_view_transform(2.7, 0.0, 0.0)
cameras = PerspectiveCameras(device=device, R=R, T=T)
# Init shader settings
materials = Materials(
device=device,
ambient_color=WHITE,
diffuse_color=WHITE,
specular_color=WHITE,
)
lights = AmbientLights(
device=device,
ambient_color=WHITE,
)
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
raster_settings = RasterizationSettings(
image_size=512, blur_radius=0.0, faces_per_pixel=1
)
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
blend_params = BlendParams(
1e-4,
1e-4,
background_color=BLACK[0],
)
# only test HardFlatShader since that's the only one that makes
# sense for classification
shader = HardFlatShader(
lights=lights,
cameras=cameras,
materials=materials,
blend_params=blend_params,
)
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
images = renderer(sphere_mesh)
self.assertEqual(images.shape[-1], C + 1)
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
# grab last 3 color channels
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
filename = "test_nd_sphere.png"
if DEBUG:
debug_filename = "DEBUG_%s" % filename
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
DATA_DIR / debug_filename
)
image_ref = load_rgb_image(filename, DATA_DIR)
self.assertClose(rgb, image_ref, atol=0.05)