mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
idr harmonic_fns and doc
Summary: Document the inputs of idr functions and distinguish n_harmonic_functions to be 0 (simple embedding) versus -1 (no embedding). Reviewed By: davnov134 Differential Revision: D37209012 fbshipit-source-id: 6e5c3eae54c4e5e8c3f76cad1caf162c6c222d52
This commit is contained in:
parent
28c1afaa9d
commit
81d63c6382
@ -201,7 +201,7 @@ generic_model_args:
|
|||||||
bias: 1.0
|
bias: 1.0
|
||||||
skip_in: []
|
skip_in: []
|
||||||
weight_norm: true
|
weight_norm: true
|
||||||
n_harmonic_functions_xyz: 0
|
n_harmonic_functions_xyz: -1
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
encoding_dim: 0
|
encoding_dim: 0
|
||||||
implicit_function_NeRFormerImplicitFunction_args:
|
implicit_function_NeRFormerImplicitFunction_args:
|
||||||
|
@ -15,6 +15,41 @@ from .base import ImplicitFunctionBase
|
|||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implicit function as used in http://github.com/lioryariv/idr.
|
||||||
|
|
||||||
|
Members:
|
||||||
|
d_in: dimension of the input point.
|
||||||
|
n_harmonic_functions_xyz: If -1, do not embed the point.
|
||||||
|
If >=0, use a harmonic embedding with this number of
|
||||||
|
harmonic functions. (The harmonic embedding includes the input
|
||||||
|
itself, so a value of 0 means the point is used but without
|
||||||
|
any harmonic functions.)
|
||||||
|
d_out and feature_vector_size: Sum of these is the output
|
||||||
|
dimension. This implicit function thus returns a concatenation
|
||||||
|
of `d_out` signed distance function values and `feature_vector_size`
|
||||||
|
features (such as colors). When used in `GenericModel`,
|
||||||
|
`feature_vector_size` corresponds is automatically set to
|
||||||
|
`render_features_dimensions`.
|
||||||
|
dims: list of hidden layer sizes.
|
||||||
|
geometric_init: whether to use custom weight initialization
|
||||||
|
in linear layers. If False, pytorch default (uniform sampling)
|
||||||
|
is used.
|
||||||
|
bias: if geometric_init=True, initial value for bias subtracted
|
||||||
|
in the last layer.
|
||||||
|
skip_in: List of indices of layers that receive as input the initial
|
||||||
|
value concatenated with the output of the previous layers.
|
||||||
|
weight_norm: whether to apply weight normalization to each layer.
|
||||||
|
pooled_feature_dim: If view pooling is in use (provided as
|
||||||
|
fun_viewpool to forward()) this must be its number of features.
|
||||||
|
Otherwise this must be set to 0. (If used from GenericModel,
|
||||||
|
this config value will be overridden automatically.)
|
||||||
|
encoding_dim: If global coding is in use (provided as global_code
|
||||||
|
to forward()) this must be its number of featuress.
|
||||||
|
Otherwise this must be set to 0. (If used from GenericModel,
|
||||||
|
this config value will be overridden automatically.)
|
||||||
|
"""
|
||||||
|
|
||||||
feature_vector_size: int = 3
|
feature_vector_size: int = 3
|
||||||
d_in: int = 3
|
d_in: int = 3
|
||||||
d_out: int = 1
|
d_out: int = 1
|
||||||
@ -23,7 +58,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
bias: float = 1.0
|
bias: float = 1.0
|
||||||
skip_in: Sequence[int] = ()
|
skip_in: Sequence[int] = ()
|
||||||
weight_norm: bool = True
|
weight_norm: bool = True
|
||||||
n_harmonic_functions_xyz: int = 0
|
n_harmonic_functions_xyz: int = -1
|
||||||
pooled_feature_dim: int = 0
|
pooled_feature_dim: int = 0
|
||||||
encoding_dim: int = 0
|
encoding_dim: int = 0
|
||||||
|
|
||||||
@ -33,7 +68,7 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
|
||||||
|
|
||||||
self.embed_fn = None
|
self.embed_fn = None
|
||||||
if self.n_harmonic_functions_xyz > 0:
|
if self.n_harmonic_functions_xyz >= 0:
|
||||||
self.embed_fn = HarmonicEmbedding(
|
self.embed_fn = HarmonicEmbedding(
|
||||||
self.n_harmonic_functions_xyz, append_input=True
|
self.n_harmonic_functions_xyz, append_input=True
|
||||||
)
|
)
|
||||||
@ -63,13 +98,13 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
std=0.0001,
|
std=0.0001,
|
||||||
)
|
)
|
||||||
torch.nn.init.constant_(lin.bias, -self.bias)
|
torch.nn.init.constant_(lin.bias, -self.bias)
|
||||||
elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0:
|
elif self.n_harmonic_functions_xyz >= 0 and layer_idx == 0:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
||||||
torch.nn.init.normal_(
|
torch.nn.init.normal_(
|
||||||
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
|
lin.weight[:, :3], 0.0, 2**0.5 / out_dim**0.5
|
||||||
)
|
)
|
||||||
elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in:
|
elif self.n_harmonic_functions_xyz >= 0 and layer_idx in self.skip_in:
|
||||||
torch.nn.init.constant_(lin.bias, 0.0)
|
torch.nn.init.constant_(lin.bias, 0.0)
|
||||||
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
|
torch.nn.init.normal_(lin.weight, 0.0, 2**0.5 / out_dim**0.5)
|
||||||
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
|
||||||
@ -110,27 +145,26 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
# Tensor]`.
|
# Tensor]`.
|
||||||
).view(0, self.out_dim)
|
).view(0, self.out_dim)
|
||||||
|
|
||||||
embedding = None
|
embeddings = []
|
||||||
if self.embed_fn is not None:
|
if self.embed_fn is not None:
|
||||||
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
|
||||||
embedding = self.embed_fn(rays_points_world)
|
embeddings.append(self.embed_fn(rays_points_world))
|
||||||
|
|
||||||
if fun_viewpool is not None:
|
if fun_viewpool is not None:
|
||||||
assert rays_points_world.ndim == 2
|
assert rays_points_world.ndim == 2
|
||||||
pooled_feature = fun_viewpool(rays_points_world[None])
|
pooled_feature = fun_viewpool(rays_points_world[None])
|
||||||
# TODO: pooled features are 4D!
|
# TODO: pooled features are 4D!
|
||||||
embedding = torch.cat((embedding, pooled_feature), dim=-1)
|
embeddings.append(pooled_feature)
|
||||||
|
|
||||||
if global_code is not None:
|
if global_code is not None:
|
||||||
assert embedding.ndim == 2
|
|
||||||
assert global_code.shape[0] == 1 # TODO: generalize to batches!
|
assert global_code.shape[0] == 1 # TODO: generalize to batches!
|
||||||
# This will require changing raytracer code
|
# This will require changing raytracer code
|
||||||
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
|
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
|
||||||
embedding = torch.cat(
|
embeddings.append(
|
||||||
(embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)),
|
global_code[0, None, :].expand(rays_points_world.shape[0], -1)
|
||||||
dim=-1,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
embedding = torch.cat(embeddings, dim=-1)
|
||||||
x = embedding
|
x = embedding
|
||||||
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s...
|
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s...
|
||||||
for layer_idx in range(self.num_layers - 1):
|
for layer_idx in range(self.num_layers - 1):
|
||||||
|
@ -16,6 +16,30 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class RayNormalColoringNetwork(torch.nn.Module):
|
class RayNormalColoringNetwork(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Members:
|
||||||
|
d_in and feature_vector_size: Sum of these is the input
|
||||||
|
dimension. These must add up to the sum of
|
||||||
|
- 3 [for the points]
|
||||||
|
- 3 unless mode=no_normal [for the normals]
|
||||||
|
- 3 unless mode=no_view_dir [for view directions]
|
||||||
|
- the feature size, [number of channels in feature_vectors]
|
||||||
|
|
||||||
|
d_out: dimension of output.
|
||||||
|
mode: One of "idr", "no_view_dir" or "no_normal" to allow omitting
|
||||||
|
part of the network input.
|
||||||
|
dims: list of hidden layer sizes.
|
||||||
|
weight_norm: whether to apply weight normalization to each layer.
|
||||||
|
n_harmonic_functions_dir:
|
||||||
|
If >0, use a harmonic embedding with this number of
|
||||||
|
harmonic functions for the view direction. Otherwise view directions
|
||||||
|
are fed without embedding, unless mode is `no_view_dir`.
|
||||||
|
pooled_feature_dim: If a pooling function is in use (provided as
|
||||||
|
pooling_fn to forward()) this must be its number of features.
|
||||||
|
Otherwise this must be set to 0. (If used from GenericModel,
|
||||||
|
this will be set automatically.)
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
feature_vector_size: int = 3,
|
feature_vector_size: int = 3,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user