mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
clean IF args
Summary: continued - avoid duplicate inputs Reviewed By: davnov134 Differential Revision: D38248827 fbshipit-source-id: 91ed398e304496a936f66e7a70ab3d189eeb5c70
This commit is contained in:
parent
078846d166
commit
46e82efb4e
@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_layers_xyz: 8
|
||||
append_xyz:
|
||||
- 5
|
||||
latent_dim: 0
|
||||
raysampler_AdaptiveRaySampler_args:
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
scene_extent: 8.0
|
||||
|
@ -313,7 +313,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
- AVG
|
||||
- STD
|
||||
implicit_function_IdrFeatureField_args:
|
||||
feature_vector_size: 3
|
||||
d_in: 3
|
||||
d_out: 1
|
||||
dims:
|
||||
@ -331,15 +330,12 @@ model_factory_ImplicitronModelFactory_args:
|
||||
weight_norm: true
|
||||
n_harmonic_functions_xyz: 0
|
||||
pooled_feature_dim: 0
|
||||
encoding_dim: 0
|
||||
implicit_function_NeRFormerImplicitFunction_args:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_dir: 128
|
||||
latent_dim: 0
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
color_dim: 3
|
||||
transformer_dim_down_factor: 2.0
|
||||
n_hidden_neurons_xyz: 80
|
||||
n_layers_xyz: 2
|
||||
@ -349,10 +345,8 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_harmonic_functions_xyz: 10
|
||||
n_harmonic_functions_dir: 4
|
||||
n_hidden_neurons_dir: 128
|
||||
latent_dim: 0
|
||||
input_xyz: true
|
||||
xyz_ray_dir_in_camera_coords: false
|
||||
color_dim: 3
|
||||
transformer_dim_down_factor: 1.0
|
||||
n_hidden_neurons_xyz: 256
|
||||
n_layers_xyz: 8
|
||||
@ -367,8 +361,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_layers_hypernet: 1
|
||||
in_features: 3
|
||||
out_features: 256
|
||||
latent_dim_hypernet: 0
|
||||
latent_dim: 0
|
||||
xyz_in_camera_coords: false
|
||||
pixel_generator_args:
|
||||
n_harmonic_functions: 4
|
||||
@ -385,7 +377,6 @@ model_factory_ImplicitronModelFactory_args:
|
||||
n_layers: 2
|
||||
in_features: 3
|
||||
out_features: 256
|
||||
latent_dim: 0
|
||||
xyz_in_camera_coords: false
|
||||
raymarch_function: null
|
||||
pixel_generator_args:
|
||||
|
@ -677,6 +677,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
We don't expose certain implicit_function fields because we want to set
|
||||
them based on other inputs.
|
||||
"""
|
||||
args.pop("feature_vector_size", None)
|
||||
args.pop("encoding_dim", None)
|
||||
args.pop("latent_dim", None)
|
||||
args.pop("latent_dim_hypernet", None)
|
||||
args.pop("color_dim", None)
|
||||
|
||||
def _construct_implicit_functions(self):
|
||||
"""
|
||||
After run_auto_creation has been called, the arguments
|
||||
@ -686,32 +698,31 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
implicit function method. Then the required implicit
|
||||
function(s) are initialized.
|
||||
"""
|
||||
# nerf preprocessing
|
||||
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
||||
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
||||
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||
extra_args = {}
|
||||
if self.implicit_function_class_type in (
|
||||
"NeuralRadianceFieldImplicitFunction",
|
||||
"NeRFormerImplicitFunction",
|
||||
):
|
||||
extra_args["latent_dim"] = (
|
||||
self._get_viewpooled_feature_dim()
|
||||
+ self._get_global_encoder_encoding_dim()
|
||||
)
|
||||
nerf_args["color_dim"] = nerformer_args[
|
||||
"color_dim"
|
||||
] = self.render_features_dimensions
|
||||
extra_args["color_dim"] = self.render_features_dimensions
|
||||
|
||||
# idr preprocessing
|
||||
idr = self.implicit_function_IdrFeatureField_args
|
||||
idr["feature_vector_size"] = self.render_features_dimensions
|
||||
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||
if self.implicit_function_class_type == "IdrFeatureField":
|
||||
extra_args["feature_vector_size"] = self.render_features_dimensions
|
||||
extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||
|
||||
# srn preprocessing
|
||||
srn = self.implicit_function_SRNImplicitFunction_args
|
||||
srn.raymarch_function_args.latent_dim = (
|
||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
||||
if self.implicit_function_class_type == "SRNImplicitFunction":
|
||||
extra_args["latent_dim"] = (
|
||||
self._get_viewpooled_feature_dim()
|
||||
+ self._get_global_encoder_encoding_dim()
|
||||
)
|
||||
|
||||
# srn_hypernet preprocessing
|
||||
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
||||
srn_hypernet_args = srn_hypernet.hypernet_args
|
||||
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
||||
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
||||
if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
|
||||
extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
|
||||
extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
|
||||
|
||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||
implicit_function_type = registry.get(
|
||||
@ -734,7 +745,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||
if config is None:
|
||||
raise ValueError(f"{config_name} not present")
|
||||
implicit_functions_list = [
|
||||
ImplicitFunctionWrapper(implicit_function_type(**config))
|
||||
ImplicitFunctionWrapper(implicit_function_type(**config, **extra_args))
|
||||
for _ in range(self.num_passes)
|
||||
]
|
||||
return torch.nn.ModuleList(implicit_functions_list)
|
||||
|
@ -4,6 +4,7 @@
|
||||
from typing import Any, cast, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
||||
@ -327,6 +328,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
||||
@registry.register
|
||||
# pyre-fixme[13]: Uninitialized attribute
|
||||
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
latent_dim: int = 0
|
||||
raymarch_function: SRNRaymarchFunction
|
||||
pixel_generator: SRNPixelGenerator
|
||||
|
||||
@ -334,6 +336,17 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def create_raymarch_function(self) -> None:
|
||||
self.raymarch_function = SRNRaymarchFunction(
|
||||
latent_dim=self.latent_dim,
|
||||
# pyre-ignore[32]
|
||||
**self.raymarch_function_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
args.pop("latent_dim", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: RayBundle,
|
||||
@ -371,6 +384,8 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
the cache.
|
||||
"""
|
||||
|
||||
latent_dim_hypernet: int = 0
|
||||
latent_dim: int = 0
|
||||
hypernet: SRNRaymarchHyperNet
|
||||
pixel_generator: SRNPixelGenerator
|
||||
|
||||
@ -378,6 +393,19 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
super().__init__()
|
||||
run_auto_creation(self)
|
||||
|
||||
def create_hypernet(self) -> None:
|
||||
self.hypernet = SRNRaymarchHyperNet(
|
||||
latent_dim=self.latent_dim,
|
||||
latent_dim_hypernet=self.latent_dim_hypernet,
|
||||
# pyre-ignore[32]
|
||||
**self.hypernet_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
args.pop("latent_dim", None)
|
||||
args.pop("latent_dim_hypernet", None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ray_bundle: RayBundle,
|
||||
|
@ -103,7 +103,6 @@ view_pooler_args:
|
||||
weight_by_ray_angle_gamma: 1.0
|
||||
min_ray_angle_weight: 0.1
|
||||
implicit_function_IdrFeatureField_args:
|
||||
feature_vector_size: 3
|
||||
d_in: 3
|
||||
d_out: 1
|
||||
dims:
|
||||
@ -121,6 +120,5 @@ implicit_function_IdrFeatureField_args:
|
||||
weight_norm: true
|
||||
n_harmonic_functions_xyz: 1729
|
||||
pooled_feature_dim: 0
|
||||
encoding_dim: 0
|
||||
view_metrics_ViewMetrics_args: {}
|
||||
regularization_metrics_RegularizationMetrics_args: {}
|
||||
|
@ -55,9 +55,10 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
||||
def test_srn_hypernet_implicit_function(self):
|
||||
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
|
||||
latent_dim_hypernet = 39
|
||||
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
|
||||
device = torch.device("cuda:0")
|
||||
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
|
||||
implicit_function = SRNHyperNetImplicitFunction(
|
||||
latent_dim_hypernet=latent_dim_hypernet
|
||||
)
|
||||
implicit_function.to(device)
|
||||
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
||||
bundle = self._get_bundle(device=device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user