clean IF args

Summary: continued - avoid duplicate inputs

Reviewed By: davnov134

Differential Revision: D38248827

fbshipit-source-id: 91ed398e304496a936f66e7a70ab3d189eeb5c70
This commit is contained in:
Jeremy Reizenstein 2022-08-03 12:37:31 -07:00 committed by Facebook GitHub Bot
parent 078846d166
commit 46e82efb4e
6 changed files with 65 additions and 37 deletions

View File

@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
n_layers_xyz: 8 n_layers_xyz: 8
append_xyz: append_xyz:
- 5 - 5
latent_dim: 0
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024 n_rays_per_image_sampled_from_mask: 1024
scene_extent: 8.0 scene_extent: 8.0

View File

@ -313,7 +313,6 @@ model_factory_ImplicitronModelFactory_args:
- AVG - AVG
- STD - STD
implicit_function_IdrFeatureField_args: implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3 d_in: 3
d_out: 1 d_out: 1
dims: dims:
@ -331,15 +330,12 @@ model_factory_ImplicitronModelFactory_args:
weight_norm: true weight_norm: true
n_harmonic_functions_xyz: 0 n_harmonic_functions_xyz: 0
pooled_feature_dim: 0 pooled_feature_dim: 0
encoding_dim: 0
implicit_function_NeRFormerImplicitFunction_args: implicit_function_NeRFormerImplicitFunction_args:
n_harmonic_functions_xyz: 10 n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4 n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128 n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true input_xyz: true
xyz_ray_dir_in_camera_coords: false xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 2.0 transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80 n_hidden_neurons_xyz: 80
n_layers_xyz: 2 n_layers_xyz: 2
@ -349,10 +345,8 @@ model_factory_ImplicitronModelFactory_args:
n_harmonic_functions_xyz: 10 n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4 n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128 n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true input_xyz: true
xyz_ray_dir_in_camera_coords: false xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 1.0 transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256 n_hidden_neurons_xyz: 256
n_layers_xyz: 8 n_layers_xyz: 8
@ -367,8 +361,6 @@ model_factory_ImplicitronModelFactory_args:
n_layers_hypernet: 1 n_layers_hypernet: 1
in_features: 3 in_features: 3
out_features: 256 out_features: 256
latent_dim_hypernet: 0
latent_dim: 0
xyz_in_camera_coords: false xyz_in_camera_coords: false
pixel_generator_args: pixel_generator_args:
n_harmonic_functions: 4 n_harmonic_functions: 4
@ -385,7 +377,6 @@ model_factory_ImplicitronModelFactory_args:
n_layers: 2 n_layers: 2
in_features: 3 in_features: 3
out_features: 256 out_features: 256
latent_dim: 0
xyz_in_camera_coords: false xyz_in_camera_coords: false
raymarch_function: null raymarch_function: null
pixel_generator_args: pixel_generator_args:

View File

