mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00

Differential Revision: D37172764 fbshipit-source-id: a2ec367e56de2781a17f5e708eb5832ec9d7e6b4
263 lines
9.8 KiB
Python
263 lines
9.8 KiB
Python
# @lint-ignore-every LICENSELINT
|
|
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
|
|
# implicit_differentiable_renderer.py
|
|
# Copyright (c) 2020 Lior Yariv
|
|
import functools
|
|
import math
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
from omegaconf import DictConfig
|
|
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
|
|
from pytorch3d.implicitron.tools.utils import evaluating
|
|
from pytorch3d.renderer import RayBundle
|
|
|
|
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
|
|
from .ray_tracing import RayTracing
|
|
from .rgb_net import RayNormalColoringNetwork
|
|
|
|
|
|
@registry.register
|
|
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
|
|
render_features_dimensions: int = 3
|
|
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
|
|
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
|
|
RayNormalColoringNetwork
|
|
)
|
|
bg_color: Tuple[float, ...] = (0.0,)
|
|
soft_mask_alpha: float = 50.0
|
|
|
|
def __post_init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
render_features_dimensions = self.render_features_dimensions
|
|
if len(self.bg_color) not in [1, render_features_dimensions]:
|
|
raise ValueError(
|
|
f"Background color should have {render_features_dimensions} entries."
|
|
)
|
|
|
|
self.ray_tracer = RayTracing(**self.ray_tracer_args)
|
|
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
|
|
|
|
self.ray_normal_coloring_network_args[
|
|
"feature_vector_size"
|
|
] = render_features_dimensions
|
|
self._rgb_network = RayNormalColoringNetwork(
|
|
**self.ray_normal_coloring_network_args
|
|
)
|
|
|
|
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
|
|
|
|
def requires_object_mask(self) -> bool:
|
|
return True
|
|
|
|
def forward(
|
|
self,
|
|
ray_bundle: RayBundle,
|
|
implicit_functions: List[ImplicitFunctionWrapper],
|
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
|
object_mask: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
) -> RendererOutput:
|
|
"""
|
|
Args:
|
|
ray_bundle: A `RayBundle` object containing the parametrizations of the
|
|
sampled rendering rays.
|
|
implicit_functions: single element list of ImplicitFunctionWrappers which
|
|
defines the implicit function to be used.
|
|
evaluation_mode: one of EvaluationMode.TRAINING or
|
|
EvaluationMode.EVALUATION which determines the settings used for
|
|
rendering.
|
|
kwargs:
|
|
object_mask: BoolTensor, denoting the silhouette of the object.
|
|
This is a required keyword argument for SignedDistanceFunctionRenderer
|
|
|
|
Returns:
|
|
instance of RendererOutput
|
|
"""
|
|
if len(implicit_functions) != 1:
|
|
raise ValueError(
|
|
"SignedDistanceFunctionRenderer supports only single pass."
|
|
)
|
|
|
|
if object_mask is None:
|
|
raise ValueError("Expected object_mask to be provided in the kwargs")
|
|
object_mask = object_mask.bool()
|
|
|
|
implicit_function = implicit_functions[0]
|
|
implicit_function_gradient = functools.partial(gradient, implicit_function)
|
|
|
|
# object_mask: silhouette of the object
|
|
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
|
|
num_pixels = math.prod(spatial_size)
|
|
|
|
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
|
|
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)
|
|
object_mask = object_mask.reshape(batch_size, -1)
|
|
|
|
with torch.no_grad(), evaluating(implicit_function):
|
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
|
points, network_object_mask, dists = self.ray_tracer(
|
|
sdf=lambda x: implicit_function(x)[
|
|
:, 0
|
|
], # TODO: get rid of this wrapper
|
|
cam_loc=cam_loc,
|
|
object_mask=object_mask,
|
|
ray_directions=ray_dirs,
|
|
)
|
|
|
|
# TODO: below, cam_loc might as well be different
|
|
depth = dists.reshape(batch_size, num_pixels, 1)
|
|
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
|
|
|
|
sdf_output = implicit_function(points)[:, 0:1]
|
|
# NOTE most of the intermediate variables are flattened for
|
|
# no apparent reason (here and in the ray tracer)
|
|
ray_dirs = ray_dirs.reshape(-1, 3)
|
|
object_mask = object_mask.reshape(-1)
|
|
|
|
# TODO: move it to loss computation
|
|
if evaluation_mode == EvaluationMode.TRAINING:
|
|
surface_mask = network_object_mask & object_mask
|
|
surface_points = points[surface_mask]
|
|
surface_dists = dists[surface_mask].unsqueeze(-1)
|
|
surface_ray_dirs = ray_dirs[surface_mask]
|
|
surface_cam_loc = cam_loc.reshape(-1, 3)[surface_mask]
|
|
surface_output = sdf_output[surface_mask]
|
|
N = surface_points.shape[0]
|
|
|
|
# Sample points for the eikonal loss
|
|
# pyre-fixme[9]
|
|
eik_bounding_box: float = self.object_bounding_sphere
|
|
n_eik_points = batch_size * num_pixels // 2
|
|
eikonal_points = torch.empty(
|
|
n_eik_points,
|
|
3,
|
|
# pyre-fixme[6]: For 3rd param expected `Union[None, str, device]`
|
|
# but got `Union[device, Tensor, Module]`.
|
|
device=self._bg_color.device,
|
|
).uniform_(-eik_bounding_box, eik_bounding_box)
|
|
eikonal_pixel_points = points.clone()
|
|
eikonal_pixel_points = eikonal_pixel_points.detach()
|
|
eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0)
|
|
|
|
points_all = torch.cat([surface_points, eikonal_points], dim=0)
|
|
|
|
output = implicit_function(surface_points)
|
|
surface_sdf_values = output[
|
|
:N, 0:1
|
|
].detach() # how is it different from sdf_output?
|
|
|
|
g = implicit_function_gradient(points_all)
|
|
surface_points_grad = g[:N, 0, :].clone().detach()
|
|
grad_theta = g[N:, 0, :]
|
|
|
|
differentiable_surface_points = _sample_network(
|
|
surface_output,
|
|
surface_sdf_values,
|
|
surface_points_grad,
|
|
surface_dists,
|
|
surface_cam_loc,
|
|
surface_ray_dirs,
|
|
)
|
|
|
|
else:
|
|
surface_mask = network_object_mask
|
|
differentiable_surface_points = points[surface_mask]
|
|
grad_theta = None
|
|
|
|
empty_render = differentiable_surface_points.shape[0] == 0
|
|
features = implicit_function(differentiable_surface_points)[None, :, 1:]
|
|
normals_full = features.new_zeros(
|
|
batch_size, *spatial_size, 3, requires_grad=empty_render
|
|
)
|
|
render_full = (
|
|
features.new_ones(
|
|
batch_size,
|
|
*spatial_size,
|
|
self.render_features_dimensions,
|
|
requires_grad=empty_render,
|
|
)
|
|
* self._bg_color
|
|
)
|
|
mask_full = features.new_ones(
|
|
batch_size, *spatial_size, 1, requires_grad=empty_render
|
|
)
|
|
if not empty_render:
|
|
normals = implicit_function_gradient(differentiable_surface_points)[
|
|
None, :, 0, :
|
|
]
|
|
normals_full.view(-1, 3)[surface_mask] = normals
|
|
render_full.view(-1, self.render_features_dimensions)[
|
|
surface_mask
|
|
] = self._rgb_network( # pyre-fixme[29]:
|
|
features,
|
|
differentiable_surface_points[None],
|
|
normals,
|
|
ray_bundle,
|
|
surface_mask[None, :, None],
|
|
pooling_fn=None, # TODO
|
|
)
|
|
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
|
|
# pyre-fixme[6]: For 1st param expected `Tensor` but got `float`.
|
|
-self.soft_mask_alpha
|
|
* sdf_output[~surface_mask]
|
|
)
|
|
|
|
# scatter points with surface_mask
|
|
points_full = ray_bundle.origins.detach().clone()
|
|
points_full.view(-1, 3)[surface_mask] = differentiable_surface_points
|
|
|
|
# TODO: it is sparse here but otherwise dense
|
|
return RendererOutput(
|
|
features=render_full,
|
|
normals=normals_full,
|
|
depths=depth.reshape(batch_size, *spatial_size, 1),
|
|
masks=mask_full, # this is a differentiable approximation, see (7) in the paper
|
|
points=points_full,
|
|
aux={"grad_theta": grad_theta}, # TODO: will be moved to eikonal loss
|
|
# TODO: do we need sdf_output, grad_theta? Only for loss probably
|
|
)
|
|
|
|
|
|
def _sample_network(
|
|
surface_output,
|
|
surface_sdf_values,
|
|
surface_points_grad,
|
|
surface_dists,
|
|
surface_cam_loc,
|
|
surface_ray_dirs,
|
|
eps: float = 1e-4,
|
|
):
|
|
# t -> t(theta)
|
|
surface_ray_dirs_0 = surface_ray_dirs.detach()
|
|
surface_points_dot = torch.bmm(
|
|
surface_points_grad.view(-1, 1, 3), surface_ray_dirs_0.view(-1, 3, 1)
|
|
).squeeze(-1)
|
|
dot_sign = (surface_points_dot >= 0).to(surface_points_dot) * 2 - 1
|
|
surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / (
|
|
surface_points_dot.abs().clip(eps) * dot_sign
|
|
)
|
|
|
|
# t(theta) -> x(theta,c,v)
|
|
surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs
|
|
|
|
return surface_points_theta_c_v
|
|
|
|
|
|
@torch.enable_grad()
|
|
def gradient(module, x):
|
|
x.requires_grad_(True)
|
|
y = module.forward(x)[:, :1]
|
|
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
|
gradients = torch.autograd.grad(
|
|
outputs=y,
|
|
inputs=x,
|
|
grad_outputs=d_output,
|
|
create_graph=True,
|
|
retain_graph=True,
|
|
only_inputs=True,
|
|
)[0]
|
|
return gradients.unsqueeze(1)
|