mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	clean IF args
Summary: continued - avoid duplicate inputs Reviewed By: davnov134 Differential Revision: D38248827 fbshipit-source-id: 91ed398e304496a936f66e7a70ab3d189eeb5c70
This commit is contained in:
		
							parent
							
								
									078846d166
								
							
						
					
					
						commit
						46e82efb4e
					
				@ -43,7 +43,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      n_layers_xyz: 8
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 5
 | 
			
		||||
      latent_dim: 0
 | 
			
		||||
    raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
      n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
      scene_extent: 8.0
 | 
			
		||||
 | 
			
		||||
@ -313,7 +313,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
        - AVG
 | 
			
		||||
        - STD
 | 
			
		||||
    implicit_function_IdrFeatureField_args:
 | 
			
		||||
      feature_vector_size: 3
 | 
			
		||||
      d_in: 3
 | 
			
		||||
      d_out: 1
 | 
			
		||||
      dims:
 | 
			
		||||
@ -331,15 +330,12 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      weight_norm: true
 | 
			
		||||
      n_harmonic_functions_xyz: 0
 | 
			
		||||
      pooled_feature_dim: 0
 | 
			
		||||
      encoding_dim: 0
 | 
			
		||||
    implicit_function_NeRFormerImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      latent_dim: 0
 | 
			
		||||
      input_xyz: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      color_dim: 3
 | 
			
		||||
      transformer_dim_down_factor: 2.0
 | 
			
		||||
      n_hidden_neurons_xyz: 80
 | 
			
		||||
      n_layers_xyz: 2
 | 
			
		||||
@ -349,10 +345,8 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      latent_dim: 0
 | 
			
		||||
      input_xyz: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      color_dim: 3
 | 
			
		||||
      transformer_dim_down_factor: 1.0
 | 
			
		||||
      n_hidden_neurons_xyz: 256
 | 
			
		||||
      n_layers_xyz: 8
 | 
			
		||||
@ -367,8 +361,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
        n_layers_hypernet: 1
 | 
			
		||||
        in_features: 3
 | 
			
		||||
        out_features: 256
 | 
			
		||||
        latent_dim_hypernet: 0
 | 
			
		||||
        latent_dim: 0
 | 
			
		||||
        xyz_in_camera_coords: false
 | 
			
		||||
      pixel_generator_args:
 | 
			
		||||
        n_harmonic_functions: 4
 | 
			
		||||
@ -385,7 +377,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 3
 | 
			
		||||
        out_features: 256
 | 
			
		||||
        latent_dim: 0
 | 
			
		||||
        xyz_in_camera_coords: false
 | 
			
		||||
        raymarch_function: null
 | 
			
		||||
      pixel_generator_args:
 | 
			
		||||
 | 
			
		||||
