raybundle input to ImplicitFunctions -> api unification

Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera.

Reviewed By: bottler

Differential Revision: D39173751

fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9
This commit is contained in:
Darijan Gudelj 2022-09-05 06:26:06 -07:00 committed by Facebook GitHub Bot
parent 70dc9c451a
commit 72c3a0ebe5
9 changed files with 60 additions and 19 deletions

View File

@ -19,6 +19,7 @@ class ImplicitFunctionBase(ABC, ReplaceableBase):
@abstractmethod
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,

View File

@ -3,14 +3,15 @@
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Tuple
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
from .base import ImplicitFunctionBase
from .utils import get_rays_points_world
@registry.register
@ -125,14 +126,16 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
# inconsistently.
def forward(
self,
# ray_bundle: RayBundle,
rays_points_world: torch.Tensor, # TODO: unify the APIs
*,
ray_bundle: Optional[RayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
fun_viewpool=None,
global_code=None,
**kwargs,
):
# this field only uses point locations
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
rays_points_world = get_rays_points_world(ray_bundle, rays_points_world)
if rays_points_world.numel() == 0 or (
self.embed_fn is None and fun_viewpool is None and global_code is None
@ -179,4 +182,4 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.softplus(x)
return x # TODO: unify the APIs
return x

View File

@ -129,6 +129,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,

View File

@ -349,6 +349,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
@ -408,6 +409,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
def forward(
self,
*,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,

View File

@ -10,7 +10,9 @@ import torch
import torch.nn.functional as F
from pytorch3d.common.compat import prod
from pytorch3d.renderer import ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
@ -185,3 +187,31 @@ def interpolate_volume(
**kwargs,
)
return out[:, :, :, 0, 0].permute(0, 2, 1)
def get_rays_points_world(
ray_bundle: Optional[RayBundle] = None,
rays_points_world: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
and raises error if both are defined.
Args:
ray_bundle: A RayBundle object or None
rays_points_world: A torch.Tensor representing ray points converted to
world coordinates
Returns:
A torch.Tensor representing ray points converted to world coordinates
of shape [minibatch x ... x pts_per_ray x 3].
"""
if rays_points_world is not None and ray_bundle is not None:
raise ValueError(
"Cannot define both rays_points_world and ray_bundle,"
+ " one has to be None."
)
if rays_points_world is not None:
return rays_points_world
if ray_bundle is not None:
return ray_bundle_to_ray_points(ray_bundle)
raise ValueError("ray_bundle and rays_points_world cannot both be None")

View File

@ -118,7 +118,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
# eval the raymarching function
raymarch_features, _ = implicit_function(
ray_bundle_t,
ray_bundle=ray_bundle_t,
raymarch_features=None,
)
if self.verbose:

View File

@ -148,7 +148,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
)
output = self.raymarcher(
*implicit_functions[0](ray_bundle),
*implicit_functions[0](ray_bundle=ray_bundle),
ray_lengths=ray_bundle.lengths,
density_noise_std=density_noise_std,
)

View File

@ -101,7 +101,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
object_mask = object_mask.bool()
implicit_function = implicit_functions[0]
implicit_function_gradient = functools.partial(gradient, implicit_function)
implicit_function_gradient = functools.partial(_gradient, implicit_function)
# object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
@ -113,7 +113,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
with torch.no_grad(), evaluating(implicit_function):
points, network_object_mask, dists = self.ray_tracer(
sdf=lambda x: implicit_function(x)[
sdf=lambda x: implicit_function(rays_points_world=x)[
:, 0
], # TODO: get rid of this wrapper
cam_loc=cam_loc,
@ -125,7 +125,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
depth = dists.reshape(batch_size, num_pixels, 1)
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
sdf_output = implicit_function(points)[:, 0:1]
sdf_output = implicit_function(rays_points_world=points)[:, 0:1]
# NOTE most of the intermediate variables are flattened for
# no apparent reason (here and in the ray tracer)
ray_dirs = ray_dirs.reshape(-1, 3)
@ -157,7 +157,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
points_all = torch.cat([surface_points, eikonal_points], dim=0)
output = implicit_function(surface_points)
output = implicit_function(rays_points_world=surface_points)
surface_sdf_values = output[
:N, 0:1
].detach() # how is it different from sdf_output?
@ -181,7 +181,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
grad_theta = None
empty_render = differentiable_surface_points.shape[0] == 0
features = implicit_function(differentiable_surface_points)[None, :, 1:]
features = implicit_function(rays_points_world=differentiable_surface_points)[
None, :, 1:
]
normals_full = features.new_zeros(
batch_size, *spatial_size, 3, requires_grad=empty_render
)
@ -260,13 +262,13 @@ def _sample_network(
@torch.enable_grad()
def gradient(module, x):
x.requires_grad_(True)
y = module.forward(x)[:, :1]
def _gradient(module, rays_points_world):
rays_points_world.requires_grad_(True)
y = module.forward(rays_points_world=rays_points_world)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
inputs=rays_points_world,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,

View File

@ -44,7 +44,7 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle)
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
@ -62,7 +62,9 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
rays_densities, rays_colors = implicit_function(
ray_bundle=bundle, global_code=global_code
)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,