Lighting broadcasting bug fix

Summary: Fixed multiple issues with shape broadcasting in lighting, shading and blending and updated the tests.

Reviewed By: bottler

Differential Revision: D28997941

fbshipit-source-id: d3ef93f979344076b1d9098a86178b4da63844c8
This commit is contained in:
Nikhila Ravi
2021-06-14 11:47:35 -07:00
committed by Facebook GitHub Bot
parent 9de627e01b
commit bc8361fa47
4 changed files with 73 additions and 31 deletions

View File

@@ -1,12 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from typing import NamedTuple, Sequence
from typing import NamedTuple, Sequence, Union
import torch
from pytorch3d import _C # pyre-fixme[21]: Could not find name `_C` in `pytorch3d`.
# Example functions for blending the top K colors per pixel using the outputs
# from rasterization.
# NOTE: All blending function should return an RGBA image per batch element
@@ -117,7 +116,11 @@ def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:
def softmax_rgb_blend(
colors, fragments, blend_params, znear: float = 1.0, zfar: float = 100
colors,
fragments,
blend_params,
znear: Union[float, torch.Tensor] = 1.0,
zfar: Union[float, torch.Tensor] = 100,
) -> torch.Tensor:
"""
RGB and alpha channel blending to return an RGBA image based on the method
@@ -184,11 +187,16 @@ def softmax_rgb_blend(
# overflow. zbuf shape (N, H, W, K), find max over K.
# TODO: there may still be some instability in the exponent calculation.
# Reshape to be compatible with (N, H, W, K) values in fragments
if torch.is_tensor(zfar):
# pyre-fixme[16]
zfar = zfar[:, None, None, None]
if torch.is_tensor(znear):
znear = znear[:, None, None, None]
z_inv = (zfar - fragments.zbuf) / (zfar - znear) * mask
# pyre-fixme[16]: `Tuple` has no attribute `values`.
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
z_inv_max = torch.max(z_inv, dim=-1).values[..., None].clamp(min=eps)
# pyre-fixme[6]: Expected `Tensor` for 1st param but got `float`.
weights_num = prob_map * torch.exp((z_inv - z_inv_max) / blend_params.gamma)
# Also apply exp normalize trick for the background color weight.

View File

@@ -253,12 +253,26 @@ class PointLights(TensorProperties):
other = self.__class__(device=self.device)
return super().clone(other)
def reshape_location(self, points) -> torch.Tensor:
"""
Reshape the location tensor to have dimensions
compatible with the points which can either be of
shape (P, 3) or (N, H, W, K, 3).
"""
if self.location.ndim == points.ndim:
# pyre-fixme[7]
return self.location
# pyre-fixme[29]
return self.location[:, None, None, None, :]
def diffuse(self, normals, points) -> torch.Tensor:
direction = self.location - points
location = self.reshape_location(points)
direction = location - points
return diffuse(normals=normals, color=self.diffuse_color, direction=direction)
def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
direction = self.location - points
location = self.reshape_location(points)
direction = location - points
return specular(
points=points,
normals=normals,

View File

@@ -14,8 +14,8 @@ def _apply_lighting(
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
points: torch tensor of shape (N, P, 3) or (P, 3).
normals: torch tensor of shape (N, P, 3) or (P, 3)
points: torch tensor of shape (N, ..., 3) or (P, 3).
normals: torch tensor of shape (N, ..., 3) or (P, 3)
lights: instance of the Lights class.
cameras: instance of the Cameras class.
materials: instance of the Materials class.
@@ -35,6 +35,7 @@ def _apply_lighting(
ambient_color = materials.ambient_color * lights.ambient_color
diffuse_color = materials.diffuse_color * light_diffuse
specular_color = materials.specular_color * light_specular
if normals.dim() == 2 and points.dim() == 2:
# If given packed inputs remove batch dim in output.
return (
@@ -42,6 +43,11 @@ def _apply_lighting(
diffuse_color.squeeze(),
specular_color.squeeze(),
)
if ambient_color.ndim != diffuse_color.ndim:
# Reshape from (N, 3) to have dimensions compatible with
# diffuse_color which is of shape (N, H, W, K, 3)
ambient_color = ambient_color[:, None, None, None, :]
return ambient_color, diffuse_color, specular_color