@ -677,6 +677,18 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        """
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        We don't expose certain implicit_function fields because we want to set
 | 
			
		||||
        them based on other inputs.
 | 
			
		||||
        """
 | 
			
		||||
        args.pop("feature_vector_size", None)
 | 
			
		||||
        args.pop("encoding_dim", None)
 | 
			
		||||
        args.pop("latent_dim", None)
 | 
			
		||||
        args.pop("latent_dim_hypernet", None)
 | 
			
		||||
        args.pop("color_dim", None)
 | 
			
		||||
 | 
			
		||||
    def _construct_implicit_functions(self):
 | 
			
		||||
        """
 | 
			
		||||
        After run_auto_creation has been called, the arguments
 | 
			
		||||
@ -686,32 +698,31 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        implicit function method. Then the required implicit
 | 
			
		||||
        function(s) are initialized.
 | 
			
		||||
        """
 | 
			
		||||
        # nerf preprocessing
 | 
			
		||||
        nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
 | 
			
		||||
        nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
 | 
			
		||||
        nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
 | 
			
		||||
            self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
        nerf_args["color_dim"] = nerformer_args[
 | 
			
		||||
            "color_dim"
 | 
			
		||||
        ] = self.render_features_dimensions
 | 
			
		||||
        extra_args = {}
 | 
			
		||||
        if self.implicit_function_class_type in (
 | 
			
		||||
            "NeuralRadianceFieldImplicitFunction",
 | 
			
		||||
            "NeRFormerImplicitFunction",
 | 
			
		||||
        ):
 | 
			
		||||
            extra_args["latent_dim"] = (
 | 
			
		||||
                self._get_viewpooled_feature_dim()
 | 
			
		||||
                + self._get_global_encoder_encoding_dim()
 | 
			
		||||
            )
 | 
			
		||||
            extra_args["color_dim"] = self.render_features_dimensions
 | 
			
		||||
 | 
			
		||||
        # idr preprocessing
 | 
			
		||||
        idr = self.implicit_function_IdrFeatureField_args
 | 
			
		||||
        idr["feature_vector_size"] = self.render_features_dimensions
 | 
			
		||||
        idr["encoding_dim"] = self._get_global_encoder_encoding_dim()
 | 
			
		||||
        if self.implicit_function_class_type == "IdrFeatureField":
 | 
			
		||||
            extra_args["feature_vector_size"] = self.render_features_dimensions
 | 
			
		||||
            extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
 | 
			
		||||
 | 
			
		||||
        # srn preprocessing
 | 
			
		||||
        srn = self.implicit_function_SRNImplicitFunction_args
 | 
			
		||||
        srn.raymarch_function_args.latent_dim = (
 | 
			
		||||
            self._get_viewpooled_feature_dim() + self._get_global_encoder_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
        if self.implicit_function_class_type == "SRNImplicitFunction":
 | 
			
		||||
            extra_args["latent_dim"] = (
 | 
			
		||||
                self._get_viewpooled_feature_dim()
 | 
			
		||||
                + self._get_global_encoder_encoding_dim()
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        # srn_hypernet preprocessing
 | 
			
		||||
        srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
 | 
			
		||||
        srn_hypernet_args = srn_hypernet.hypernet_args
 | 
			
		||||
        srn_hypernet_args.latent_dim_hypernet = self._get_global_encoder_encoding_dim()
 | 
			
		||||
        srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
 | 
			
		||||
        if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
 | 
			
		||||
            extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
 | 
			
		||||
            extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
 | 
			
		||||
 | 
			
		||||
        # check that for srn, srn_hypernet, idr we have self.num_passes=1
 | 
			
		||||
        implicit_function_type = registry.get(
 | 
			
		||||
@ -734,7 +745,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        if config is None:
 | 
			
		||||
            raise ValueError(f"{config_name} not present")
 | 
			
		||||
        implicit_functions_list = [
 | 
			
		||||
            ImplicitFunctionWrapper(implicit_function_type(**config))
 | 
			
		||||
            ImplicitFunctionWrapper(implicit_function_type(**config, **extra_args))
 | 
			
		||||
            for _ in range(self.num_passes)
 | 
			
		||||
        ]
 | 
			
		||||
        return torch.nn.ModuleList(implicit_functions_list)
 | 
			
		||||
 | 
			
		||||
@ -4,6 +4,7 @@
 | 
			
		||||
from typing import Any, cast, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
 | 
			
		||||
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
 | 
			
		||||
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
 | 
			
		||||
@ -327,6 +328,7 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
 | 
			
		||||
@registry.register
 | 
			
		||||
# pyre-fixme[13]: Uninitialized attribute
 | 
			
		||||
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    latent_dim: int = 0
 | 
			
		||||
    raymarch_function: SRNRaymarchFunction
 | 
			
		||||
    pixel_generator: SRNPixelGenerator
 | 
			
		||||
 | 
			
		||||
@ -334,6 +336,17 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def create_raymarch_function(self) -> None:
 | 
			
		||||
        self.raymarch_function = SRNRaymarchFunction(
 | 
			
		||||
            latent_dim=self.latent_dim,
 | 
			
		||||
            # pyre-ignore[32]
 | 
			
		||||
            **self.raymarch_function_args,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def raymarch_function_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        args.pop("latent_dim", None)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        ray_bundle: RayBundle,
 | 
			
		||||
@ -371,6 +384,8 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    the cache.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    latent_dim_hypernet: int = 0
 | 
			
		||||
    latent_dim: int = 0
 | 
			
		||||
    hypernet: SRNRaymarchHyperNet
 | 
			
		||||
    pixel_generator: SRNPixelGenerator
 | 
			
		||||
 | 
			
		||||
@ -378,6 +393,19 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def create_hypernet(self) -> None:
 | 
			
		||||
        self.hypernet = SRNRaymarchHyperNet(
 | 
			
		||||
            latent_dim=self.latent_dim,
 | 
			
		||||
            latent_dim_hypernet=self.latent_dim_hypernet,
 | 
			
		||||
            # pyre-ignore[32]
 | 
			
		||||
            **self.hypernet_args,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def hypernet_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        args.pop("latent_dim", None)
 | 
			
		||||
        args.pop("latent_dim_hypernet", None)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        ray_bundle: RayBundle,
 | 
			
		||||
 | 
			
		||||
@ -103,7 +103,6 @@ view_pooler_args:
 | 
			
		||||
    weight_by_ray_angle_gamma: 1.0
 | 
			
		||||
    min_ray_angle_weight: 0.1
 | 
			
		||||
implicit_function_IdrFeatureField_args:
 | 
			
		||||
  feature_vector_size: 3
 | 
			
		||||
  d_in: 3
 | 
			
		||||
  d_out: 1
 | 
			
		||||
  dims:
 | 
			
		||||
@ -121,6 +120,5 @@ implicit_function_IdrFeatureField_args:
 | 
			
		||||
  weight_norm: true
 | 
			
		||||
  n_harmonic_functions_xyz: 1729
 | 
			
		||||
  pooled_feature_dim: 0
 | 
			
		||||
  encoding_dim: 0
 | 
			
		||||
view_metrics_ViewMetrics_args: {}
 | 
			
		||||
regularization_metrics_RegularizationMetrics_args: {}
 | 
			
		||||
 | 
			
		||||
@ -55,9 +55,10 @@ class TestSRN(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
    def test_srn_hypernet_implicit_function(self):
 | 
			
		||||
        # TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
 | 
			
		||||
        latent_dim_hypernet = 39
 | 
			
		||||
        hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
 | 
			
		||||
        implicit_function = SRNHyperNetImplicitFunction(
 | 
			
		||||
            latent_dim_hypernet=latent_dim_hypernet
 | 
			
		||||
        )
 | 
			
		||||
        implicit_function.to(device)
 | 
			
		||||
        global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
 | 
			
		||||
        bundle = self._get_bundle(device=device)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user