mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
raybundle input to ImplicitFunctions -> api unification
Summary: Currently some implicit functions in implicitron take a raybundle, others take ray_points_world. raybundle is what they really need. However, the raybundle is going to become a bit more flexible later, as it will contain different numbers of rays for each camera. Reviewed By: bottler Differential Revision: D39173751 fbshipit-source-id: ebc038e426d22e831e67a18ba64655d8a61e1eb9
This commit is contained in:
parent
70dc9c451a
commit
72c3a0ebe5
@ -19,6 +19,7 @@ class ImplicitFunctionBase(ABC, ReplaceableBase):
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
ray_bundle: RayBundle,
|
||||
fun_viewpool=None,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
|
@ -3,14 +3,15 @@
|
||||
# implicit_differentiable_renderer.py
|
||||
# Copyright (c) 2020 Lior Yariv
|
||||
import math
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding
|
||||
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
|
||||
from torch import nn
|
||||
|
||||
from .base import ImplicitFunctionBase
|
||||
from .utils import get_rays_points_world
|
||||
|
||||
|
||||
@registry.register
|
||||
@ -125,14 +126,16 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
# inconsistently.
|
||||
def forward(
|
||||
self,
|
||||
# ray_bundle: RayBundle,
|
||||
rays_points_world: torch.Tensor, # TODO: unify the APIs
|
||||
*,
|
||||
ray_bundle: Optional[RayBundle] = None,
|
||||
rays_points_world: Optional[torch.Tensor] = None,
|
||||
fun_viewpool=None,
|
||||
global_code=None,
|
||||
**kwargs,
|
||||
):
|
||||
# this field only uses point locations
|
||||
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
|
||||
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
|
||||
rays_points_world = get_rays_points_world(ray_bundle, rays_points_world)
|
||||
|
||||
if rays_points_world.numel() == 0 or (
|
||||
self.embed_fn is None and fun_viewpool is None and global_code is None
|
||||
@ -179,4 +182,4 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||
x = self.softplus(x)
|
||||
|
||||
return x # TODO: unify the APIs
|
||||
return x
|
||||
|
@ -129,6 +129,7 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
ray_bundle: RayBundle,
|
||||
fun_viewpool=None,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
|
@ -349,6 +349,7 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
ray_bundle: RayBundle,
|
||||
fun_viewpool=None,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
@ -408,6 +409,7 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
*,
|
||||
ray_bundle: RayBundle,
|
||||
fun_viewpool=None,
|
||||
camera: Optional[CamerasBase] = None,
|
||||
|
@ -10,7 +10,9 @@ import torch
|
||||
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.compat import prod
|
||||
from pytorch3d.renderer import ray_bundle_to_ray_points
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from pytorch3d.renderer.implicit import RayBundle
|
||||
|
||||
|
||||
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
|
||||
@ -185,3 +187,31 @@ def interpolate_volume(
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0, 0].permute(0, 2, 1)
|
||||
|
||||
|
||||
def get_rays_points_world(
|
||||
ray_bundle: Optional[RayBundle] = None,
|
||||
rays_points_world: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Converts the ray_bundle to rays_points_world if rays_points_world is not defined
|
||||
and raises error if both are defined.
|
||||
|
||||
Args:
|
||||
ray_bundle: A RayBundle object or None
|
||||
rays_points_world: A torch.Tensor representing ray points converted to
|
||||
world coordinates
|
||||
Returns:
|
||||
A torch.Tensor representing ray points converted to world coordinates
|
||||
of shape [minibatch x ... x pts_per_ray x 3].
|
||||
"""
|
||||
if rays_points_world is not None and ray_bundle is not None:
|
||||
raise ValueError(
|
||||
"Cannot define both rays_points_world and ray_bundle,"
|
||||
+ " one has to be None."
|
||||
)
|
||||
if rays_points_world is not None:
|
||||
return rays_points_world
|
||||
if ray_bundle is not None:
|
||||
return ray_bundle_to_ray_points(ray_bundle)
|
||||
raise ValueError("ray_bundle and rays_points_world cannot both be None")
|
||||
|
@ -118,7 +118,7 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
|
||||
|
||||
# eval the raymarching function
|
||||
raymarch_features, _ = implicit_function(
|
||||
ray_bundle_t,
|
||||
ray_bundle=ray_bundle_t,
|
||||
raymarch_features=None,
|
||||
)
|
||||
if self.verbose:
|
||||
|
@ -148,7 +148,7 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13
|
||||
)
|
||||
|
||||
output = self.raymarcher(
|
||||
*implicit_functions[0](ray_bundle),
|
||||
*implicit_functions[0](ray_bundle=ray_bundle),
|
||||
ray_lengths=ray_bundle.lengths,
|
||||
density_noise_std=density_noise_std,
|
||||
)
|
||||
|
@ -101,7 +101,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
||||
object_mask = object_mask.bool()
|
||||
|
||||
implicit_function = implicit_functions[0]
|
||||
implicit_function_gradient = functools.partial(gradient, implicit_function)
|
||||
implicit_function_gradient = functools.partial(_gradient, implicit_function)
|
||||
|
||||
# object_mask: silhouette of the object
|
||||
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
|
||||
@ -113,7 +113,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
||||
|
||||
with torch.no_grad(), evaluating(implicit_function):
|
||||
points, network_object_mask, dists = self.ray_tracer(
|
||||
sdf=lambda x: implicit_function(x)[
|
||||
sdf=lambda x: implicit_function(rays_points_world=x)[
|
||||
:, 0
|
||||
], # TODO: get rid of this wrapper
|
||||
cam_loc=cam_loc,
|
||||
@ -125,7 +125,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
||||
depth = dists.reshape(batch_size, num_pixels, 1)
|
||||
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
|
||||
|
||||
sdf_output = implicit_function(points)[:, 0:1]
|
||||
sdf_output = implicit_function(rays_points_world=points)[:, 0:1]
|
||||
# NOTE most of the intermediate variables are flattened for
|
||||
# no apparent reason (here and in the ray tracer)
|
||||
ray_dirs = ray_dirs.reshape(-1, 3)
|
||||
@ -157,7 +157,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
||||
|
||||
points_all = torch.cat([surface_points, eikonal_points], dim=0)
|
||||
|
||||
output = implicit_function(surface_points)
|
||||
output = implicit_function(rays_points_world=surface_points)
|
||||
surface_sdf_values = output[
|
||||
:N, 0:1
|
||||
].detach() # how is it different from sdf_output?
|
||||
@ -181,7 +181,9 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign
|
||||
grad_theta = None
|
||||
|
||||
empty_render = differentiable_surface_points.shape[0] == 0
|
||||
features = implicit_function(differentiable_surface_points)[None, :, 1:]
|
||||
features = implicit_function(rays_points_world=differentiable_surface_points)[
|
||||
None, :, 1:
|
||||
]
|
||||
normals_full = features.new_zeros(
|
||||
batch_size, *spatial_size, 3, requires_grad=empty_render
|
||||
)
|
||||
@ -260,13 +262,13 @@ def _sample_network(
|
||||
|
||||
|
||||
@torch.enable_grad()
|
||||
def gradient(module, x):
|
||||
x.requires_grad_(True)
|
||||
y = module.forward(x)[:, :1]
|
||||
def _gradient(module, rays_points_world):
|
||||
rays_points_world.requires_grad_(True)
|
||||
y = module.forward(rays_points_world=rays_points_world)[:, :1]
|
||||
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
||||
gradients = torch.autograd.grad(
|
||||
outputs=y,
|
||||
inputs=x,
|
||||
inputs=rays_points_world,
|
||||
grad_outputs=d_output,
|
||||
create_graph=True,
|
||||
retain_graph=True,
|
||||
|
@ -44,7 +44,7 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
||||
implicit_function = SRNImplicitFunction()
|
||||
device = torch.device("cpu")
|
||||
bundle = self._get_bundle(device=device)
|
||||
rays_densities, rays_colors = implicit_function(bundle)
|
||||
rays_densities, rays_colors = implicit_function(ray_bundle=bundle)
|
||||
out_features = implicit_function.raymarch_function.out_features
|
||||
self.assertEqual(
|
||||
rays_densities.shape,
|
||||
@ -62,7 +62,9 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
||||
implicit_function.to(device)
|
||||
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
||||
bundle = self._get_bundle(device=device)
|
||||
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
|
||||
rays_densities, rays_colors = implicit_function(
|
||||
ray_bundle=bundle, global_code=global_code
|
||||
)
|
||||
out_features = implicit_function.hypernet.out_features
|
||||
self.assertEqual(
|
||||
rays_densities.shape,
|
||||
|
Loading…
x
Reference in New Issue
Block a user