mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
renderer: add support for rendering high dimensional textures for classification/segmentation use cases (#1248)
Summary: For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck. This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way. This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1248 Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture. Reviewed By: bottler Differential Revision: D37764610 Pulled By: d4l3k fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
This commit is contained in:
parent
aa8b03f31d
commit
8d10ba52b2
@ -292,6 +292,9 @@ class AmbientLights(TensorProperties):
|
|||||||
A light object representing the same color of light everywhere.
|
A light object representing the same color of light everywhere.
|
||||||
By default, this is white, which effectively means lighting is
|
By default, this is white, which effectively means lighting is
|
||||||
not used in rendering.
|
not used in rendering.
|
||||||
|
|
||||||
|
Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB.
|
||||||
|
The ambient_color input determines the number of channels.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
|
def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
|
||||||
@ -304,9 +307,11 @@ class AmbientLights(TensorProperties):
|
|||||||
device: Device (as str or torch.device) on which the tensors should be located
|
device: Device (as str or torch.device) on which the tensors should be located
|
||||||
|
|
||||||
The ambient_color if provided, should be
|
The ambient_color if provided, should be
|
||||||
- 3 element tuple/list or list of lists
|
- tuple/list of C-element tuples of floats
|
||||||
- torch tensor of shape (1, 3)
|
- torch tensor of shape (1, C)
|
||||||
- torch tensor of shape (N, 3)
|
- torch tensor of shape (N, C)
|
||||||
|
where C is the number of channels and N is batch size.
|
||||||
|
For RGB, C is 3.
|
||||||
"""
|
"""
|
||||||
if ambient_color is None:
|
if ambient_color is None:
|
||||||
ambient_color = ((1.0, 1.0, 1.0),)
|
ambient_color = ((1.0, 1.0, 1.0),)
|
||||||
@ -317,10 +322,14 @@ class AmbientLights(TensorProperties):
|
|||||||
return super().clone(other)
|
return super().clone(other)
|
||||||
|
|
||||||
def diffuse(self, normals, points) -> torch.Tensor:
|
def diffuse(self, normals, points) -> torch.Tensor:
|
||||||
return torch.zeros_like(points)
|
return self._zeros_channels(points)
|
||||||
|
|
||||||
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
|
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
|
||||||
return torch.zeros_like(points)
|
return self._zeros_channels(points)
|
||||||
|
|
||||||
|
def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
|
||||||
|
ch = self.ambient_color.shape[-1]
|
||||||
|
return torch.zeros(*points.shape[:-1], ch, device=points.device)
|
||||||
|
|
||||||
|
|
||||||
def _validate_light_properties(obj) -> None:
|
def _validate_light_properties(obj) -> None:
|
||||||
|
@ -27,9 +27,9 @@ class Materials(TensorProperties):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
ambient_color: RGB ambient reflectivity of the material
|
ambient_color: ambient reflectivity of the material
|
||||||
diffuse_color: RGB diffuse reflectivity of the material
|
diffuse_color: diffuse reflectivity of the material
|
||||||
specular_color: RGB specular reflectivity of the material
|
specular_color: specular reflectivity of the material
|
||||||
shininess: The specular exponent for the material. This defines
|
shininess: The specular exponent for the material. This defines
|
||||||
the focus of the specular highlight with a high value
|
the focus of the specular highlight with a high value
|
||||||
resulting in a concentrated highlight. Shininess values
|
resulting in a concentrated highlight. Shininess values
|
||||||
@ -37,7 +37,8 @@ class Materials(TensorProperties):
|
|||||||
device: Device (as str or torch.device) on which the tensors should be located
|
device: Device (as str or torch.device) on which the tensors should be located
|
||||||
|
|
||||||
ambient_color, diffuse_color and specular_color can be of shape
|
ambient_color, diffuse_color and specular_color can be of shape
|
||||||
(1, 3) or (N, 3). shininess can be of shape (1) or (N).
|
(1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,)
|
||||||
|
or (N,).
|
||||||
|
|
||||||
The colors and shininess are broadcast against each other so need to
|
The colors and shininess are broadcast against each other so need to
|
||||||
have either the same batch dimension or batch dimension = 1.
|
have either the same batch dimension or batch dimension = 1.
|
||||||
@ -49,11 +50,12 @@ class Materials(TensorProperties):
|
|||||||
specular_color=specular_color,
|
specular_color=specular_color,
|
||||||
shininess=shininess,
|
shininess=shininess,
|
||||||
)
|
)
|
||||||
|
C = self.ambient_color.shape[-1]
|
||||||
for n in ["ambient_color", "diffuse_color", "specular_color"]:
|
for n in ["ambient_color", "diffuse_color", "specular_color"]:
|
||||||
t = getattr(self, n)
|
t = getattr(self, n)
|
||||||
if t.shape[-1] != 3:
|
if t.shape[-1] != C:
|
||||||
msg = "Expected %s to have shape (N, 3); got %r"
|
msg = "Expected %s to have shape (N, %d); got %r"
|
||||||
raise ValueError(msg % (n, t.shape))
|
raise ValueError(msg % (n, C, t.shape))
|
||||||
if self.shininess.shape != torch.Size([self._N]):
|
if self.shininess.shape != torch.Size([self._N]):
|
||||||
msg = "shininess should have shape (N); got %r"
|
msg = "shininess should have shape (N); got %r"
|
||||||
raise ValueError(msg % repr(self.shininess.shape))
|
raise ValueError(msg % repr(self.shininess.shape))
|
||||||
|
BIN
tests/data/test_nd_sphere.png
Normal file
BIN
tests/data/test_nd_sphere.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 56 KiB |
@ -1236,3 +1236,81 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
|
"test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
|
||||||
)
|
)
|
||||||
self.assertClose(rgb, image_ref, atol=0.05)
|
self.assertClose(rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
|
def test_nd_sphere(self):
|
||||||
|
"""
|
||||||
|
Test that the render can handle textures with more than 3 channels and
|
||||||
|
not just 3 channel RGB.
|
||||||
|
"""
|
||||||
|
torch.manual_seed(1)
|
||||||
|
device = torch.device("cuda:0")
|
||||||
|
C = 5
|
||||||
|
WHITE = ((1.0,) * C,)
|
||||||
|
BLACK = ((0.0,) * C,)
|
||||||
|
|
||||||
|
# Init mesh
|
||||||
|
sphere_mesh = ico_sphere(5, device)
|
||||||
|
verts_padded = sphere_mesh.verts_padded()
|
||||||
|
faces_padded = sphere_mesh.faces_padded()
|
||||||
|
feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
|
||||||
|
n_verts = feats.shape[1]
|
||||||
|
# make some non-uniform pattern
|
||||||
|
feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
|
||||||
|
textures = TexturesVertex(verts_features=feats)
|
||||||
|
sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
|
||||||
|
|
||||||
|
# No elevation or azimuth rotation
|
||||||
|
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||||
|
|
||||||
|
cameras = PerspectiveCameras(device=device, R=R, T=T)
|
||||||
|
|
||||||
|
# Init shader settings
|
||||||
|
materials = Materials(
|
||||||
|
device=device,
|
||||||
|
ambient_color=WHITE,
|
||||||
|
diffuse_color=WHITE,
|
||||||
|
specular_color=WHITE,
|
||||||
|
)
|
||||||
|
lights = AmbientLights(
|
||||||
|
device=device,
|
||||||
|
ambient_color=WHITE,
|
||||||
|
)
|
||||||
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
|
|
||||||
|
raster_settings = RasterizationSettings(
|
||||||
|
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||||
|
)
|
||||||
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
blend_params = BlendParams(
|
||||||
|
1e-4,
|
||||||
|
1e-4,
|
||||||
|
background_color=BLACK[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
# only test HardFlatShader since that's the only one that makes
|
||||||
|
# sense for classification
|
||||||
|
shader = HardFlatShader(
|
||||||
|
lights=lights,
|
||||||
|
cameras=cameras,
|
||||||
|
materials=materials,
|
||||||
|
blend_params=blend_params,
|
||||||
|
)
|
||||||
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
|
images = renderer(sphere_mesh)
|
||||||
|
|
||||||
|
self.assertEqual(images.shape[-1], C + 1)
|
||||||
|
self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
|
||||||
|
self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
|
||||||
|
|
||||||
|
# grab last 3 color channels
|
||||||
|
rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
|
||||||
|
filename = "test_nd_sphere.png"
|
||||||
|
|
||||||
|
if DEBUG:
|
||||||
|
debug_filename = "DEBUG_%s" % filename
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / debug_filename
|
||||||
|
)
|
||||||
|
|
||||||
|
image_ref = load_rgb_image(filename, DATA_DIR)
|
||||||
|
self.assertClose(rgb, image_ref, atol=0.05)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user