mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Lighting broadcasting bug fix
Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests. Reviewed By: bottler Differential Revision: D28997941 fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8
This commit is contained in:
parent
9de627e01b
commit
bc8361fa47
@ -1,12 +1,11 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
from typing import NamedTuple, Sequence
|
from typing import NamedTuple, Sequence, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
|
||||||
|
|
||||||
|
|
||||||
# Example functions for blending the top K colors per pixel using the outputs
|
# Example functions for blending the top K colors per pixel using the outputs
|
||||||
# from rasterization.
|
# from rasterization.
|
||||||
# NOTE: All blending function should return an RGBA image per batch element
|
# NOTE: All blending function should return an RGBA image per batch element
|
||||||
@ -117,7 +116,11 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
|
|||||||
|
|
||||||
|
|
||||||
def softmax_rgb_blend(
|
def softmax_rgb_blend(
|
||||||
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
|
colors,
|
||||||
|
fragments,
|
||||||
|
blend_params,
|
||||||
|
znear: Union[float, torch.Tensor] = 1.0,
|
||||||
|
zfar: Union[float, torch.Tensor] = 100,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
RGB and alpha channel blending to return an RGBA image based on the method
|
RGB and alpha channel blending to return an RGBA image based on the method
|
||||||
@ -184,11 +187,16 @@ def softmax_rgb_blend(
|
|||||||
# overflow. zbuf shape (N, H, W, K), find max over K.
|
# overflow. zbuf shape (N, H, W, K), find max over K.
|
||||||
# TODO: there may still be some instability in the exponent calculation.
|
# TODO: there may still be some instability in the exponent calculation.
|
||||||
|
|
||||||
|
# Reshape to be compatible with (N, H, W, K) values in fragments
|
||||||
|
if torch.is_tensor(zfar):
|
||||||
|
# pyre-fixme[16]
|
||||||
|
zfar = zfar[:, None, None, None]
|
||||||
|
if torch.is_tensor(znear):
|
||||||
|
znear = znear[:, None, None, None]
|
||||||
|
|
||||||
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
|
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
|
||||||
# pyre-fixme[16]: `Tuple` has no attribute `values`.
|
# pyre-fixme[16]: `Tuple` has no attribute `values`.
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
|
|
||||||
z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
|
z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
|
||||||
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
|
|
||||||
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
|
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
|
||||||
|
|
||||||
# Also apply exp normalize trick for the background color weight.
|
# Also apply exp normalize trick for the background color weight.
|
||||||
|
@ -253,12 +253,26 @@ class PointLights(TensorProperties):
|
|||||||
other = self.__class__(device=self.device)
|
other = self.__class__(device=self.device)
|
||||||
return super().clone(other)
|
return super().clone(other)
|
||||||
|
|
||||||
|
def reshape_location(self, points) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Reshape the location tensor to have dimensions
|
||||||
|
compatible with the points which can either be of
|
||||||
|
shape (P, 3) or (N, H, W, K, 3).
|
||||||
|
"""
|
||||||
|
if self.location.ndim == points.ndim:
|
||||||
|
# pyre-fixme[7]
|
||||||
|
return self.location
|
||||||
|
# pyre-fixme[29]
|
||||||
|
return self.location[:, None, None, None, :]
|
||||||
|
|
||||||
def diffuse(self, normals, points) -> torch.Tensor:
|
def diffuse(self, normals, points) -> torch.Tensor:
|
||||||
direction = self.location - points
|
location = self.reshape_location(points)
|
||||||
|
direction = location - points
|
||||||
return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
|
return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
|
||||||
|
|
||||||
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
|
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
|
||||||
direction = self.location - points
|
location = self.reshape_location(points)
|
||||||
|
direction = location - points
|
||||||
return specular(
|
return specular(
|
||||||
points=points,
|
points=points,
|
||||||
normals=normals,
|
normals=normals,
|
||||||
|
@ -14,8 +14,8 @@ def _apply_lighting(
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
points: torch tensor of shape (N, P, 3) or (P, 3).
|
points: torch tensor of shape (N, ..., 3) or (P, 3).
|
||||||
normals: torch tensor of shape (N, P, 3) or (P, 3)
|
normals: torch tensor of shape (N, ..., 3) or (P, 3)
|
||||||
lights: instance of the Lights class.
|
lights: instance of the Lights class.
|
||||||
cameras: instance of the Cameras class.
|
cameras: instance of the Cameras class.
|
||||||
materials: instance of the Materials class.
|
materials: instance of the Materials class.
|
||||||
@ -35,6 +35,7 @@ def _apply_lighting(
|
|||||||
ambient_color = materials.ambient_color * lights.ambient_color
|
ambient_color = materials.ambient_color * lights.ambient_color
|
||||||
diffuse_color = materials.diffuse_color * light_diffuse
|
diffuse_color = materials.diffuse_color * light_diffuse
|
||||||
specular_color = materials.specular_color * light_specular
|
specular_color = materials.specular_color * light_specular
|
||||||
|
|
||||||
if normals.dim() == 2 and points.dim() == 2:
|
if normals.dim() == 2 and points.dim() == 2:
|
||||||
# If given packed inputs remove batch dim in output.
|
# If given packed inputs remove batch dim in output.
|
||||||
return (
|
return (
|
||||||
@ -42,6 +43,11 @@ def _apply_lighting(
|
|||||||
diffuse_color.squeeze(),
|
diffuse_color.squeeze(),
|
||||||
specular_color.squeeze(),
|
specular_color.squeeze(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if ambient_color.ndim != diffuse_color.ndim:
|
||||||
|
# Reshape from (N, 3) to have dimensions compatible with
|
||||||
|
# diffuse_color which is of shape (N, H, W, K, 3)
|
||||||
|
ambient_color = ambient_color[:, None, None, None, :]
|
||||||
return ambient_color, diffuse_color, specular_color
|
return ambient_color, diffuse_color, specular_color
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ Sanity checks for output images from the renderer.
|
|||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -53,6 +54,8 @@ DEBUG = False
|
|||||||
DATA_DIR = get_tests_dir() / "data"
|
DATA_DIR = get_tests_dir() / "data"
|
||||||
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
TUTORIAL_DATA_DIR = get_pytorch3d_dir() / "docs/tutorials/data"
|
||||||
|
|
||||||
|
ShaderTest = namedtuple("ShaderTest", ["shader", "reference_name", "debug_name"])
|
||||||
|
|
||||||
|
|
||||||
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||||
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
|
def test_simple_sphere(self, elevated_camera=False, check_depth=False):
|
||||||
@ -107,13 +110,13 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||||
|
|
||||||
# Test several shaders
|
# Test several shaders
|
||||||
shaders = {
|
shader_tests = [
|
||||||
"phong": HardPhongShader,
|
ShaderTest(HardPhongShader, "phong", "hard_phong"),
|
||||||
"gouraud": HardGouraudShader,
|
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
|
||||||
"flat": HardFlatShader,
|
ShaderTest(HardFlatShader, "flat", "hard_flat"),
|
||||||
}
|
]
|
||||||
for (name, shader_init) in shaders.items():
|
for test in shader_tests:
|
||||||
shader = shader_init(
|
shader = test.shader(
|
||||||
lights=lights,
|
lights=lights,
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
materials=materials,
|
materials=materials,
|
||||||
@ -135,7 +138,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
filename = "simple_sphere_light_%s%s%s.png" % (
|
filename = "simple_sphere_light_%s%s%s.png" % (
|
||||||
name,
|
test.reference_name,
|
||||||
postfix,
|
postfix,
|
||||||
cam_type.__name__,
|
cam_type.__name__,
|
||||||
)
|
)
|
||||||
@ -144,7 +147,12 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(rgb, image_ref, atol=0.05)
|
self.assertClose(rgb, image_ref, atol=0.05)
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
filename = "DEBUG_%s" % filename
|
debug_filename = "simple_sphere_light_%s%s%s.png" % (
|
||||||
|
test.debug_name,
|
||||||
|
postfix,
|
||||||
|
cam_type.__name__,
|
||||||
|
)
|
||||||
|
filename = "DEBUG_%s" % debug_filename
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
DATA_DIR / filename
|
DATA_DIR / filename
|
||||||
)
|
)
|
||||||
@ -269,7 +277,8 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_simple_sphere_batched(self):
|
def test_simple_sphere_batched(self):
|
||||||
"""
|
"""
|
||||||
Test a mesh with vertex textures can be extended to form a batch, and
|
Test a mesh with vertex textures can be extended to form a batch, and
|
||||||
is rendered correctly with Phong, Gouraud and Flat Shaders.
|
is rendered correctly with Phong, Gouraud and Flat Shaders with batched
|
||||||
|
lighting and hard and soft blending.
|
||||||
"""
|
"""
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
@ -291,24 +300,28 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
R, T = look_at_view_transform(dist, elev, azim)
|
R, T = look_at_view_transform(dist, elev, azim)
|
||||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
image_size=512, blur_radius=0.0, faces_per_pixel=4
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init shader settings
|
# Init shader settings
|
||||||
materials = Materials(device=device)
|
materials = Materials(device=device)
|
||||||
lights = PointLights(device=device)
|
lights_location = torch.tensor([0.0, 0.0, +2.0], device=device)
|
||||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
lights_location = lights_location[None].expand(batch_size, -1)
|
||||||
|
lights = PointLights(device=device, location=lights_location)
|
||||||
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
shaders = {
|
shader_tests = [
|
||||||
"phong": HardPhongShader,
|
ShaderTest(HardPhongShader, "phong", "hard_phong"),
|
||||||
"gouraud": HardGouraudShader,
|
ShaderTest(SoftPhongShader, "phong", "soft_phong"),
|
||||||
"flat": HardFlatShader,
|
ShaderTest(HardGouraudShader, "gouraud", "hard_gouraud"),
|
||||||
}
|
ShaderTest(HardFlatShader, "flat", "hard_flat"),
|
||||||
for (name, shader_init) in shaders.items():
|
]
|
||||||
shader = shader_init(
|
for test in shader_tests:
|
||||||
|
reference_name = test.reference_name
|
||||||
|
debug_name = test.debug_name
|
||||||
|
shader = test.shader(
|
||||||
lights=lights,
|
lights=lights,
|
||||||
cameras=cameras,
|
cameras=cameras,
|
||||||
materials=materials,
|
materials=materials,
|
||||||
@ -317,14 +330,15 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
images = renderer(sphere_meshes)
|
images = renderer(sphere_meshes)
|
||||||
image_ref = load_rgb_image(
|
image_ref = load_rgb_image(
|
||||||
"test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__),
|
"test_simple_sphere_light_%s_%s.png"
|
||||||
|
% (reference_name, type(cameras).__name__),
|
||||||
DATA_DIR,
|
DATA_DIR,
|
||||||
)
|
)
|
||||||
for i in range(batch_size):
|
for i in range(batch_size):
|
||||||
rgb = images[i, ..., :3].squeeze().cpu()
|
rgb = images[i, ..., :3].squeeze().cpu()
|
||||||
if i == 0 and DEBUG:
|
if i == 0 and DEBUG:
|
||||||
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
|
filename = "DEBUG_simple_sphere_batched_%s_%s.png" % (
|
||||||
name,
|
debug_name,
|
||||||
type(cameras).__name__,
|
type(cameras).__name__,
|
||||||
)
|
)
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user