idr harmonic_fns and doc

Summary: Document the inputs of idr functions and distinguish n_harmonic_functions to be 0 (simple embedding) versus -1 (no embedding).

Reviewed By: davnov134

Differential Revision: D37209012

fbshipit-source-id: 6e5c3eae54c4e5e8c3f76cad1caf162c6c222d52
This commit is contained in:
Jeremy Reizenstein 2022-06-20 13:48:34 -07:00 committed by Facebook GitHub Bot
parent 28c1afaa9d
commit 81d63c6382
3 changed files with 70 additions and 12 deletions

View File

@ -201,7 +201,7 @@ generic_model_args:
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
n_harmonic_functions_xyz: -1
pooled_feature_dim: 0
encoding_dim: 0
implicit_function_NeRFormerImplicitFunction_args:

View File

@ -15,6 +15,41 @@ from .base import ImplicitFunctionBase
@registry.register
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
"""
Implicit function as used in http://github.com/lioryariv/idr.
Members:
d_in: dimension of the input point.
n_harmonic_functions_xyz: If -1, do not embed the point.
If >=0, use a harmonic embedding with this number of
harmonic functions. (The harmonic embedding includes the input
itself, so a value of 0 means the point is used but without
any harmonic functions.)
d_out and feature_vector_size: Sum of these is the output
dimension. This implicit function thus returns a concatenation
of `d_out` signed distance function values and `feature_vector_size`
features (such as colors). When used in `GenericModel`,
`feature_vector_size` corresponds is automatically set to
`render_features_dimensions`.
dims: list of hidden layer sizes.
geometric_init: whether to use custom weight initialization
in linear layers. If False, pytorch default (uniform sampling)
is used.
bias: if geometric_init=True, initial value for bias subtracted
in the last layer.
skip_in: List of indices of layers that receive as input the initial
value concatenated with the output of the previous layers.
weight_norm: whether to apply weight normalization to each layer.
pooled_feature_dim: If view pooling is in use (provided as
fun_viewpool to forward()) this must be its number of features.
Otherwise this must be set to 0. (If used from GenericModel,
this config value will be overridden automatically.)
encoding_dim: If global coding is in use (provided as global_code
to forward()) this must be its number of featuress.
Otherwise this must be set to 0. (If used from GenericModel,
this config value will be overridden automatically.)
"""
feature_vector_size: int = 3
d_in: int = 3
d_out: int = 1
@ -23,7 +58,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
bias: float = 1.0
skip_in: Sequence[int] = ()
weight_norm: bool = True
n_harmonic_functions_xyz: int = 0
n_harmonic_functions_xyz: int = -1
pooled_feature_dim: int = 0
encoding_dim: int = 0
@ -33,7 +68,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
self.embed_fn = None
if self.n_harmonic_functions_xyz > 0:
if self.n_harmonic_functions_xyz >= 0:
self.embed_fn = HarmonicEmbedding(
self.n_harmonic_functions_xyz, append_input=True
)
@ -63,13 +98,13 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
std=0.0001,
)
torch.nn.init.constant_(lin.bias, -self.bias)
elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0:
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
)
elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in:
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
@ -110,27 +145,26 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
# Tensor]`.
).view(0, self.out_dim)
embedding = None
embeddings = []
if self.embed_fn is not None:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
embedding = self.embed_fn(rays_points_world)
embeddings.append(self.embed_fn(rays_points_world))
if fun_viewpool is not None:
assert rays_points_world.ndim == 2
pooled_feature = fun_viewpool(rays_points_world[None])
# TODO: pooled features are 4D!
embedding = torch.cat((embedding, pooled_feature), dim=-1)
embeddings.append(pooled_feature)
if global_code is not None:
assert embedding.ndim == 2
assert global_code.shape[0] == 1 # TODO: generalize to batches!
# This will require changing raytracer code
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
embedding = torch.cat(
(embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)),
dim=-1,
embeddings.append(
global_code[0, None, :].expand(rays_points_world.shape[0], -1)
)
embedding = torch.cat(embeddings, dim=-1)
x = embedding
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s...
for layer_idx in range(self.num_layers - 1):

View File

@ -16,6 +16,30 @@ logger = logging.getLogger(__name__)
class RayNormalColoringNetwork(torch.nn.Module):
"""
Members:
d_in and feature_vector_size: Sum of these is the input
dimension. These must add up to the sum of
- 3 [for the points]
- 3 unless mode=no_normal [for the normals]
- 3 unless mode=no_view_dir [for view directions]
- the feature size, [number of channels in feature_vectors]
d_out: dimension of output.
mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting
part of the network input.
dims: list of hidden layer sizes.
weight_norm: whether to apply weight normalization to each layer.
n_harmonic_functions_dir:
If >0, use a harmonic embedding with this number of
harmonic functions for the view direction. Otherwise view directions
are fed without embedding, unless mode is `no_view_dir`.
pooled_feature_dim: If a pooling function is in use (provided as
pooling_fn to forward()) this must be its number of features.
Otherwise this must be set to 0. (If used from GenericModel,
this will be set automatically.)
"""
def __init__(
self,
feature_vector_size: int = 3,