From 81d63c63823e146e74d7be367d19314ab16d6815 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Mon, 20 Jun 2022 13:48:34 -0700 Subject: [PATCH] idr harmonic_fns and doc Summary: Document the inputs of idr functions and distinguish n_harmonic_functions to be 0 (simple embedding) versus -1 (no embedding). Reviewed By: davnov134 Differential Revision: D37209012 fbshipit-source-id: 6e5c3eae54c4e5e8c3f76cad1caf162c6c222d52 --- .../implicitron_trainer/tests/experiment.yaml | 2 +- .../implicit_function/idr_feature_field.py | 56 +++++++++++++++---- .../implicitron/models/renderer/rgb_net.py | 24 ++++++++ 3 files changed, 70 insertions(+), 12 deletions(-) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index eda38f3a..267dd5c3 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -201,7 +201,7 @@ generic_model_args: bias: 1.0 skip_in: [] weight_norm: true - n_harmonic_functions_xyz: 0 + n_harmonic_functions_xyz: -1 pooled_feature_dim: 0 encoding_dim: 0 implicit_function_NeRFormerImplicitFunction_args: diff --git a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py index d54ef78c..19d3b39d 100644 --- a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py +++ b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py @@ -15,6 +15,41 @@ from .base import ImplicitFunctionBase @registry.register class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): + """ + Implicit function as used in http://github.com/lioryariv/idr. + + Members: + d_in: dimension of the input point. + n_harmonic_functions_xyz: If -1, do not embed the point. + If >=0, use a harmonic embedding with this number of + harmonic functions. (The harmonic embedding includes the input + itself, so a value of 0 means the point is used but without + any harmonic functions.) + d_out and feature_vector_size: Sum of these is the output + dimension. This implicit function thus returns a concatenation + of `d_out` signed distance function values and `feature_vector_size` + features (such as colors). When used in `GenericModel`, + `feature_vector_size` corresponds is automatically set to + `render_features_dimensions`. + dims: list of hidden layer sizes. + geometric_init: whether to use custom weight initialization + in linear layers. If False, pytorch default (uniform sampling) + is used. + bias: if geometric_init=True, initial value for bias subtracted + in the last layer. + skip_in: List of indices of layers that receive as input the initial + value concatenated with the output of the previous layers. + weight_norm: whether to apply weight normalization to each layer. + pooled_feature_dim: If view pooling is in use (provided as + fun_viewpool to forward()) this must be its number of features. + Otherwise this must be set to 0. (If used from GenericModel, + this config value will be overridden automatically.) + encoding_dim: If global coding is in use (provided as global_code + to forward()) this must be its number of featuress. + Otherwise this must be set to 0. (If used from GenericModel, + this config value will be overridden automatically.) + """ + feature_vector_size: int = 3 d_in: int = 3 d_out: int = 1 @@ -23,7 +58,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): bias: float = 1.0 skip_in: Sequence[int] = () weight_norm: bool = True - n_harmonic_functions_xyz: int = 0 + n_harmonic_functions_xyz: int = -1 pooled_feature_dim: int = 0 encoding_dim: int = 0 @@ -33,7 +68,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size] self.embed_fn = None - if self.n_harmonic_functions_xyz > 0: + if self.n_harmonic_functions_xyz >= 0: self.embed_fn = HarmonicEmbedding( self.n_harmonic_functions_xyz, append_input=True ) @@ -63,13 +98,13 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): std=0.0001, ) torch.nn.init.constant_(lin.bias, -self.bias) - elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0: + elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.constant_(lin.weight[:, 3:], 0.0) torch.nn.init.normal_( lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5 ) - elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in: + elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in: torch.nn.init.constant_(lin.bias, 0.0) torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5) torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) @@ -110,27 +145,26 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): # Tensor]`. ).view(0, self.out_dim) - embedding = None + embeddings = [] if self.embed_fn is not None: # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. - embedding = self.embed_fn(rays_points_world) + embeddings.append(self.embed_fn(rays_points_world)) if fun_viewpool is not None: assert rays_points_world.ndim == 2 pooled_feature = fun_viewpool(rays_points_world[None]) # TODO: pooled features are 4D! - embedding = torch.cat((embedding, pooled_feature), dim=-1) + embeddings.append(pooled_feature) if global_code is not None: - assert embedding.ndim == 2 assert global_code.shape[0] == 1 # TODO: generalize to batches! # This will require changing raytracer code # embedding = embedding[None].expand(global_code.shape[0], *embedding.shape) - embedding = torch.cat( - (embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)), - dim=-1, + embeddings.append( + global_code[0, None, :].expand(rays_points_world.shape[0], -1) ) + embedding = torch.cat(embeddings, dim=-1) x = embedding # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s... for layer_idx in range(self.num_layers - 1): diff --git a/pytorch3d/implicitron/models/renderer/rgb_net.py b/pytorch3d/implicitron/models/renderer/rgb_net.py index e1b81587..47609e83 100644 --- a/pytorch3d/implicitron/models/renderer/rgb_net.py +++ b/pytorch3d/implicitron/models/renderer/rgb_net.py @@ -16,6 +16,30 @@ logger = logging.getLogger(__name__) class RayNormalColoringNetwork(torch.nn.Module): + """ + Members: + d_in and feature_vector_size: Sum of these is the input + dimension. These must add up to the sum of + - 3 [for the points] + - 3 unless mode=no_normal [for the normals] + - 3 unless mode=no_view_dir [for view directions] + - the feature size, [number of channels in feature_vectors] + + d_out: dimension of output. + mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting + part of the network input. + dims: list of hidden layer sizes. + weight_norm: whether to apply weight normalization to each layer. + n_harmonic_functions_dir: + If >0, use a harmonic embedding with this number of + harmonic functions for the view direction. Otherwise view directions + are fed without embedding, unless mode is `no_view_dir`. + pooled_feature_dim: If a pooling function is in use (provided as + pooling_fn to forward()) this must be its number of features. + Otherwise this must be set to 0. (If used from GenericModel, + this will be set automatically.) + """ + def __init__( self, feature_vector_size: int = 3,