mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
renderer: add support for rendering high dimensional textures for classification/segmentation use cases (#1248)
Summary: For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck. This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way. This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1248 Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture. Reviewed By: bottler Differential Revision: D37764610 Pulled By: d4l3k fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
This commit is contained in:
parent
aa8b03f31d
commit
8d10ba52b2
@ -292,6 +292,9 @@ class AmbientLights(TensorProperties):
|
||||
A light object representing the same color of light everywhere.
|
||||
By default, this is white, which effectively means lighting is
|
||||
not used in rendering.
|
||||
|
||||
Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB.
|
||||
The ambient_color input determines the number of channels.
|
||||
"""
|
||||
|
||||
def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
|
||||
@ -304,9 +307,11 @@ class AmbientLights(TensorProperties):
|
||||
device: Device (as str or torch.device) on which the tensors should be located
|
||||
|
||||
The ambient_color if provided, should be
|
||||
- 3 element tuple/list or list of lists
|
||||
- torch tensor of shape (1, 3)
|
||||
- torch tensor of shape (N, 3)
|
||||
- tuple/list of C-element tuples of floats
|
||||
- torch tensor of shape (1, C)
|
||||
- torch tensor of shape (N, C)
|
||||
where C is the number of channels and N is batch size.
|
||||
For RGB, C is 3.
|
||||
"""
|
||||
if ambient_color is None:
|
||||
ambient_color = ((1.0, 1.0, 1.0),)
|
||||
@ -317,10 +322,14 @@ class AmbientLights(TensorProperties):
|
||||
return super().clone(other)
|
||||
|
||||
def diffuse(self, normals, points) -> torch.Tensor:
|
||||
return torch.zeros_like(points)
|
||||
return self._zeros_channels(points)
|
||||
|
||||
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
|
||||
return torch.zeros_like(points)
|
||||
return self._zeros_channels(points)
|
||||
|
||||
def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
|
||||
ch = self.ambient_color.shape[-1]
|
||||
return torch.zeros(*points.shape[:-1], ch, device=points.device)
|
||||
|
||||
|
||||
def _validate_light_properties(obj) -> None:
|
||||
|
@ -27,9 +27,9 @@ class Materials(TensorProperties):
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
ambient_color: RGB ambient reflectivity of the material
|
||||
diffuse_color: RGB diffuse reflectivity of the material
|
||||
specular_color: RGB specular reflectivity of the material
|
||||
ambient_color: ambient reflectivity of the material
|
||||
diffuse_color: diffuse reflectivity of the material
|
||||
specular_color: specular reflectivity of the material
|
||||
shininess: The specular exponent for the material. This defines
|
||||
the focus of the specular highlight with a high value
|
||||
resulting in a concentrated highlight. Shininess values
|
||||
@ -37,7 +37,8 @@ class Materials(TensorProperties):
|
||||
device: Device (as str or torch.device) on which the tensors should be located
|
||||
|
||||
ambient_color, diffuse_color and specular_color can be of shape
|
||||
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
|
||||
(1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,)
|
||||
or (N,).
|
||||
|
||||
The colors and shininess are broadcast against each other so need to
|
||||
have either the same batch dimension or batch dimension = 1.
|
||||
@ -49,11 +50,12 @@ class Materials(TensorProperties):
|
||||
specular_color=specular_color,
|
||||
shininess=shininess,
|
||||
)
|
||||
C = self.ambient_color.shape[-1]
|
||||
for n in ["ambient_color", "diffuse_color", "specular_color"]:
|
||||
t = getattr(self, n)
|
||||
if t.shape[-1] != 3:
|
||||
msg = "Expected %s to have shape (N, 3); got %r"
|
||||
raise ValueError(msg % (n, t.shape))
|
||||
if t.shape[-1] != C:
|
||||
msg = "Expected %s to have shape (N, %d); got %r"
|
||||
raise ValueError(msg % (n, C, t.shape))
|
||||
if self.shininess.shape != torch.Size([self._N]):
|
||||
msg = "shininess should have shape (N); got %r"
|
||||
raise ValueError(msg % repr(self.shininess.shape))
|
||||
|
BIN
tests/data/test_nd_sphere.png
Normal file
BIN
tests/data/test_nd_sphere.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 56 KiB |
@ -1236,3 +1236,81 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
|
||||
)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_nd_sphere(self):
|
||||
"""
|
||||
Test that the render can handle textures with more than 3 channels and
|
||||
not just 3 channel RGB.
|
||||
"""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
C = 5
|
||||
WHITE = ((1.0,) * C,)
|
||||
BLACK = ((0.0,) * C,)
|
||||
|
||||
# Init mesh
|
||||
sphere_mesh = ico_sphere(5, device)
|
||||
verts_padded = sphere_mesh.verts_padded()
|
||||
faces_padded = sphere_mesh.faces_padded()
|
||||
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||
n_verts = feats.shape[1]
|
||||
# make some non-uniform pattern
|
||||
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||
textures = TexturesVertex(verts_features=feats)
|
||||
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||
|
||||
# No elevation or azimuth rotation
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
|
||||
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
diffuse_color=WHITE,
|
||||
specular_color=WHITE,
|
||||
)
|
||||
lights = AmbientLights(
|
||||
device=device,
|
||||
ambient_color=WHITE,
|
||||
)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=BLACK[0],
|
||||
)
|
||||
|
||||
# only test HardFlatShader since that's the only one that makes
|
||||
# sense for classification
|
||||
shader = HardFlatShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
images = renderer(sphere_mesh)
|
||||
|
||||
self.assertEqual(images.shape[-1], C + 1)
|
||||
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||
|
||||
# grab last 3 color channels
|
||||
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||
filename = "test_nd_sphere.png"
|
||||
|
||||
if DEBUG:
|
||||
debug_filename = "DEBUG_%s" % filename
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / debug_filename
|
||||
)
|
||||
|
||||
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
Loading…
x
Reference in New Issue
Block a user