@ -677,6 +677,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
""" """
pass pass
@classmethod
def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
"""
We don't expose certain implicit_function fields because we want to set
them based on other inputs.
"""
args.pop("feature_vector_size", None)
args.pop("encoding_dim", None)
args.pop("latent_dim", None)
args.pop("latent_dim_hypernet", None)
args.pop("color_dim", None)
def _construct_implicit_functions(self): def _construct_implicit_functions(self):
""" """
After run_auto_creation has been called, the arguments After run_auto_creation has been called, the arguments
@ -686,32 +698,31 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
implicit function method. Then the required implicit implicit function method. Then the required implicit
function(s) are initialized. function(s) are initialized.
""" """
# nerf preprocessing extra_args = {}
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args if self.implicit_function_class_type in (
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args "NeuralRadianceFieldImplicitFunction",
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = ( "NeRFormerImplicitFunction",
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim() ):
) extra_args["latent_dim"] = (
nerf_args["color_dim"] = nerformer_args[ self._get_viewpooled_feature_dim()
"color_dim" + self._get_global_encoder_encoding_dim()
] = self.render_features_dimensions )
extra_args["color_dim"] = self.render_features_dimensions
# idr preprocessing if self.implicit_function_class_type == "IdrFeatureField":
idr = self.implicit_function_IdrFeatureField_args extra_args["feature_vector_size"] = self.render_features_dimensions
idr["feature_vector_size"] = self.render_features_dimensions extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
# srn preprocessing if self.implicit_function_class_type == "SRNImplicitFunction":
srn = self.implicit_function_SRNImplicitFunction_args extra_args["latent_dim"] = (
srn.raymarch_function_args.latent_dim = ( self._get_viewpooled_feature_dim()
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim() + self._get_global_encoder_encoding_dim()
) )
# srn_hypernet preprocessing # srn_hypernet preprocessing
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
srn_hypernet_args = srn_hypernet.hypernet_args extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim() extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
# check that for srn, srn_hypernet, idr we have self.num_passes=1 # check that for srn, srn_hypernet, idr we have self.num_passes=1
implicit_function_type = registry.get( implicit_function_type = registry.get(
@ -734,7 +745,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
if config is None: if config is None:
raise ValueError(f"{config_name} not present") raise ValueError(f"{config_name} not present")
implicit_functions_list = [ implicit_functions_list = [
ImplicitFunctionWrapper(implicit_function_type(**config)) ImplicitFunctionWrapper(implicit_function_type(**config, **extra_args))
for _ in range(self.num_passes) for _ in range(self.num_passes)
] ]
return torch.nn.ModuleList(implicit_functions_list) return torch.nn.ModuleList(implicit_functions_list)

View File

@ -4,6 +4,7 @@
from typing import Any, cast, Optional, Tuple from typing import Any, cast, Optional, Tuple
import torch import torch
from omegaconf import DictConfig
from pytorch3d.common.linear_with_repeat import LinearWithRepeat from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
@ -327,6 +328,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
@registry.register @registry.register
# pyre-fixme[13]: Uninitialized attribute # pyre-fixme[13]: Uninitialized attribute
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
latent_dim: int = 0
raymarch_function: SRNRaymarchFunction raymarch_function: SRNRaymarchFunction
pixel_generator: SRNPixelGenerator pixel_generator: SRNPixelGenerator
@ -334,6 +336,17 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
super().__init__() super().__init__()
run_auto_creation(self) run_auto_creation(self)
def create_raymarch_function(self) -> None:
self.raymarch_function = SRNRaymarchFunction(
latent_dim=self.latent_dim,
# pyre-ignore[32]
**self.raymarch_function_args,
)
@classmethod
def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
args.pop("latent_dim", None)
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: RayBundle,
@ -371,6 +384,8 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
the cache. the cache.
""" """
latent_dim_hypernet: int = 0
latent_dim: int = 0
hypernet: SRNRaymarchHyperNet hypernet: SRNRaymarchHyperNet
pixel_generator: SRNPixelGenerator pixel_generator: SRNPixelGenerator
@ -378,6 +393,19 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
super().__init__() super().__init__()
run_auto_creation(self) run_auto_creation(self)
def create_hypernet(self) -> None:
self.hypernet = SRNRaymarchHyperNet(
latent_dim=self.latent_dim,
latent_dim_hypernet=self.latent_dim_hypernet,
# pyre-ignore[32]
**self.hypernet_args,
)
@classmethod
def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
args.pop("latent_dim", None)
args.pop("latent_dim_hypernet", None)
def forward( def forward(
self, self,
ray_bundle: RayBundle, ray_bundle: RayBundle,

View File

@ -103,7 +103,6 @@ view_pooler_args:
weight_by_ray_angle_gamma: 1.0 weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1 min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args: implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3 d_in: 3
d_out: 1 d_out: 1
dims: dims:
@ -121,6 +120,5 @@ implicit_function_IdrFeatureField_args:
weight_norm: true weight_norm: true
n_harmonic_functions_xyz: 1729 n_harmonic_functions_xyz: 1729
pooled_feature_dim: 0 pooled_feature_dim: 0
encoding_dim: 0
view_metrics_ViewMetrics_args: {} view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {} regularization_metrics_RegularizationMetrics_args: {}

View File

@ -55,9 +55,10 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
def test_srn_hypernet_implicit_function(self): def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core? # TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39 latent_dim_hypernet = 39
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0") device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args) implicit_function = SRNHyperNetImplicitFunction(
latent_dim_hypernet=latent_dim_hypernet
)
implicit_function.to(device) implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device) bundle = self._get_bundle(device=device)