mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
clean IF args
Summary: continued - avoid duplicate inputs Reviewed By: davnov134 Differential Revision: D38248827 fbshipit-source-id: 91ed398e304496a936f66e7a70ab3d189eeb5c70
This commit is contained in:
parent
078846d166
commit
46e82efb4e
@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_layers_xyz: 8
|
n_layers_xyz: 8
|
||||||
append_xyz:
|
append_xyz:
|
||||||
- 5
|
- 5
|
||||||
latent_dim: 0
|
|
||||||
raysampler_AdaptiveRaySampler_args:
|
raysampler_AdaptiveRaySampler_args:
|
||||||
n_rays_per_image_sampled_from_mask: 1024
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
scene_extent: 8.0
|
scene_extent: 8.0
|
||||||
|
@ -313,7 +313,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
- AVG
|
- AVG
|
||||||
- STD
|
- STD
|
||||||
implicit_function_IdrFeatureField_args:
|
implicit_function_IdrFeatureField_args:
|
||||||
feature_vector_size: 3
|
|
||||||
d_in: 3
|
d_in: 3
|
||||||
d_out: 1
|
d_out: 1
|
||||||
dims:
|
dims:
|
||||||
@ -331,15 +330,12 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
weight_norm: true
|
weight_norm: true
|
||||||
n_harmonic_functions_xyz: 0
|
n_harmonic_functions_xyz: 0
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
encoding_dim: 0
|
|
||||||
implicit_function_NeRFormerImplicitFunction_args:
|
implicit_function_NeRFormerImplicitFunction_args:
|
||||||
n_harmonic_functions_xyz: 10
|
n_harmonic_functions_xyz: 10
|
||||||
n_harmonic_functions_dir: 4
|
n_harmonic_functions_dir: 4
|
||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
latent_dim: 0
|
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
color_dim: 3
|
|
||||||
transformer_dim_down_factor: 2.0
|
transformer_dim_down_factor: 2.0
|
||||||
n_hidden_neurons_xyz: 80
|
n_hidden_neurons_xyz: 80
|
||||||
n_layers_xyz: 2
|
n_layers_xyz: 2
|
||||||
@ -349,10 +345,8 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_harmonic_functions_xyz: 10
|
n_harmonic_functions_xyz: 10
|
||||||
n_harmonic_functions_dir: 4
|
n_harmonic_functions_dir: 4
|
||||||
n_hidden_neurons_dir: 128
|
n_hidden_neurons_dir: 128
|
||||||
latent_dim: 0
|
|
||||||
input_xyz: true
|
input_xyz: true
|
||||||
xyz_ray_dir_in_camera_coords: false
|
xyz_ray_dir_in_camera_coords: false
|
||||||
color_dim: 3
|
|
||||||
transformer_dim_down_factor: 1.0
|
transformer_dim_down_factor: 1.0
|
||||||
n_hidden_neurons_xyz: 256
|
n_hidden_neurons_xyz: 256
|
||||||
n_layers_xyz: 8
|
n_layers_xyz: 8
|
||||||
@ -367,8 +361,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_layers_hypernet: 1
|
n_layers_hypernet: 1
|
||||||
in_features: 3
|
in_features: 3
|
||||||
out_features: 256
|
out_features: 256
|
||||||
latent_dim_hypernet: 0
|
|
||||||
latent_dim: 0
|
|
||||||
xyz_in_camera_coords: false
|
xyz_in_camera_coords: false
|
||||||
pixel_generator_args:
|
pixel_generator_args:
|
||||||
n_harmonic_functions: 4
|
n_harmonic_functions: 4
|
||||||
@ -385,7 +377,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
n_layers: 2
|
n_layers: 2
|
||||||
in_features: 3
|
in_features: 3
|
||||||
out_features: 256
|
out_features: 256
|
||||||
latent_dim: 0
|
|
||||||
xyz_in_camera_coords: false
|
xyz_in_camera_coords: false
|
||||||
raymarch_function: null
|
raymarch_function: null
|
||||||
pixel_generator_args:
|
pixel_generator_args:
|
||||||
|
@ -677,6 +677,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain implicit_function fields because we want to set
|
||||||
|
them based on other inputs.
|
||||||
|
"""
|
||||||
|
args.pop("feature_vector_size", None)
|
||||||
|
args.pop("encoding_dim", None)
|
||||||
|
args.pop("latent_dim", None)
|
||||||
|
args.pop("latent_dim_hypernet", None)
|
||||||
|
args.pop("color_dim", None)
|
||||||
|
|
||||||
def _construct_implicit_functions(self):
|
def _construct_implicit_functions(self):
|
||||||
"""
|
"""
|
||||||
After run_auto_creation has been called, the arguments
|
After run_auto_creation has been called, the arguments
|
||||||
@ -686,32 +698,31 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
implicit function method. Then the required implicit
|
implicit function method. Then the required implicit
|
||||||
function(s) are initialized.
|
function(s) are initialized.
|
||||||
"""
|
"""
|
||||||
# nerf preprocessing
|
extra_args = {}
|
||||||
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
|
if self.implicit_function_class_type in (
|
||||||
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
|
"NeuralRadianceFieldImplicitFunction",
|
||||||
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
|
"NeRFormerImplicitFunction",
|
||||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
):
|
||||||
)
|
extra_args["latent_dim"] = (
|
||||||
nerf_args["color_dim"] = nerformer_args[
|
self._get_viewpooled_feature_dim()
|
||||||
"color_dim"
|
+ self._get_global_encoder_encoding_dim()
|
||||||
] = self.render_features_dimensions
|
)
|
||||||
|
extra_args["color_dim"] = self.render_features_dimensions
|
||||||
|
|
||||||
# idr preprocessing
|
if self.implicit_function_class_type == "IdrFeatureField":
|
||||||
idr = self.implicit_function_IdrFeatureField_args
|
extra_args["feature_vector_size"] = self.render_features_dimensions
|
||||||
idr["feature_vector_size"] = self.render_features_dimensions
|
extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
||||||
idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
|
||||||
|
|
||||||
# srn preprocessing
|
if self.implicit_function_class_type == "SRNImplicitFunction":
|
||||||
srn = self.implicit_function_SRNImplicitFunction_args
|
extra_args["latent_dim"] = (
|
||||||
srn.raymarch_function_args.latent_dim = (
|
self._get_viewpooled_feature_dim()
|
||||||
self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
|
+ self._get_global_encoder_encoding_dim()
|
||||||
)
|
)
|
||||||
|
|
||||||
# srn_hypernet preprocessing
|
# srn_hypernet preprocessing
|
||||||
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
|
if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
|
||||||
srn_hypernet_args = srn_hypernet.hypernet_args
|
extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
|
||||||
srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
|
extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
|
||||||
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
|
|
||||||
|
|
||||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||||
implicit_function_type = registry.get(
|
implicit_function_type = registry.get(
|
||||||
@ -734,7 +745,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
if config is None:
|
if config is None:
|
||||||
raise ValueError(f"{config_name} not present")
|
raise ValueError(f"{config_name} not present")
|
||||||
implicit_functions_list = [
|
implicit_functions_list = [
|
||||||
ImplicitFunctionWrapper(implicit_function_type(**config))
|
ImplicitFunctionWrapper(implicit_function_type(**config, **extra_args))
|
||||||
for _ in range(self.num_passes)
|
for _ in range(self.num_passes)
|
||||||
]
|
]
|
||||||
return torch.nn.ModuleList(implicit_functions_list)
|
return torch.nn.ModuleList(implicit_functions_list)
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
from typing import Any, cast, Optional, Tuple
|
from typing import Any, cast, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
|
||||||
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
|
||||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
|
||||||
@ -327,6 +328,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
|
|||||||
@registry.register
|
@registry.register
|
||||||
# pyre-fixme[13]: Uninitialized attribute
|
# pyre-fixme[13]: Uninitialized attribute
|
||||||
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||||
|
latent_dim: int = 0
|
||||||
raymarch_function: SRNRaymarchFunction
|
raymarch_function: SRNRaymarchFunction
|
||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
@ -334,6 +336,17 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def create_raymarch_function(self) -> None:
|
||||||
|
self.raymarch_function = SRNRaymarchFunction(
|
||||||
|
latent_dim=self.latent_dim,
|
||||||
|
# pyre-ignore[32]
|
||||||
|
**self.raymarch_function_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
args.pop("latent_dim", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
@ -371,6 +384,8 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
the cache.
|
the cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
latent_dim_hypernet: int = 0
|
||||||
|
latent_dim: int = 0
|
||||||
hypernet: SRNRaymarchHyperNet
|
hypernet: SRNRaymarchHyperNet
|
||||||
pixel_generator: SRNPixelGenerator
|
pixel_generator: SRNPixelGenerator
|
||||||
|
|
||||||
@ -378,6 +393,19 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
|
def create_hypernet(self) -> None:
|
||||||
|
self.hypernet = SRNRaymarchHyperNet(
|
||||||
|
latent_dim=self.latent_dim,
|
||||||
|
latent_dim_hypernet=self.latent_dim_hypernet,
|
||||||
|
# pyre-ignore[32]
|
||||||
|
**self.hypernet_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
args.pop("latent_dim", None)
|
||||||
|
args.pop("latent_dim_hypernet", None)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
ray_bundle: RayBundle,
|
ray_bundle: RayBundle,
|
||||||
|
@ -103,7 +103,6 @@ view_pooler_args:
|
|||||||
weight_by_ray_angle_gamma: 1.0
|
weight_by_ray_angle_gamma: 1.0
|
||||||
min_ray_angle_weight: 0.1
|
min_ray_angle_weight: 0.1
|
||||||
implicit_function_IdrFeatureField_args:
|
implicit_function_IdrFeatureField_args:
|
||||||
feature_vector_size: 3
|
|
||||||
d_in: 3
|
d_in: 3
|
||||||
d_out: 1
|
d_out: 1
|
||||||
dims:
|
dims:
|
||||||
@ -121,6 +120,5 @@ implicit_function_IdrFeatureField_args:
|
|||||||
weight_norm: true
|
weight_norm: true
|
||||||
n_harmonic_functions_xyz: 1729
|
n_harmonic_functions_xyz: 1729
|
||||||
pooled_feature_dim: 0
|
pooled_feature_dim: 0
|
||||||
encoding_dim: 0
|
|
||||||
view_metrics_ViewMetrics_args: {}
|
view_metrics_ViewMetrics_args: {}
|
||||||
regularization_metrics_RegularizationMetrics_args: {}
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
|
@ -55,9 +55,10 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_srn_hypernet_implicit_function(self):
|
def test_srn_hypernet_implicit_function(self):
|
||||||
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
|
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
|
||||||
latent_dim_hypernet = 39
|
latent_dim_hypernet = 39
|
||||||
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
|
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
|
implicit_function = SRNHyperNetImplicitFunction(
|
||||||
|
latent_dim_hypernet=latent_dim_hypernet
|
||||||
|
)
|
||||||
implicit_function.to(device)
|
implicit_function.to(device)
|
||||||
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
||||||
bundle = self._get_bundle(device=device)
|
bundle = self._get_bundle(device=device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user