mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Add the OverfitModel
Summary: Introduces the OverfitModel for NeRF-style training with overfitting to one scene. It is a specific case of GenericModel. It has been disentangle to ease usage. ## General modification 1. Modularize a minimum GenericModel to introduce OverfitModel 2. Introduce OverfitModel and ensure through unit testing that it behaves like GenericModel. ## Modularization The following methods have been extracted from GenericModel to allow modularity with ManyViewModel: - get_objective is now a call to weighted_sum_losses - log_loss_weights - prepare_inputs The generic methods have been moved to an utils.py file. Simplify the code to introduce OverfitModel. Private methods like chunk_generator are now public and can now be used by ManyViewModel. Reviewed By: shapovalov Differential Revision: D43771992 fbshipit-source-id: 6102aeb21c7fdd56aa2ff9cd1dd23fd9fbf26315
This commit is contained in:
		
							parent
							
								
									7d8b029aae
								
							
						
					
					
						commit
						813e941de5
					
				@ -248,7 +248,7 @@ The main object for this trainer loop is `Experiment`. It has four top-level rep
 | 
			
		||||
* `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`.
 | 
			
		||||
It constructs the data sets and dataloaders.
 | 
			
		||||
* `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`.
 | 
			
		||||
It constructs the model, which is usually an instance of implicitron's main `GenericModel` class, and can load its weights from a checkpoint.
 | 
			
		||||
It constructs the model, which is usually an instance of `OverfitModel` (for NeRF-style training with overfitting to one scene) or `GenericModel` (that is able to generalize to multiple scenes by NeRFormer-style conditioning on other scene views), and can load its weights from a checkpoint.
 | 
			
		||||
* `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`.
 | 
			
		||||
It constructs the optimizer and can load its weights from a checkpoint.
 | 
			
		||||
* `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop.
 | 
			
		||||
@ -292,6 +292,43 @@ model_GenericModel_args: GenericModel
 | 
			
		||||
        ╘== ReductionFeatureAggregator
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Here is the class structure of OverfitModel:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
model_OverfitModel_args: OverfitModel
 | 
			
		||||
└-- raysampler_*_args: RaySampler
 | 
			
		||||
    ╘== AdaptiveRaysampler
 | 
			
		||||
    ╘== NearFarRaysampler
 | 
			
		||||
└-- renderer_*_args: BaseRenderer
 | 
			
		||||
    ╘== MultiPassEmissionAbsorptionRenderer
 | 
			
		||||
    ╘== LSTMRenderer
 | 
			
		||||
    ╘== SignedDistanceFunctionRenderer
 | 
			
		||||
        └-- ray_tracer_args: RayTracing
 | 
			
		||||
        └-- ray_normal_coloring_network_args: RayNormalColoringNetwork
 | 
			
		||||
└-- implicit_function_*_args: ImplicitFunctionBase
 | 
			
		||||
    ╘== NeuralRadianceFieldImplicitFunction
 | 
			
		||||
    ╘== SRNImplicitFunction
 | 
			
		||||
        └-- raymarch_function_args: SRNRaymarchFunction
 | 
			
		||||
        └-- pixel_generator_args: SRNPixelGenerator
 | 
			
		||||
    ╘== SRNHyperNetImplicitFunction
 | 
			
		||||
        └-- hypernet_args: SRNRaymarchHyperNet
 | 
			
		||||
        └-- pixel_generator_args: SRNPixelGenerator
 | 
			
		||||
    ╘== IdrFeatureField
 | 
			
		||||
└-- coarse_implicit_function_*_args: ImplicitFunctionBase
 | 
			
		||||
    ╘== NeuralRadianceFieldImplicitFunction
 | 
			
		||||
    ╘== SRNImplicitFunction
 | 
			
		||||
        └-- raymarch_function_args: SRNRaymarchFunction
 | 
			
		||||
        └-- pixel_generator_args: SRNPixelGenerator
 | 
			
		||||
    ╘== SRNHyperNetImplicitFunction
 | 
			
		||||
        └-- hypernet_args: SRNRaymarchHyperNet
 | 
			
		||||
        └-- pixel_generator_args: SRNPixelGenerator
 | 
			
		||||
    ╘== IdrFeatureField
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
OverfitModel has been introduced to create a simple class to disantagle Nerfs which the overfit pattern
 | 
			
		||||
from the GenericModel.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
 | 
			
		||||
`tests/experiment.yaml` shows every possible option if you have no user-defined classes.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										79
									
								
								projects/implicitron_trainer/configs/overfit_base.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								projects/implicitron_trainer/configs/overfit_base.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,79 @@
 | 
			
		||||
defaults:
 | 
			
		||||
- default_config
 | 
			
		||||
- _self_
 | 
			
		||||
exp_dir: ./data/exps/overfit_base/
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  visdom_port: 8097
 | 
			
		||||
  visualize_interval: 0
 | 
			
		||||
  max_epochs: 1000
 | 
			
		||||
data_source_ImplicitronDataSource_args:
 | 
			
		||||
  data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
 | 
			
		||||
  dataset_map_provider_class_type: JsonIndexDatasetMapProvider
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    dataset_length_train: 1000
 | 
			
		||||
    dataset_length_val: 1
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
  dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
			
		||||
    dataset_root: ${oc.env:CO3D_DATASET_ROOT}
 | 
			
		||||
    n_frames_per_sequence: -1
 | 
			
		||||
    test_on_train: true
 | 
			
		||||
    test_restrict_sequence_id: 0
 | 
			
		||||
    dataset_JsonIndexDataset_args:
 | 
			
		||||
      load_point_clouds: false
 | 
			
		||||
      mask_depths: false
 | 
			
		||||
      mask_images: false
 | 
			
		||||
model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
  model_class_type: "OverfitModel"
 | 
			
		||||
  model_OverfitModel_args:
 | 
			
		||||
    loss_weights:
 | 
			
		||||
      loss_mask_bce: 1.0
 | 
			
		||||
      loss_prev_stage_mask_bce: 1.0
 | 
			
		||||
      loss_autodecoder_norm: 0.01
 | 
			
		||||
      loss_rgb_mse: 1.0
 | 
			
		||||
      loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
    output_rasterized_mc: false
 | 
			
		||||
    chunk_size_grid: 102400
 | 
			
		||||
    render_image_height: 400
 | 
			
		||||
    render_image_width: 400
 | 
			
		||||
    share_implicit_function_across_passes: false
 | 
			
		||||
    implicit_function_class_type: "NeuralRadianceFieldImplicitFunction"
 | 
			
		||||
    implicit_function_NeuralRadianceFieldImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_xyz: 256
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      n_layers_xyz: 8
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 5
 | 
			
		||||
    coarse_implicit_function_class_type: "NeuralRadianceFieldImplicitFunction"
 | 
			
		||||
    coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_xyz: 256
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      n_layers_xyz: 8
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 5
 | 
			
		||||
    raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
      n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
      scene_extent: 8.0
 | 
			
		||||
      n_pts_per_ray_training: 64
 | 
			
		||||
      n_pts_per_ray_evaluation: 64
 | 
			
		||||
      stratified_point_sampling_training: true
 | 
			
		||||
      stratified_point_sampling_evaluation: false
 | 
			
		||||
    renderer_MultiPassEmissionAbsorptionRenderer_args:
 | 
			
		||||
      n_pts_per_ray_fine_training: 64
 | 
			
		||||
      n_pts_per_ray_fine_evaluation: 64
 | 
			
		||||
      append_coarse_samples_to_fine: true
 | 
			
		||||
      density_noise_std_train: 1.0
 | 
			
		||||
optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  breed: Adam
 | 
			
		||||
  weight_decay: 0.0
 | 
			
		||||
  lr_policy: MultiStepLR
 | 
			
		||||
  multistep_lr_milestones: []
 | 
			
		||||
  lr: 0.0005
 | 
			
		||||
  gamma: 0.1
 | 
			
		||||
  momentum: 0.9
 | 
			
		||||
  betas:
 | 
			
		||||
  - 0.9
 | 
			
		||||
  - 0.999
 | 
			
		||||
@ -0,0 +1,42 @@
 | 
			
		||||
defaults:
 | 
			
		||||
- overfit_base
 | 
			
		||||
- _self_
 | 
			
		||||
data_source_ImplicitronDataSource_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    batch_size: 1
 | 
			
		||||
    dataset_length_train: 1000
 | 
			
		||||
    dataset_length_val: 1
 | 
			
		||||
    num_workers: 8
 | 
			
		||||
  dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
			
		||||
    assert_single_seq: true
 | 
			
		||||
    n_frames_per_sequence: -1
 | 
			
		||||
    test_restrict_sequence_id: 0
 | 
			
		||||
    test_on_train: false
 | 
			
		||||
model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
  model_class_type: "OverfitModel"
 | 
			
		||||
  model_OverfitModel_args:
 | 
			
		||||
    render_image_height: 800
 | 
			
		||||
    render_image_width: 800
 | 
			
		||||
    log_vars:
 | 
			
		||||
    - loss_rgb_psnr_fg
 | 
			
		||||
    - loss_rgb_psnr
 | 
			
		||||
    - loss_eikonal
 | 
			
		||||
    - loss_prev_stage_rgb_psnr
 | 
			
		||||
    - loss_mask_bce
 | 
			
		||||
    - loss_prev_stage_mask_bce
 | 
			
		||||
    - loss_rgb_mse
 | 
			
		||||
    - loss_prev_stage_rgb_mse
 | 
			
		||||
    - loss_depth_abs
 | 
			
		||||
    - loss_depth_abs_fg
 | 
			
		||||
    - loss_kl
 | 
			
		||||
    - loss_mask_neg_iou
 | 
			
		||||
    - objective
 | 
			
		||||
    - epoch
 | 
			
		||||
    - sec/it
 | 
			
		||||
optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  lr: 0.0005
 | 
			
		||||
  multistep_lr_milestones:
 | 
			
		||||
  - 200
 | 
			
		||||
  - 300
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  max_epochs: 400
 | 
			
		||||
@ -0,0 +1,56 @@
 | 
			
		||||
defaults:
 | 
			
		||||
- overfit_singleseq_base
 | 
			
		||||
- _self_
 | 
			
		||||
exp_dir: "./data/overfit_nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}"
 | 
			
		||||
data_source_ImplicitronDataSource_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    dataset_length_train: 100
 | 
			
		||||
  dataset_map_provider_class_type: BlenderDatasetMapProvider
 | 
			
		||||
  dataset_map_provider_BlenderDatasetMapProvider_args:
 | 
			
		||||
    base_dir: ${oc.env:BLENDER_DATASET_ROOT}/${oc.env:BLENDER_SINGLESEQ_CLASS}
 | 
			
		||||
    n_known_frames_for_test: null
 | 
			
		||||
    object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS}
 | 
			
		||||
    path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
    path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
      silence_logs: true
 | 
			
		||||
 | 
			
		||||
model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
  model_class_type: "OverfitModel"
 | 
			
		||||
  model_OverfitModel_args:
 | 
			
		||||
    mask_images: false
 | 
			
		||||
    raysampler_class_type: AdaptiveRaySampler
 | 
			
		||||
    raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
      n_pts_per_ray_training: 64
 | 
			
		||||
      n_pts_per_ray_evaluation: 64
 | 
			
		||||
      n_rays_per_image_sampled_from_mask: 4096
 | 
			
		||||
      stratified_point_sampling_training: true
 | 
			
		||||
      stratified_point_sampling_evaluation: false
 | 
			
		||||
      scene_extent: 2.0
 | 
			
		||||
      scene_center:
 | 
			
		||||
      - 0.0
 | 
			
		||||
      - 0.0
 | 
			
		||||
      - 0.0
 | 
			
		||||
    renderer_MultiPassEmissionAbsorptionRenderer_args:
 | 
			
		||||
      density_noise_std_train: 0.0
 | 
			
		||||
      n_pts_per_ray_fine_training: 128
 | 
			
		||||
      n_pts_per_ray_fine_evaluation: 128
 | 
			
		||||
      raymarcher_EmissionAbsorptionRaymarcher_args:
 | 
			
		||||
        blend_output: false
 | 
			
		||||
    loss_weights:
 | 
			
		||||
      loss_rgb_mse: 1.0
 | 
			
		||||
      loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
      loss_mask_bce: 0.0
 | 
			
		||||
      loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
      loss_autodecoder_norm: 0.00
 | 
			
		||||
 | 
			
		||||
optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  exponential_lr_step_size: 3001
 | 
			
		||||
  lr_policy: LinearExponential
 | 
			
		||||
  linear_exponential_lr_milestone: 200
 | 
			
		||||
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  max_epochs: 6000
 | 
			
		||||
  metric_print_interval: 10
 | 
			
		||||
  store_checkpoints_purge: 3
 | 
			
		||||
  test_when_finished: true
 | 
			
		||||
  validation_interval: 100
 | 
			
		||||
@ -59,7 +59,7 @@ from pytorch3d.implicitron.dataset.data_source import (
 | 
			
		||||
    DataSourceBase,
 | 
			
		||||
    ImplicitronDataSource,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
 | 
			
		||||
    MultiPassEmissionAbsorptionRenderer,
 | 
			
		||||
 | 
			
		||||
@ -561,6 +561,623 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
    view_metrics_ViewMetrics_args: {}
 | 
			
		||||
    regularization_metrics_RegularizationMetrics_args: {}
 | 
			
		||||
  model_OverfitModel_args:
 | 
			
		||||
    log_vars:
 | 
			
		||||
    - loss_rgb_psnr_fg
 | 
			
		||||
    - loss_rgb_psnr
 | 
			
		||||
    - loss_rgb_mse
 | 
			
		||||
    - loss_rgb_huber
 | 
			
		||||
    - loss_depth_abs
 | 
			
		||||
    - loss_depth_abs_fg
 | 
			
		||||
    - loss_mask_neg_iou
 | 
			
		||||
    - loss_mask_bce
 | 
			
		||||
    - loss_mask_beta_prior
 | 
			
		||||
    - loss_eikonal
 | 
			
		||||
    - loss_density_tv
 | 
			
		||||
    - loss_depth_neg_penalty
 | 
			
		||||
    - loss_autodecoder_norm
 | 
			
		||||
    - loss_prev_stage_rgb_mse
 | 
			
		||||
    - loss_prev_stage_rgb_psnr_fg
 | 
			
		||||
    - loss_prev_stage_rgb_psnr
 | 
			
		||||
    - loss_prev_stage_mask_bce
 | 
			
		||||
    - objective
 | 
			
		||||
    - epoch
 | 
			
		||||
    - sec/it
 | 
			
		||||
    mask_images: true
 | 
			
		||||
    mask_depths: true
 | 
			
		||||
    render_image_width: 400
 | 
			
		||||
    render_image_height: 400
 | 
			
		||||
    mask_threshold: 0.5
 | 
			
		||||
    output_rasterized_mc: false
 | 
			
		||||
    bg_color:
 | 
			
		||||
    - 0.0
 | 
			
		||||
    - 0.0
 | 
			
		||||
    - 0.0
 | 
			
		||||
    chunk_size_grid: 4096
 | 
			
		||||
    render_features_dimensions: 3
 | 
			
		||||
    tqdm_trigger_threshold: 16
 | 
			
		||||
    n_train_target_views: 1
 | 
			
		||||
    sampling_mode_training: mask_sample
 | 
			
		||||
    sampling_mode_evaluation: full_grid
 | 
			
		||||
    global_encoder_class_type: null
 | 
			
		||||
    raysampler_class_type: AdaptiveRaySampler
 | 
			
		||||
    renderer_class_type: MultiPassEmissionAbsorptionRenderer
 | 
			
		||||
    share_implicit_function_across_passes: false
 | 
			
		||||
    implicit_function_class_type: NeuralRadianceFieldImplicitFunction
 | 
			
		||||
    coarse_implicit_function_class_type: null
 | 
			
		||||
    view_metrics_class_type: ViewMetrics
 | 
			
		||||
    regularization_metrics_class_type: RegularizationMetrics
 | 
			
		||||
    loss_weights:
 | 
			
		||||
      loss_rgb_mse: 1.0
 | 
			
		||||
      loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
      loss_mask_bce: 0.0
 | 
			
		||||
      loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
    global_encoder_HarmonicTimeEncoder_args:
 | 
			
		||||
      n_harmonic_functions: 10
 | 
			
		||||
      append_input: true
 | 
			
		||||
      time_divisor: 1.0
 | 
			
		||||
    global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
      autodecoder_args:
 | 
			
		||||
        encoding_dim: 0
 | 
			
		||||
        n_instances: 1
 | 
			
		||||
        init_scale: 1.0
 | 
			
		||||
        ignore_input: false
 | 
			
		||||
    raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
      n_pts_per_ray_training: 64
 | 
			
		||||
      n_pts_per_ray_evaluation: 64
 | 
			
		||||
      n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
      n_rays_total_training: null
 | 
			
		||||
      stratified_point_sampling_training: true
 | 
			
		||||
      stratified_point_sampling_evaluation: false
 | 
			
		||||
      scene_extent: 8.0
 | 
			
		||||
      scene_center:
 | 
			
		||||
      - 0.0
 | 
			
		||||
      - 0.0
 | 
			
		||||
      - 0.0
 | 
			
		||||
    raysampler_NearFarRaySampler_args:
 | 
			
		||||
      n_pts_per_ray_training: 64
 | 
			
		||||
      n_pts_per_ray_evaluation: 64
 | 
			
		||||
      n_rays_per_image_sampled_from_mask: 1024
 | 
			
		||||
      n_rays_total_training: null
 | 
			
		||||
      stratified_point_sampling_training: true
 | 
			
		||||
      stratified_point_sampling_evaluation: false
 | 
			
		||||
      min_depth: 0.1
 | 
			
		||||
      max_depth: 8.0
 | 
			
		||||
    renderer_LSTMRenderer_args:
 | 
			
		||||
      num_raymarch_steps: 10
 | 
			
		||||
      init_depth: 17.0
 | 
			
		||||
      init_depth_noise_std: 0.0005
 | 
			
		||||
      hidden_size: 16
 | 
			
		||||
      n_feature_channels: 256
 | 
			
		||||
      bg_color: null
 | 
			
		||||
      verbose: false
 | 
			
		||||
    renderer_MultiPassEmissionAbsorptionRenderer_args:
 | 
			
		||||
      raymarcher_class_type: EmissionAbsorptionRaymarcher
 | 
			
		||||
      n_pts_per_ray_fine_training: 64
 | 
			
		||||
      n_pts_per_ray_fine_evaluation: 64
 | 
			
		||||
      stratified_sampling_coarse_training: true
 | 
			
		||||
      stratified_sampling_coarse_evaluation: false
 | 
			
		||||
      append_coarse_samples_to_fine: true
 | 
			
		||||
      density_noise_std_train: 0.0
 | 
			
		||||
      return_weights: false
 | 
			
		||||
      raymarcher_CumsumRaymarcher_args:
 | 
			
		||||
        surface_thickness: 1
 | 
			
		||||
        bg_color:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        replicate_last_interval: false
 | 
			
		||||
        background_opacity: 0.0
 | 
			
		||||
        density_relu: true
 | 
			
		||||
        blend_output: false
 | 
			
		||||
      raymarcher_EmissionAbsorptionRaymarcher_args:
 | 
			
		||||
        surface_thickness: 1
 | 
			
		||||
        bg_color:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        replicate_last_interval: false
 | 
			
		||||
        background_opacity: 10000000000.0
 | 
			
		||||
        density_relu: true
 | 
			
		||||
        blend_output: false
 | 
			
		||||
    renderer_SignedDistanceFunctionRenderer_args:
 | 
			
		||||
      ray_normal_coloring_network_args:
 | 
			
		||||
        feature_vector_size: 3
 | 
			
		||||
        mode: idr
 | 
			
		||||
        d_in: 9
 | 
			
		||||
        d_out: 3
 | 
			
		||||
        dims:
 | 
			
		||||
        - 512
 | 
			
		||||
        - 512
 | 
			
		||||
        - 512
 | 
			
		||||
        - 512
 | 
			
		||||
        weight_norm: true
 | 
			
		||||
        n_harmonic_functions_dir: 0
 | 
			
		||||
        pooled_feature_dim: 0
 | 
			
		||||
      bg_color:
 | 
			
		||||
      - 0.0
 | 
			
		||||
      soft_mask_alpha: 50.0
 | 
			
		||||
      ray_tracer_args:
 | 
			
		||||
        sdf_threshold: 5.0e-05
 | 
			
		||||
        line_search_step: 0.5
 | 
			
		||||
        line_step_iters: 1
 | 
			
		||||
        sphere_tracing_iters: 10
 | 
			
		||||
        n_steps: 100
 | 
			
		||||
        n_secant_steps: 8
 | 
			
		||||
    implicit_function_IdrFeatureField_args:
 | 
			
		||||
      d_in: 3
 | 
			
		||||
      d_out: 1
 | 
			
		||||
      dims:
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      geometric_init: true
 | 
			
		||||
      bias: 1.0
 | 
			
		||||
      skip_in: []
 | 
			
		||||
      weight_norm: true
 | 
			
		||||
      n_harmonic_functions_xyz: 0
 | 
			
		||||
      pooled_feature_dim: 0
 | 
			
		||||
    implicit_function_NeRFormerImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      input_xyz: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      transformer_dim_down_factor: 2.0
 | 
			
		||||
      n_hidden_neurons_xyz: 80
 | 
			
		||||
      n_layers_xyz: 2
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 1
 | 
			
		||||
    implicit_function_NeuralRadianceFieldImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      input_xyz: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      transformer_dim_down_factor: 1.0
 | 
			
		||||
      n_hidden_neurons_xyz: 256
 | 
			
		||||
      n_layers_xyz: 8
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 5
 | 
			
		||||
    implicit_function_SRNHyperNetImplicitFunction_args:
 | 
			
		||||
      latent_dim_hypernet: 0
 | 
			
		||||
      hypernet_args:
 | 
			
		||||
        n_harmonic_functions: 3
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        n_hidden_units_hypernet: 256
 | 
			
		||||
        n_layers_hypernet: 1
 | 
			
		||||
        in_features: 3
 | 
			
		||||
        out_features: 256
 | 
			
		||||
        xyz_in_camera_coords: false
 | 
			
		||||
      pixel_generator_args:
 | 
			
		||||
        n_harmonic_functions: 4
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_hidden_units_color: 128
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 256
 | 
			
		||||
        out_features: 3
 | 
			
		||||
        ray_dir_in_camera_coords: false
 | 
			
		||||
    implicit_function_SRNImplicitFunction_args:
 | 
			
		||||
      raymarch_function_args:
 | 
			
		||||
        n_harmonic_functions: 3
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 3
 | 
			
		||||
        out_features: 256
 | 
			
		||||
        xyz_in_camera_coords: false
 | 
			
		||||
        raymarch_function: null
 | 
			
		||||
      pixel_generator_args:
 | 
			
		||||
        n_harmonic_functions: 4
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_hidden_units_color: 128
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 256
 | 
			
		||||
        out_features: 3
 | 
			
		||||
        ray_dir_in_camera_coords: false
 | 
			
		||||
    implicit_function_VoxelGridImplicitFunction_args:
 | 
			
		||||
      harmonic_embedder_xyz_density_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      harmonic_embedder_xyz_color_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      harmonic_embedder_dir_color_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      decoder_density_class_type: MLPDecoder
 | 
			
		||||
      decoder_color_class_type: MLPDecoder
 | 
			
		||||
      use_multiple_streams: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      scaffold_calculating_epochs: []
 | 
			
		||||
      scaffold_resolution:
 | 
			
		||||
      - 128
 | 
			
		||||
      - 128
 | 
			
		||||
      - 128
 | 
			
		||||
      scaffold_empty_space_threshold: 0.001
 | 
			
		||||
      scaffold_occupancy_chunk_size: -1
 | 
			
		||||
      scaffold_max_pool_kernel_size: 3
 | 
			
		||||
      scaffold_filter_points: true
 | 
			
		||||
      volume_cropping_epochs: []
 | 
			
		||||
      voxel_grid_density_args:
 | 
			
		||||
        voxel_grid_class_type: FullResolutionVoxelGrid
 | 
			
		||||
        extents:
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        translation:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        init_std: 0.1
 | 
			
		||||
        init_mean: 0.0
 | 
			
		||||
        hold_voxel_grid_as_parameters: true
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        voxel_grid_CPFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: 24
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
        voxel_grid_FullResolutionVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
        voxel_grid_VMFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: null
 | 
			
		||||
          distribution_of_components: null
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
      voxel_grid_color_args:
 | 
			
		||||
        voxel_grid_class_type: FullResolutionVoxelGrid
 | 
			
		||||
        extents:
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        translation:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        init_std: 0.1
 | 
			
		||||
        init_mean: 0.0
 | 
			
		||||
        hold_voxel_grid_as_parameters: true
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        voxel_grid_CPFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: 24
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
        voxel_grid_FullResolutionVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
        voxel_grid_VMFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: null
 | 
			
		||||
          distribution_of_components: null
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
      decoder_density_ElementwiseDecoder_args:
 | 
			
		||||
        scale: 1.0
 | 
			
		||||
        shift: 0.0
 | 
			
		||||
        operation: IDENTITY
 | 
			
		||||
      decoder_density_MLPDecoder_args:
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        network_args:
 | 
			
		||||
          n_layers: 8
 | 
			
		||||
          output_dim: 256
 | 
			
		||||
          skip_dim: 39
 | 
			
		||||
          hidden_dim: 256
 | 
			
		||||
          input_skips:
 | 
			
		||||
          - 5
 | 
			
		||||
          skip_affine_trans: false
 | 
			
		||||
          last_layer_bias_init: null
 | 
			
		||||
          last_activation: RELU
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
      decoder_color_ElementwiseDecoder_args:
 | 
			
		||||
        scale: 1.0
 | 
			
		||||
        shift: 0.0
 | 
			
		||||
        operation: IDENTITY
 | 
			
		||||
      decoder_color_MLPDecoder_args:
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        network_args:
 | 
			
		||||
          n_layers: 8
 | 
			
		||||
          output_dim: 256
 | 
			
		||||
          skip_dim: 39
 | 
			
		||||
          hidden_dim: 256
 | 
			
		||||
          input_skips:
 | 
			
		||||
          - 5
 | 
			
		||||
          skip_affine_trans: false
 | 
			
		||||
          last_layer_bias_init: null
 | 
			
		||||
          last_activation: RELU
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
    coarse_implicit_function_IdrFeatureField_args:
 | 
			
		||||
      d_in: 3
 | 
			
		||||
      d_out: 1
 | 
			
		||||
      dims:
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      - 512
 | 
			
		||||
      geometric_init: true
 | 
			
		||||
      bias: 1.0
 | 
			
		||||
      skip_in: []
 | 
			
		||||
      weight_norm: true
 | 
			
		||||
      n_harmonic_functions_xyz: 0
 | 
			
		||||
      pooled_feature_dim: 0
 | 
			
		||||
    coarse_implicit_function_NeRFormerImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      input_xyz: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      transformer_dim_down_factor: 2.0
 | 
			
		||||
      n_hidden_neurons_xyz: 80
 | 
			
		||||
      n_layers_xyz: 2
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 1
 | 
			
		||||
    coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args:
 | 
			
		||||
      n_harmonic_functions_xyz: 10
 | 
			
		||||
      n_harmonic_functions_dir: 4
 | 
			
		||||
      n_hidden_neurons_dir: 128
 | 
			
		||||
      input_xyz: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      transformer_dim_down_factor: 1.0
 | 
			
		||||
      n_hidden_neurons_xyz: 256
 | 
			
		||||
      n_layers_xyz: 8
 | 
			
		||||
      append_xyz:
 | 
			
		||||
      - 5
 | 
			
		||||
    coarse_implicit_function_SRNHyperNetImplicitFunction_args:
 | 
			
		||||
      latent_dim_hypernet: 0
 | 
			
		||||
      hypernet_args:
 | 
			
		||||
        n_harmonic_functions: 3
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        n_hidden_units_hypernet: 256
 | 
			
		||||
        n_layers_hypernet: 1
 | 
			
		||||
        in_features: 3
 | 
			
		||||
        out_features: 256
 | 
			
		||||
        xyz_in_camera_coords: false
 | 
			
		||||
      pixel_generator_args:
 | 
			
		||||
        n_harmonic_functions: 4
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_hidden_units_color: 128
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 256
 | 
			
		||||
        out_features: 3
 | 
			
		||||
        ray_dir_in_camera_coords: false
 | 
			
		||||
    coarse_implicit_function_SRNImplicitFunction_args:
 | 
			
		||||
      raymarch_function_args:
 | 
			
		||||
        n_harmonic_functions: 3
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 3
 | 
			
		||||
        out_features: 256
 | 
			
		||||
        xyz_in_camera_coords: false
 | 
			
		||||
        raymarch_function: null
 | 
			
		||||
      pixel_generator_args:
 | 
			
		||||
        n_harmonic_functions: 4
 | 
			
		||||
        n_hidden_units: 256
 | 
			
		||||
        n_hidden_units_color: 128
 | 
			
		||||
        n_layers: 2
 | 
			
		||||
        in_features: 256
 | 
			
		||||
        out_features: 3
 | 
			
		||||
        ray_dir_in_camera_coords: false
 | 
			
		||||
    coarse_implicit_function_VoxelGridImplicitFunction_args:
 | 
			
		||||
      harmonic_embedder_xyz_density_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      harmonic_embedder_xyz_color_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      harmonic_embedder_dir_color_args:
 | 
			
		||||
        n_harmonic_functions: 6
 | 
			
		||||
        omega_0: 1.0
 | 
			
		||||
        logspace: true
 | 
			
		||||
        append_input: true
 | 
			
		||||
      decoder_density_class_type: MLPDecoder
 | 
			
		||||
      decoder_color_class_type: MLPDecoder
 | 
			
		||||
      use_multiple_streams: true
 | 
			
		||||
      xyz_ray_dir_in_camera_coords: false
 | 
			
		||||
      scaffold_calculating_epochs: []
 | 
			
		||||
      scaffold_resolution:
 | 
			
		||||
      - 128
 | 
			
		||||
      - 128
 | 
			
		||||
      - 128
 | 
			
		||||
      scaffold_empty_space_threshold: 0.001
 | 
			
		||||
      scaffold_occupancy_chunk_size: -1
 | 
			
		||||
      scaffold_max_pool_kernel_size: 3
 | 
			
		||||
      scaffold_filter_points: true
 | 
			
		||||
      volume_cropping_epochs: []
 | 
			
		||||
      voxel_grid_density_args:
 | 
			
		||||
        voxel_grid_class_type: FullResolutionVoxelGrid
 | 
			
		||||
        extents:
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        translation:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        init_std: 0.1
 | 
			
		||||
        init_mean: 0.0
 | 
			
		||||
        hold_voxel_grid_as_parameters: true
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        voxel_grid_CPFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: 24
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
        voxel_grid_FullResolutionVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
        voxel_grid_VMFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: null
 | 
			
		||||
          distribution_of_components: null
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
      voxel_grid_color_args:
 | 
			
		||||
        voxel_grid_class_type: FullResolutionVoxelGrid
 | 
			
		||||
        extents:
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        - 2.0
 | 
			
		||||
        translation:
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        - 0.0
 | 
			
		||||
        init_std: 0.1
 | 
			
		||||
        init_mean: 0.0
 | 
			
		||||
        hold_voxel_grid_as_parameters: true
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        voxel_grid_CPFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: 24
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
        voxel_grid_FullResolutionVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
        voxel_grid_VMFactorizedVoxelGrid_args:
 | 
			
		||||
          align_corners: true
 | 
			
		||||
          padding: zeros
 | 
			
		||||
          mode: bilinear
 | 
			
		||||
          n_features: 1
 | 
			
		||||
          resolution_changes:
 | 
			
		||||
            0:
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
            - 128
 | 
			
		||||
          n_components: null
 | 
			
		||||
          distribution_of_components: null
 | 
			
		||||
          basis_matrix: true
 | 
			
		||||
      decoder_density_ElementwiseDecoder_args:
 | 
			
		||||
        scale: 1.0
 | 
			
		||||
        shift: 0.0
 | 
			
		||||
        operation: IDENTITY
 | 
			
		||||
      decoder_density_MLPDecoder_args:
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        network_args:
 | 
			
		||||
          n_layers: 8
 | 
			
		||||
          output_dim: 256
 | 
			
		||||
          skip_dim: 39
 | 
			
		||||
          hidden_dim: 256
 | 
			
		||||
          input_skips:
 | 
			
		||||
          - 5
 | 
			
		||||
          skip_affine_trans: false
 | 
			
		||||
          last_layer_bias_init: null
 | 
			
		||||
          last_activation: RELU
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
      decoder_color_ElementwiseDecoder_args:
 | 
			
		||||
        scale: 1.0
 | 
			
		||||
        shift: 0.0
 | 
			
		||||
        operation: IDENTITY
 | 
			
		||||
      decoder_color_MLPDecoder_args:
 | 
			
		||||
        param_groups: {}
 | 
			
		||||
        network_args:
 | 
			
		||||
          n_layers: 8
 | 
			
		||||
          output_dim: 256
 | 
			
		||||
          skip_dim: 39
 | 
			
		||||
          hidden_dim: 256
 | 
			
		||||
          input_skips:
 | 
			
		||||
          - 5
 | 
			
		||||
          skip_affine_trans: false
 | 
			
		||||
          last_layer_bias_init: null
 | 
			
		||||
          last_activation: RELU
 | 
			
		||||
          use_xavier_init: true
 | 
			
		||||
    view_metrics_ViewMetrics_args: {}
 | 
			
		||||
    regularization_metrics_RegularizationMetrics_args: {}
 | 
			
		||||
optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  betas:
 | 
			
		||||
  - 0.9
 | 
			
		||||
 | 
			
		||||
@ -141,7 +141,11 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
        # Check that all the pre-prepared configs are valid.
 | 
			
		||||
        config_files = []
 | 
			
		||||
 | 
			
		||||
        for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
 | 
			
		||||
        for pattern in (
 | 
			
		||||
            "repro_singleseq*.yaml",
 | 
			
		||||
            "repro_multiseq*.yaml",
 | 
			
		||||
            "overfit_singleseq*.yaml",
 | 
			
		||||
        ):
 | 
			
		||||
            config_files.extend(
 | 
			
		||||
                [
 | 
			
		||||
                    f
 | 
			
		||||
 | 
			
		||||
@ -3,3 +3,8 @@
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
# Allows to register the models
 | 
			
		||||
# see: pytorch3d.implicitron.tools.config.registry:register
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import GenericModel
 | 
			
		||||
from pytorch3d.implicitron.models.overfit_model import OverfitModel
 | 
			
		||||
 | 
			
		||||
@ -8,11 +8,11 @@ from dataclasses import dataclass, field
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
 | 
			
		||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
from .renderer.base import EvaluationMode
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ImplicitronRender:
 | 
			
		||||
 | 
			
		||||
@ -9,14 +9,11 @@
 | 
			
		||||
# which are part of implicitron. They ensure that the registry is prepopulated.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import warnings
 | 
			
		||||
from dataclasses import field
 | 
			
		||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
from pytorch3d.common.compat import prod
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import (
 | 
			
		||||
    ImplicitronModelBase,
 | 
			
		||||
@ -33,11 +30,9 @@ from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (  # noqa
 | 
			
		||||
    NeRFormerImplicitFunction,
 | 
			
		||||
    NeuralRadianceFieldImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (  # noqa
 | 
			
		||||
    SRNHyperNetImplicitFunction,
 | 
			
		||||
    SRNImplicitFunction,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import (  # noqa
 | 
			
		||||
    VoxelGridImplicitFunction,
 | 
			
		||||
@ -63,8 +58,16 @@ from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import (  # noqa
 | 
			
		||||
    SignedDistanceFunctionRenderer,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.utils import (
 | 
			
		||||
    apply_chunked,
 | 
			
		||||
    chunk_generator,
 | 
			
		||||
    log_loss_weights,
 | 
			
		||||
    preprocess_input,
 | 
			
		||||
    weighted_sum_losses,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.view_pooler.view_pooler import ViewPooler
 | 
			
		||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
 | 
			
		||||
from pytorch3d.implicitron.tools import vis_utils
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    expand_args_fields,
 | 
			
		||||
    registry,
 | 
			
		||||
@ -72,7 +75,6 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
 | 
			
		||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
 | 
			
		||||
from pytorch3d.renderer import utils as rend_utils
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
@ -323,7 +325,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
 | 
			
		||||
        self._implicit_functions = self._construct_implicit_functions()
 | 
			
		||||
 | 
			
		||||
        self.log_loss_weights()
 | 
			
		||||
        log_loss_weights(self.loss_weights, logger)
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
@ -367,8 +369,14 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            preds: A dictionary containing all outputs of the forward pass including the
 | 
			
		||||
                rendered images, depths, masks, losses and other metrics.
 | 
			
		||||
        """
 | 
			
		||||
        image_rgb, fg_probability, depth_map = self._preprocess_input(
 | 
			
		||||
            image_rgb, fg_probability, depth_map
 | 
			
		||||
        image_rgb, fg_probability, depth_map = preprocess_input(
 | 
			
		||||
            image_rgb,
 | 
			
		||||
            fg_probability,
 | 
			
		||||
            depth_map,
 | 
			
		||||
            self.mask_images,
 | 
			
		||||
            self.mask_depths,
 | 
			
		||||
            self.mask_threshold,
 | 
			
		||||
            self.bg_color,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Obtain the batch size from the camera as this is the only required input.
 | 
			
		||||
@ -453,12 +461,12 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        for func in self._implicit_functions:
 | 
			
		||||
            func.bind_args(**custom_args)
 | 
			
		||||
 | 
			
		||||
        chunked_renderer_inputs = {}
 | 
			
		||||
        inputs_to_be_chunked = {}
 | 
			
		||||
        if fg_probability is not None and self.renderer.requires_object_mask():
 | 
			
		||||
            sampled_fb_prob = rend_utils.ndc_grid_sample(
 | 
			
		||||
                fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
 | 
			
		||||
            )
 | 
			
		||||
            chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5
 | 
			
		||||
            inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5
 | 
			
		||||
 | 
			
		||||
        # (5)-(6) Implicit function evaluation and Rendering
 | 
			
		||||
        rendered = self._render(
 | 
			
		||||
@ -466,7 +474,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            sampling_mode=sampling_mode,
 | 
			
		||||
            evaluation_mode=evaluation_mode,
 | 
			
		||||
            implicit_functions=self._implicit_functions,
 | 
			
		||||
            chunked_inputs=chunked_renderer_inputs,
 | 
			
		||||
            inputs_to_be_chunked=inputs_to_be_chunked,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Unbind the custom arguments to prevent pytorch from storing
 | 
			
		||||
@ -530,30 +538,18 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            raise AssertionError("Unreachable state")
 | 
			
		||||
 | 
			
		||||
        # (7) Compute losses
 | 
			
		||||
        # finally get the optimization objective using self.loss_weights
 | 
			
		||||
        objective = self._get_objective(preds)
 | 
			
		||||
        if objective is not None:
 | 
			
		||||
            preds["objective"] = objective
 | 
			
		||||
 | 
			
		||||
        return preds
 | 
			
		||||
 | 
			
		||||
    def _get_objective(self, preds) -> Optional[torch.Tensor]:
 | 
			
		||||
    def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        A helper function to compute the overall loss as the dot product
 | 
			
		||||
        of individual loss functions with the corresponding weights.
 | 
			
		||||
        """
 | 
			
		||||
        losses_weighted = [
 | 
			
		||||
            preds[k] * float(w)
 | 
			
		||||
            for k, w in self.loss_weights.items()
 | 
			
		||||
            if (k in preds and w != 0.0)
 | 
			
		||||
        ]
 | 
			
		||||
        if len(losses_weighted) == 0:
 | 
			
		||||
            warnings.warn("No main objective found.")
 | 
			
		||||
            return None
 | 
			
		||||
        loss = sum(losses_weighted)
 | 
			
		||||
        assert torch.is_tensor(loss)
 | 
			
		||||
        # pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
 | 
			
		||||
        return loss
 | 
			
		||||
        return weighted_sum_losses(preds, self.loss_weights)
 | 
			
		||||
 | 
			
		||||
    def visualize(
 | 
			
		||||
        self,
 | 
			
		||||
@ -585,7 +581,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
        chunked_inputs: Dict[str, torch.Tensor],
 | 
			
		||||
        inputs_to_be_chunked: Dict[str, torch.Tensor],
 | 
			
		||||
        sampling_mode: RenderSamplingMode,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> RendererOutput:
 | 
			
		||||
@ -593,7 +589,7 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        Args:
 | 
			
		||||
            ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
 | 
			
		||||
                sampled rendering rays.
 | 
			
		||||
            chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
 | 
			
		||||
            inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g.
 | 
			
		||||
                SignedDistanceFunctionRenderer requires "object_mask", shape
 | 
			
		||||
                (B, 1, H, W), the silhouette of the object in the image. When
 | 
			
		||||
                chunking, they are passed to the renderer as shape
 | 
			
		||||
@ -605,30 +601,27 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            An instance of RendererOutput
 | 
			
		||||
        """
 | 
			
		||||
        if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
 | 
			
		||||
            return _apply_chunked(
 | 
			
		||||
            return apply_chunked(
 | 
			
		||||
                self.renderer,
 | 
			
		||||
                _chunk_generator(
 | 
			
		||||
                chunk_generator(
 | 
			
		||||
                    self.chunk_size_grid,
 | 
			
		||||
                    ray_bundle,
 | 
			
		||||
                    chunked_inputs,
 | 
			
		||||
                    inputs_to_be_chunked,
 | 
			
		||||
                    self.tqdm_trigger_threshold,
 | 
			
		||||
                    **kwargs,
 | 
			
		||||
                ),
 | 
			
		||||
                lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]),
 | 
			
		||||
                lambda batch: torch.cat(batch, dim=1).reshape(
 | 
			
		||||
                    *ray_bundle.lengths.shape[:-1], -1
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # pyre-fixme[29]: `BaseRenderer` is not a function.
 | 
			
		||||
            return self.renderer(
 | 
			
		||||
                ray_bundle=ray_bundle,
 | 
			
		||||
                **chunked_inputs,
 | 
			
		||||
                **inputs_to_be_chunked,
 | 
			
		||||
                **kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def _get_global_encoder_encoding_dim(self) -> int:
 | 
			
		||||
        if self.global_encoder is None:
 | 
			
		||||
            return 0
 | 
			
		||||
        return self.global_encoder.get_encoding_dim()
 | 
			
		||||
 | 
			
		||||
    def _get_viewpooled_feature_dim(self) -> int:
 | 
			
		||||
        if self.view_pooler is None:
 | 
			
		||||
            return 0
 | 
			
		||||
@ -720,30 +713,29 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        function(s) are initialized.
 | 
			
		||||
        """
 | 
			
		||||
        extra_args = {}
 | 
			
		||||
        global_encoder_dim = (
 | 
			
		||||
            0 if self.global_encoder is None else self.global_encoder.get_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
        viewpooled_feature_dim = self._get_viewpooled_feature_dim()
 | 
			
		||||
 | 
			
		||||
        if self.implicit_function_class_type in (
 | 
			
		||||
            "NeuralRadianceFieldImplicitFunction",
 | 
			
		||||
            "NeRFormerImplicitFunction",
 | 
			
		||||
        ):
 | 
			
		||||
            extra_args["latent_dim"] = (
 | 
			
		||||
                self._get_viewpooled_feature_dim()
 | 
			
		||||
                + self._get_global_encoder_encoding_dim()
 | 
			
		||||
            )
 | 
			
		||||
            extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim
 | 
			
		||||
            extra_args["color_dim"] = self.render_features_dimensions
 | 
			
		||||
 | 
			
		||||
        if self.implicit_function_class_type == "IdrFeatureField":
 | 
			
		||||
            extra_args["feature_vector_size"] = self.render_features_dimensions
 | 
			
		||||
            extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
 | 
			
		||||
            extra_args["encoding_dim"] = global_encoder_dim
 | 
			
		||||
 | 
			
		||||
        if self.implicit_function_class_type == "SRNImplicitFunction":
 | 
			
		||||
            extra_args["latent_dim"] = (
 | 
			
		||||
                self._get_viewpooled_feature_dim()
 | 
			
		||||
                + self._get_global_encoder_encoding_dim()
 | 
			
		||||
            )
 | 
			
		||||
            extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim
 | 
			
		||||
 | 
			
		||||
        # srn_hypernet preprocessing
 | 
			
		||||
        if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
 | 
			
		||||
            extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
 | 
			
		||||
            extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
 | 
			
		||||
            extra_args["latent_dim"] = viewpooled_feature_dim
 | 
			
		||||
            extra_args["latent_dim_hypernet"] = global_encoder_dim
 | 
			
		||||
 | 
			
		||||
        # check that for srn, srn_hypernet, idr we have self.num_passes=1
 | 
			
		||||
        implicit_function_type = registry.get(
 | 
			
		||||
@ -770,147 +762,3 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            for _ in range(self.num_passes)
 | 
			
		||||
        ]
 | 
			
		||||
        return torch.nn.ModuleList(implicit_functions_list)
 | 
			
		||||
 | 
			
		||||
    def log_loss_weights(self) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Print a table of the loss weights.
 | 
			
		||||
        """
 | 
			
		||||
        loss_weights_message = (
 | 
			
		||||
            "-------\nloss_weights:\n"
 | 
			
		||||
            + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
 | 
			
		||||
            + "-------"
 | 
			
		||||
        )
 | 
			
		||||
        logger.info(loss_weights_message)
 | 
			
		||||
 | 
			
		||||
    def _preprocess_input(
 | 
			
		||||
        self,
 | 
			
		||||
        image_rgb: Optional[torch.Tensor],
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
        depth_map: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
 | 
			
		||||
        """
 | 
			
		||||
        Helper function to preprocess the input images and optional depth maps
 | 
			
		||||
        to apply masking if required.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
 | 
			
		||||
                corresponding to the source viewpoints from which features will be extracted
 | 
			
		||||
            fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
 | 
			
		||||
                of foreground masks with values in [0, 1].
 | 
			
		||||
            depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Modified image_rgb, fg_mask, depth_map
 | 
			
		||||
        """
 | 
			
		||||
        if image_rgb is not None and image_rgb.ndim == 3:
 | 
			
		||||
            # The FrameData object is used for both frames and batches of frames,
 | 
			
		||||
            # and a user might get this error if those were confused.
 | 
			
		||||
            # Perhaps a user has a FrameData `fd` representing a single frame and
 | 
			
		||||
            # wrote something like `model(**fd)` instead of
 | 
			
		||||
            # `model(**fd.collate([fd]))`.
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Model received unbatched inputs. "
 | 
			
		||||
                + "Perhaps they came from a FrameData which had not been collated."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        fg_mask = fg_probability
 | 
			
		||||
        if fg_mask is not None and self.mask_threshold > 0.0:
 | 
			
		||||
            # threshold masks
 | 
			
		||||
            warnings.warn("Thresholding masks!")
 | 
			
		||||
            fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask)
 | 
			
		||||
 | 
			
		||||
        if self.mask_images and fg_mask is not None and image_rgb is not None:
 | 
			
		||||
            # mask the image
 | 
			
		||||
            warnings.warn("Masking images!")
 | 
			
		||||
            image_rgb = image_utils.mask_background(
 | 
			
		||||
                image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.mask_depths and fg_mask is not None and depth_map is not None:
 | 
			
		||||
            # mask the depths
 | 
			
		||||
            assert (
 | 
			
		||||
                self.mask_threshold > 0.0
 | 
			
		||||
            ), "Depths should be masked only with thresholded masks"
 | 
			
		||||
            warnings.warn("Masking depths!")
 | 
			
		||||
            depth_map = depth_map * fg_mask
 | 
			
		||||
 | 
			
		||||
        return image_rgb, fg_mask, depth_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_chunked(func, chunk_generator, tensor_collator):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function to apply a function on a sequence of
 | 
			
		||||
    chunked inputs yielded by a generator and collate
 | 
			
		||||
    the result.
 | 
			
		||||
    """
 | 
			
		||||
    processed_chunks = [
 | 
			
		||||
        func(*chunk_args, **chunk_kwargs)
 | 
			
		||||
        for chunk_args, chunk_kwargs in chunk_generator
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return cat_dataclass(processed_chunks, tensor_collator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _tensor_collator(batch, new_dims) -> torch.Tensor:
 | 
			
		||||
    """
 | 
			
		||||
    Helper function to reshape the batch to the desired shape
 | 
			
		||||
    """
 | 
			
		||||
    return torch.cat(batch, dim=1).reshape(*new_dims, -1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _chunk_generator(
 | 
			
		||||
    chunk_size: int,
 | 
			
		||||
    ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
    chunked_inputs: Dict[str, torch.Tensor],
 | 
			
		||||
    tqdm_trigger_threshold: int,
 | 
			
		||||
    *args,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function which yields chunks of rays from the
 | 
			
		||||
    input ray_bundle, to be used when the number of rays is
 | 
			
		||||
    large and will not fit in memory for rendering.
 | 
			
		||||
    """
 | 
			
		||||
    (
 | 
			
		||||
        batch_size,
 | 
			
		||||
        *spatial_dim,
 | 
			
		||||
        n_pts_per_ray,
 | 
			
		||||
    ) = ray_bundle.lengths.shape  # B x ... x n_pts_per_ray
 | 
			
		||||
    if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"chunk_size_grid ({chunk_size}) should be divisible "
 | 
			
		||||
            f"by n_pts_per_ray ({n_pts_per_ray})"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    n_rays = prod(spatial_dim)
 | 
			
		||||
    # special handling for raytracing-based methods
 | 
			
		||||
    n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
 | 
			
		||||
    chunk_size_in_rays = -(-n_rays // n_chunks)
 | 
			
		||||
 | 
			
		||||
    iter = range(0, n_rays, chunk_size_in_rays)
 | 
			
		||||
    if len(iter) >= tqdm_trigger_threshold:
 | 
			
		||||
        iter = tqdm.tqdm(iter)
 | 
			
		||||
 | 
			
		||||
    def _safe_slice(
 | 
			
		||||
        tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
 | 
			
		||||
    ) -> Any:
 | 
			
		||||
        return tensor[start_idx:end_idx] if tensor is not None else None
 | 
			
		||||
 | 
			
		||||
    for start_idx in iter:
 | 
			
		||||
        end_idx = min(start_idx + chunk_size_in_rays, n_rays)
 | 
			
		||||
        ray_bundle_chunk = ImplicitronRayBundle(
 | 
			
		||||
            origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
 | 
			
		||||
            directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
 | 
			
		||||
                :, start_idx:end_idx
 | 
			
		||||
            ],
 | 
			
		||||
            lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
 | 
			
		||||
                :, start_idx:end_idx
 | 
			
		||||
            ],
 | 
			
		||||
            xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
 | 
			
		||||
            camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
 | 
			
		||||
            camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
 | 
			
		||||
        )
 | 
			
		||||
        extra_args = kwargs.copy()
 | 
			
		||||
        for k, v in chunked_inputs.items():
 | 
			
		||||
            extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
 | 
			
		||||
        yield [ray_bundle_chunk, *args], extra_args
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										639
									
								
								pytorch3d/implicitron/models/overfit_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										639
									
								
								pytorch3d/implicitron/models/overfit_model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,639 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Note: The #noqa comments below are for unused imports of pluggable implementations
 | 
			
		||||
# which are part of implicitron. They ensure that the registry is prepopulated.
 | 
			
		||||
 | 
			
		||||
import functools
 | 
			
		||||
import logging
 | 
			
		||||
from dataclasses import field
 | 
			
		||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import (
 | 
			
		||||
    ImplicitronModelBase,
 | 
			
		||||
    ImplicitronRender,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
 | 
			
		||||
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
 | 
			
		||||
from pytorch3d.implicitron.models.metrics import (
 | 
			
		||||
    RegularizationMetricsBase,
 | 
			
		||||
    ViewMetricsBase,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import (
 | 
			
		||||
    BaseRenderer,
 | 
			
		||||
    EvaluationMode,
 | 
			
		||||
    ImplicitronRayBundle,
 | 
			
		||||
    RendererOutput,
 | 
			
		||||
    RenderSamplingMode,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
 | 
			
		||||
from pytorch3d.implicitron.models.utils import (
 | 
			
		||||
    apply_chunked,
 | 
			
		||||
    chunk_generator,
 | 
			
		||||
    log_loss_weights,
 | 
			
		||||
    preprocess_input,
 | 
			
		||||
    weighted_sum_losses,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools import vis_utils
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    expand_args_fields,
 | 
			
		||||
    registry,
 | 
			
		||||
    run_auto_creation,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
 | 
			
		||||
from pytorch3d.renderer import utils as rend_utils
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from visdom import Visdom
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
IMPLICIT_FUNCTION_ARGS_TO_REMOVE: List[str] = [
 | 
			
		||||
    "feature_vector_size",
 | 
			
		||||
    "encoding_dim",
 | 
			
		||||
    "latent_dim",
 | 
			
		||||
    "color_dim",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class OverfitModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
    """
 | 
			
		||||
    OverfitModel is a wrapper for the neural implicit
 | 
			
		||||
    rendering and reconstruction pipeline which consists
 | 
			
		||||
    of the following sequence of 4 steps:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
        (1) Ray Sampling
 | 
			
		||||
        ------------------
 | 
			
		||||
        Rays are sampled from an image grid based on the target view(s).
 | 
			
		||||
                │
 | 
			
		||||
                ▼
 | 
			
		||||
        (2) Implicit Function Evaluation
 | 
			
		||||
        ------------------
 | 
			
		||||
        Evaluate the implicit function(s) at the sampled ray points
 | 
			
		||||
        (also optionally pass in a global encoding from global_encoder).
 | 
			
		||||
                │
 | 
			
		||||
                ▼
 | 
			
		||||
        (3) Rendering
 | 
			
		||||
        ------------------
 | 
			
		||||
        Render the image into the target cameras by raymarching along
 | 
			
		||||
        the sampled rays and aggregating the colors and densities
 | 
			
		||||
        output by the implicit function in (2).
 | 
			
		||||
                │
 | 
			
		||||
                ▼
 | 
			
		||||
        (4) Loss Computation
 | 
			
		||||
        ------------------
 | 
			
		||||
        Compute losses based on the predicted target image(s).
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    The `forward` function of OverfitModel executes
 | 
			
		||||
    this sequence of steps. Currently, steps 1, 2, 3
 | 
			
		||||
    can be customized by intializing a subclass of the appropriate
 | 
			
		||||
    base class and adding the newly created module to the registry.
 | 
			
		||||
    Please see https://github.com/facebookresearch/pytorch3d/blob/main/projects/implicitron_trainer/README.md#custom-plugins
 | 
			
		||||
    for more details on how to create and register a custom component.
 | 
			
		||||
 | 
			
		||||
    In the config .yaml files for experiments, the parameters below are
 | 
			
		||||
    contained in the
 | 
			
		||||
    `model_factory_ImplicitronModelFactory_args.model_OverfitModel_args`
 | 
			
		||||
    node. As OverfitModel derives from ReplaceableBase, the input arguments are
 | 
			
		||||
    parsed by the run_auto_creation function to initialize the
 | 
			
		||||
    necessary member modules. Please see implicitron_trainer/README.md
 | 
			
		||||
    for more details on this process.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        mask_images: Whether or not to mask the RGB image background given the
 | 
			
		||||
            foreground mask (the `fg_probability` argument of `GenericModel.forward`)
 | 
			
		||||
        mask_depths: Whether or not to mask the depth image background given the
 | 
			
		||||
            foreground mask (the `fg_probability` argument of `GenericModel.forward`)
 | 
			
		||||
        render_image_width: Width of the output image to render
 | 
			
		||||
        render_image_height: Height of the output image to render
 | 
			
		||||
        mask_threshold: If greater than 0.0, the foreground mask is
 | 
			
		||||
            thresholded by this value before being applied to the RGB/Depth images
 | 
			
		||||
        output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
 | 
			
		||||
            splatting onto an image grid. Default: False.
 | 
			
		||||
        bg_color: RGB values for setting the background color of input image
 | 
			
		||||
            if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
 | 
			
		||||
            way to determine the background color of its output, unrelated to this.
 | 
			
		||||
        chunk_size_grid: The total number of points which can be rendered
 | 
			
		||||
            per chunk. This is used to compute the number of rays used
 | 
			
		||||
            per chunk when the chunked version of the renderer is used (in order
 | 
			
		||||
            to fit rendering on all rays in memory)
 | 
			
		||||
        render_features_dimensions: The number of output features to render.
 | 
			
		||||
            Defaults to 3, corresponding to RGB images.
 | 
			
		||||
        sampling_mode_training: The sampling method to use during training. Must be
 | 
			
		||||
            a value from the RenderSamplingMode Enum.
 | 
			
		||||
        sampling_mode_evaluation: Same as above but for evaluation.
 | 
			
		||||
        global_encoder_class_type: The name of the class to use for global_encoder,
 | 
			
		||||
            which must be available in the registry. Or `None` to disable global encoder.
 | 
			
		||||
        global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
 | 
			
		||||
            of the image (referred to as the global_code) that can be used to model aspects of
 | 
			
		||||
            the scene such as multiple objects or morphing objects. It is up to the implicit
 | 
			
		||||
            function definition how to use it, but the most typical way is to broadcast and
 | 
			
		||||
            concatenate to the other inputs for the implicit function.
 | 
			
		||||
        raysampler_class_type: The name of the raysampler class which is available
 | 
			
		||||
            in the global registry.
 | 
			
		||||
        raysampler: An instance of RaySampler which is used to emit
 | 
			
		||||
            rays from the target view(s).
 | 
			
		||||
        renderer_class_type: The name of the renderer class which is available in the global
 | 
			
		||||
            registry.
 | 
			
		||||
        renderer: A renderer class which inherits from BaseRenderer. This is used to
 | 
			
		||||
            generate the images from the target view(s).
 | 
			
		||||
        share_implicit_function_across_passes: If set to True
 | 
			
		||||
            coarse_implicit_function is automatically set as implicit_function
 | 
			
		||||
            (coarse_implicit_function=implicit_funciton). The
 | 
			
		||||
            implicit_functions are then run sequentially during the rendering.
 | 
			
		||||
        implicit_function_class_type: The type of implicit function to use which
 | 
			
		||||
            is available in the global registry.
 | 
			
		||||
        implicit_function: An instance of ImplicitFunctionBase.
 | 
			
		||||
        coarse_implicit_function_class_type: The type of implicit function to use which
 | 
			
		||||
            is available in the global registry.
 | 
			
		||||
        coarse_implicit_function: An instance of ImplicitFunctionBase.
 | 
			
		||||
            If set and `share_implicit_function_across_passes` is set to False,
 | 
			
		||||
            coarse_implicit_function is instantiated on itself. It
 | 
			
		||||
            is then used as the second pass during the rendering.
 | 
			
		||||
            If set to None, we only do a single pass with implicit_function.
 | 
			
		||||
        view_metrics: An instance of ViewMetricsBase used to compute loss terms which
 | 
			
		||||
            are independent of the model's parameters.
 | 
			
		||||
        view_metrics_class_type: The type of view metrics to use, must be available in
 | 
			
		||||
            the global registry.
 | 
			
		||||
        regularization_metrics: An instance of RegularizationMetricsBase used to compute
 | 
			
		||||
            regularization terms which can depend on the model's parameters.
 | 
			
		||||
        regularization_metrics_class_type: The type of regularization metrics to use,
 | 
			
		||||
            must be available in the global registry.
 | 
			
		||||
        loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
 | 
			
		||||
            for `ViewMetrics` class for available loss functions.
 | 
			
		||||
        log_vars: A list of variable names which should be logged.
 | 
			
		||||
            The names should correspond to a subset of the keys of the
 | 
			
		||||
            dict `preds` output by the `forward` function.
 | 
			
		||||
    """  # noqa: B950
 | 
			
		||||
 | 
			
		||||
    mask_images: bool = True
 | 
			
		||||
    mask_depths: bool = True
 | 
			
		||||
    render_image_width: int = 400
 | 
			
		||||
    render_image_height: int = 400
 | 
			
		||||
    mask_threshold: float = 0.5
 | 
			
		||||
    output_rasterized_mc: bool = False
 | 
			
		||||
    bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
 | 
			
		||||
    chunk_size_grid: int = 4096
 | 
			
		||||
    render_features_dimensions: int = 3
 | 
			
		||||
    tqdm_trigger_threshold: int = 16
 | 
			
		||||
 | 
			
		||||
    n_train_target_views: int = 1
 | 
			
		||||
    sampling_mode_training: str = "mask_sample"
 | 
			
		||||
    sampling_mode_evaluation: str = "full_grid"
 | 
			
		||||
 | 
			
		||||
    # ---- global encoder settings
 | 
			
		||||
    global_encoder_class_type: Optional[str] = None
 | 
			
		||||
    global_encoder: Optional[GlobalEncoderBase]
 | 
			
		||||
 | 
			
		||||
    # ---- raysampler
 | 
			
		||||
    raysampler_class_type: str = "AdaptiveRaySampler"
 | 
			
		||||
    raysampler: RaySamplerBase
 | 
			
		||||
 | 
			
		||||
    # ---- renderer configs
 | 
			
		||||
    renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
 | 
			
		||||
    renderer: BaseRenderer
 | 
			
		||||
 | 
			
		||||
    # ---- implicit function settings
 | 
			
		||||
    share_implicit_function_across_passes: bool = False
 | 
			
		||||
    implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
 | 
			
		||||
    implicit_function: ImplicitFunctionBase
 | 
			
		||||
    coarse_implicit_function_class_type: Optional[str] = None
 | 
			
		||||
    coarse_implicit_function: Optional[ImplicitFunctionBase]
 | 
			
		||||
 | 
			
		||||
    # ----- metrics
 | 
			
		||||
    view_metrics: ViewMetricsBase
 | 
			
		||||
    view_metrics_class_type: str = "ViewMetrics"
 | 
			
		||||
 | 
			
		||||
    regularization_metrics: RegularizationMetricsBase
 | 
			
		||||
    regularization_metrics_class_type: str = "RegularizationMetrics"
 | 
			
		||||
 | 
			
		||||
    # ---- loss weights
 | 
			
		||||
    loss_weights: Dict[str, float] = field(
 | 
			
		||||
        default_factory=lambda: {
 | 
			
		||||
            "loss_rgb_mse": 1.0,
 | 
			
		||||
            "loss_prev_stage_rgb_mse": 1.0,
 | 
			
		||||
            "loss_mask_bce": 0.0,
 | 
			
		||||
            "loss_prev_stage_mask_bce": 0.0,
 | 
			
		||||
        }
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # ---- variables to be logged (logger automatically ignores if not computed)
 | 
			
		||||
    log_vars: List[str] = field(
 | 
			
		||||
        default_factory=lambda: [
 | 
			
		||||
            "loss_rgb_psnr_fg",
 | 
			
		||||
            "loss_rgb_psnr",
 | 
			
		||||
            "loss_rgb_mse",
 | 
			
		||||
            "loss_rgb_huber",
 | 
			
		||||
            "loss_depth_abs",
 | 
			
		||||
            "loss_depth_abs_fg",
 | 
			
		||||
            "loss_mask_neg_iou",
 | 
			
		||||
            "loss_mask_bce",
 | 
			
		||||
            "loss_mask_beta_prior",
 | 
			
		||||
            "loss_eikonal",
 | 
			
		||||
            "loss_density_tv",
 | 
			
		||||
            "loss_depth_neg_penalty",
 | 
			
		||||
            "loss_autodecoder_norm",
 | 
			
		||||
            # metrics that are only logged in 2+stage renderes
 | 
			
		||||
            "loss_prev_stage_rgb_mse",
 | 
			
		||||
            "loss_prev_stage_rgb_psnr_fg",
 | 
			
		||||
            "loss_prev_stage_rgb_psnr",
 | 
			
		||||
            "loss_prev_stage_mask_bce",
 | 
			
		||||
            # basic metrics
 | 
			
		||||
            "objective",
 | 
			
		||||
            "epoch",
 | 
			
		||||
            "sec/it",
 | 
			
		||||
        ]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        # The attribute will be filled by run_auto_creation
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        log_loss_weights(self.loss_weights, logger)
 | 
			
		||||
        # We need to set it here since run_auto_creation
 | 
			
		||||
        # will create coarse_implicit_function before implicit_function
 | 
			
		||||
        if self.share_implicit_function_across_passes:
 | 
			
		||||
            self.coarse_implicit_function = self.implicit_function
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
        *,  # force keyword-only arguments
 | 
			
		||||
        image_rgb: Optional[torch.Tensor],
 | 
			
		||||
        camera: CamerasBase,
 | 
			
		||||
        fg_probability: Optional[torch.Tensor] = None,
 | 
			
		||||
        mask_crop: Optional[torch.Tensor] = None,
 | 
			
		||||
        depth_map: Optional[torch.Tensor] = None,
 | 
			
		||||
        sequence_name: Optional[List[str]] = None,
 | 
			
		||||
        frame_timestamp: Optional[torch.Tensor] = None,
 | 
			
		||||
        evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Dict[str, Any]:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
 | 
			
		||||
                the first `min(B, n_train_target_views)` images are considered targets and
 | 
			
		||||
                are used to supervise the renders; the rest corresponding to the source
 | 
			
		||||
                viewpoints from which features will be extracted.
 | 
			
		||||
            camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
 | 
			
		||||
                to the viewpoints of target images, from which the rays will be sampled,
 | 
			
		||||
                and source images, which will be used for intersecting with target rays.
 | 
			
		||||
            fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
 | 
			
		||||
                foreground masks.
 | 
			
		||||
            mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
 | 
			
		||||
                regions in the input images (i.e. regions that do not correspond
 | 
			
		||||
                to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
 | 
			
		||||
                "mask_sample", rays  will be sampled in the non zero regions.
 | 
			
		||||
            depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
 | 
			
		||||
            sequence_name: A list of `B` strings corresponding to the sequence names
 | 
			
		||||
                from which images `image_rgb` were extracted. They are used to match
 | 
			
		||||
                target frames with relevant source frames.
 | 
			
		||||
            frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
 | 
			
		||||
                of frame timestamps.
 | 
			
		||||
            evaluation_mode: one of EvaluationMode.TRAINING or
 | 
			
		||||
                EvaluationMode.EVALUATION which determines the settings used for
 | 
			
		||||
                rendering.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            preds: A dictionary containing all outputs of the forward pass including the
 | 
			
		||||
                rendered images, depths, masks, losses and other metrics.
 | 
			
		||||
        """
 | 
			
		||||
        image_rgb, fg_probability, depth_map = preprocess_input(
 | 
			
		||||
            image_rgb,
 | 
			
		||||
            fg_probability,
 | 
			
		||||
            depth_map,
 | 
			
		||||
            self.mask_images,
 | 
			
		||||
            self.mask_depths,
 | 
			
		||||
            self.mask_threshold,
 | 
			
		||||
            self.bg_color,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Determine the used ray sampling mode.
 | 
			
		||||
        sampling_mode = RenderSamplingMode(
 | 
			
		||||
            self.sampling_mode_training
 | 
			
		||||
            if evaluation_mode == EvaluationMode.TRAINING
 | 
			
		||||
            else self.sampling_mode_evaluation
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # (1) Sample rendering rays with the ray sampler.
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        ray_bundle: ImplicitronRayBundle = self.raysampler(
 | 
			
		||||
            camera,
 | 
			
		||||
            evaluation_mode,
 | 
			
		||||
            mask=mask_crop
 | 
			
		||||
            if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
            else None,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        inputs_to_be_chunked = {}
 | 
			
		||||
        if fg_probability is not None and self.renderer.requires_object_mask():
 | 
			
		||||
            sampled_fb_prob = rend_utils.ndc_grid_sample(
 | 
			
		||||
                fg_probability, ray_bundle.xys, mode="nearest"
 | 
			
		||||
            )
 | 
			
		||||
            inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5
 | 
			
		||||
 | 
			
		||||
        # (2)-(3) Implicit function evaluation and Rendering
 | 
			
		||||
        implicit_functions: List[Union[Callable, ImplicitFunctionBase]] = [
 | 
			
		||||
            self.implicit_function
 | 
			
		||||
        ]
 | 
			
		||||
        if self.coarse_implicit_function is not None:
 | 
			
		||||
            implicit_functions += [self.coarse_implicit_function]
 | 
			
		||||
 | 
			
		||||
        if self.global_encoder is not None:
 | 
			
		||||
            global_code = self.global_encoder(  # pyre-fixme[29]
 | 
			
		||||
                sequence_name=sequence_name,
 | 
			
		||||
                frame_timestamp=frame_timestamp,
 | 
			
		||||
            )
 | 
			
		||||
            implicit_functions = [
 | 
			
		||||
                functools.partial(implicit_function, global_code=global_code)
 | 
			
		||||
                if isinstance(implicit_function, Callable)
 | 
			
		||||
                else functools.partial(
 | 
			
		||||
                    implicit_function.forward, global_code=global_code
 | 
			
		||||
                )
 | 
			
		||||
                for implicit_function in implicit_functions
 | 
			
		||||
            ]
 | 
			
		||||
        rendered = self._render(
 | 
			
		||||
            ray_bundle=ray_bundle,
 | 
			
		||||
            sampling_mode=sampling_mode,
 | 
			
		||||
            evaluation_mode=evaluation_mode,
 | 
			
		||||
            implicit_functions=implicit_functions,
 | 
			
		||||
            inputs_to_be_chunked=inputs_to_be_chunked,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # A dict to store losses as well as rendering results.
 | 
			
		||||
        preds: Dict[str, Any] = self.view_metrics(
 | 
			
		||||
            results={},
 | 
			
		||||
            raymarched=rendered,
 | 
			
		||||
            ray_bundle=ray_bundle,
 | 
			
		||||
            image_rgb=image_rgb,
 | 
			
		||||
            depth_map=depth_map,
 | 
			
		||||
            fg_probability=fg_probability,
 | 
			
		||||
            mask_crop=mask_crop,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        preds.update(
 | 
			
		||||
            self.regularization_metrics(
 | 
			
		||||
                results=preds,
 | 
			
		||||
                model=self,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
 | 
			
		||||
            if self.output_rasterized_mc:
 | 
			
		||||
                # Visualize the monte-carlo pixel renders by splatting onto
 | 
			
		||||
                # an image grid.
 | 
			
		||||
                (
 | 
			
		||||
                    preds["images_render"],
 | 
			
		||||
                    preds["depths_render"],
 | 
			
		||||
                    preds["masks_render"],
 | 
			
		||||
                ) = rasterize_sparse_ray_bundle(
 | 
			
		||||
                    ray_bundle,
 | 
			
		||||
                    rendered.features,
 | 
			
		||||
                    (self.render_image_height, self.render_image_width),
 | 
			
		||||
                    rendered.depths,
 | 
			
		||||
                    masks=rendered.masks,
 | 
			
		||||
                )
 | 
			
		||||
        elif sampling_mode == RenderSamplingMode.FULL_GRID:
 | 
			
		||||
            preds["images_render"] = rendered.features.permute(0, 3, 1, 2)
 | 
			
		||||
            preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
 | 
			
		||||
            preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
 | 
			
		||||
 | 
			
		||||
            preds["implicitron_render"] = ImplicitronRender(
 | 
			
		||||
                image_render=preds["images_render"],
 | 
			
		||||
                depth_render=preds["depths_render"],
 | 
			
		||||
                mask_render=preds["masks_render"],
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise AssertionError("Unreachable state")
 | 
			
		||||
 | 
			
		||||
        # (4) Compute losses
 | 
			
		||||
        # finally get the optimization objective using self.loss_weights
 | 
			
		||||
        objective = self._get_objective(preds)
 | 
			
		||||
        if objective is not None:
 | 
			
		||||
            preds["objective"] = objective
 | 
			
		||||
 | 
			
		||||
        return preds
 | 
			
		||||
 | 
			
		||||
    def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]:
 | 
			
		||||
        """
 | 
			
		||||
        A helper function to compute the overall loss as the dot product
 | 
			
		||||
        of individual loss functions with the corresponding weights.
 | 
			
		||||
        """
 | 
			
		||||
        return weighted_sum_losses(preds, self.loss_weights)
 | 
			
		||||
 | 
			
		||||
    def visualize(
 | 
			
		||||
        self,
 | 
			
		||||
        viz: Optional["Visdom"],
 | 
			
		||||
        visdom_env_imgs: str,
 | 
			
		||||
        preds: Dict[str, Any],
 | 
			
		||||
        prefix: str,
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Helper function to visualize the predictions generated
 | 
			
		||||
        in the forward pass.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            viz: Visdom connection object
 | 
			
		||||
            visdom_env_imgs: name of visdom environment for the images.
 | 
			
		||||
            preds: predictions dict like returned by forward()
 | 
			
		||||
            prefix: prepended to the names of images
 | 
			
		||||
        """
 | 
			
		||||
        if viz is None or not viz.check_connection():
 | 
			
		||||
            logger.info("no visdom server! -> skipping batch vis")
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        idx_image = 0
 | 
			
		||||
        title = f"{prefix}_im{idx_image}"
 | 
			
		||||
 | 
			
		||||
        vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title)
 | 
			
		||||
 | 
			
		||||
    def _render(
 | 
			
		||||
        self,
 | 
			
		||||
        *,
 | 
			
		||||
        ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
        inputs_to_be_chunked: Dict[str, torch.Tensor],
 | 
			
		||||
        sampling_mode: RenderSamplingMode,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> RendererOutput:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
 | 
			
		||||
                sampled rendering rays.
 | 
			
		||||
            inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g.
 | 
			
		||||
                SignedDistanceFunctionRenderer requires "object_mask", shape
 | 
			
		||||
                (B, 1, H, W), the silhouette of the object in the image. When
 | 
			
		||||
                chunking, they are passed to the renderer as shape
 | 
			
		||||
                `(B, _, chunksize)`.
 | 
			
		||||
            sampling_mode: The sampling method to use. Must be a value from the
 | 
			
		||||
                RenderSamplingMode Enum.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            An instance of RendererOutput
 | 
			
		||||
        """
 | 
			
		||||
        if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
 | 
			
		||||
            return apply_chunked(
 | 
			
		||||
                self.renderer,
 | 
			
		||||
                chunk_generator(
 | 
			
		||||
                    self.chunk_size_grid,
 | 
			
		||||
                    ray_bundle,
 | 
			
		||||
                    inputs_to_be_chunked,
 | 
			
		||||
                    self.tqdm_trigger_threshold,
 | 
			
		||||
                    **kwargs,
 | 
			
		||||
                ),
 | 
			
		||||
                lambda batch: torch.cat(batch, dim=1).reshape(
 | 
			
		||||
                    *ray_bundle.lengths.shape[:-1], -1
 | 
			
		||||
                ),
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # pyre-fixme[29]: `BaseRenderer` is not a function.
 | 
			
		||||
            return self.renderer(
 | 
			
		||||
                ray_bundle=ray_bundle,
 | 
			
		||||
                **inputs_to_be_chunked,
 | 
			
		||||
                **kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        We don't expose certain fields of the raysampler because we want to set
 | 
			
		||||
        them from our own members.
 | 
			
		||||
        """
 | 
			
		||||
        del args["sampling_mode_training"]
 | 
			
		||||
        del args["sampling_mode_evaluation"]
 | 
			
		||||
        del args["image_width"]
 | 
			
		||||
        del args["image_height"]
 | 
			
		||||
 | 
			
		||||
    def create_raysampler(self):
 | 
			
		||||
        extra_args = {
 | 
			
		||||
            "sampling_mode_training": self.sampling_mode_training,
 | 
			
		||||
            "sampling_mode_evaluation": self.sampling_mode_evaluation,
 | 
			
		||||
            "image_width": self.render_image_width,
 | 
			
		||||
            "image_height": self.render_image_height,
 | 
			
		||||
        }
 | 
			
		||||
        raysampler_args = getattr(
 | 
			
		||||
            self, "raysampler_" + self.raysampler_class_type + "_args"
 | 
			
		||||
        )
 | 
			
		||||
        self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
 | 
			
		||||
            **raysampler_args, **extra_args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def renderer_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        We don't expose certain fields of the renderer because we want to set
 | 
			
		||||
        them based on other inputs.
 | 
			
		||||
        """
 | 
			
		||||
        args.pop("render_features_dimensions", None)
 | 
			
		||||
        args.pop("object_bounding_sphere", None)
 | 
			
		||||
 | 
			
		||||
    def create_renderer(self):
 | 
			
		||||
        extra_args = {}
 | 
			
		||||
 | 
			
		||||
        if self.renderer_class_type == "SignedDistanceFunctionRenderer":
 | 
			
		||||
            extra_args["render_features_dimensions"] = self.render_features_dimensions
 | 
			
		||||
            if not hasattr(self.raysampler, "scene_extent"):
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "SignedDistanceFunctionRenderer requires"
 | 
			
		||||
                    + " a raysampler that defines the 'scene_extent' field"
 | 
			
		||||
                    + " (this field is supported by, e.g., the adaptive raysampler - "
 | 
			
		||||
                    + " self.raysampler_class_type='AdaptiveRaySampler')."
 | 
			
		||||
                )
 | 
			
		||||
            extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
 | 
			
		||||
 | 
			
		||||
        renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
 | 
			
		||||
        self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
 | 
			
		||||
            **renderer_args, **extra_args
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        We don't expose certain implicit_function fields because we want to set
 | 
			
		||||
        them based on other inputs.
 | 
			
		||||
        """
 | 
			
		||||
        for arg in IMPLICIT_FUNCTION_ARGS_TO_REMOVE:
 | 
			
		||||
            args.pop(arg, None)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def coarse_implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        We don't expose certain implicit_function fields because we want to set
 | 
			
		||||
        them based on other inputs.
 | 
			
		||||
        """
 | 
			
		||||
        for arg in IMPLICIT_FUNCTION_ARGS_TO_REMOVE:
 | 
			
		||||
            args.pop(arg, None)
 | 
			
		||||
 | 
			
		||||
    def _create_extra_args_for_implicit_function(self) -> Dict[str, Any]:
 | 
			
		||||
        extra_args = {}
 | 
			
		||||
        global_encoder_dim = (
 | 
			
		||||
            0 if self.global_encoder is None else self.global_encoder.get_encoding_dim()
 | 
			
		||||
        )
 | 
			
		||||
        if self.implicit_function_class_type in (
 | 
			
		||||
            "NeuralRadianceFieldImplicitFunction",
 | 
			
		||||
            "NeRFormerImplicitFunction",
 | 
			
		||||
        ):
 | 
			
		||||
            extra_args["latent_dim"] = global_encoder_dim
 | 
			
		||||
            extra_args["color_dim"] = self.render_features_dimensions
 | 
			
		||||
 | 
			
		||||
        if self.implicit_function_class_type == "IdrFeatureField":
 | 
			
		||||
            extra_args["feature_work_size"] = global_encoder_dim
 | 
			
		||||
            extra_args["feature_vector_size"] = self.render_features_dimensions
 | 
			
		||||
 | 
			
		||||
        if self.implicit_function_class_type == "SRNImplicitFunction":
 | 
			
		||||
            extra_args["latent_dim"] = global_encoder_dim
 | 
			
		||||
        return extra_args
 | 
			
		||||
 | 
			
		||||
    def create_implicit_function(self) -> None:
 | 
			
		||||
        implicit_function_type = registry.get(
 | 
			
		||||
            ImplicitFunctionBase, self.implicit_function_class_type
 | 
			
		||||
        )
 | 
			
		||||
        expand_args_fields(implicit_function_type)
 | 
			
		||||
 | 
			
		||||
        config_name = f"implicit_function_{self.implicit_function_class_type}_args"
 | 
			
		||||
        config = getattr(self, config_name, None)
 | 
			
		||||
        if config is None:
 | 
			
		||||
            raise ValueError(f"{config_name} not present")
 | 
			
		||||
 | 
			
		||||
        extra_args = self._create_extra_args_for_implicit_function()
 | 
			
		||||
        self.implicit_function = implicit_function_type(**config, **extra_args)
 | 
			
		||||
 | 
			
		||||
    def create_coarse_implicit_function(self) -> None:
 | 
			
		||||
        # If coarse_implicit_function_class_type has been defined
 | 
			
		||||
        # then we init a module based on its arguments
 | 
			
		||||
        if (
 | 
			
		||||
            self.coarse_implicit_function_class_type is not None
 | 
			
		||||
            and not self.share_implicit_function_across_passes
 | 
			
		||||
        ):
 | 
			
		||||
            config_name = "coarse_implicit_function_{0}_args".format(
 | 
			
		||||
                self.coarse_implicit_function_class_type
 | 
			
		||||
            )
 | 
			
		||||
            config = getattr(self, config_name, {})
 | 
			
		||||
 | 
			
		||||
            implicit_function_type = registry.get(
 | 
			
		||||
                ImplicitFunctionBase,
 | 
			
		||||
                # pyre-ignore: config is None allow to check if this is None.
 | 
			
		||||
                self.coarse_implicit_function_class_type,
 | 
			
		||||
            )
 | 
			
		||||
            expand_args_fields(implicit_function_type)
 | 
			
		||||
 | 
			
		||||
            extra_args = self._create_extra_args_for_implicit_function()
 | 
			
		||||
            self.coarse_implicit_function = implicit_function_type(
 | 
			
		||||
                **config, **extra_args
 | 
			
		||||
            )
 | 
			
		||||
        elif self.share_implicit_function_across_passes:
 | 
			
		||||
            # Since coarse_implicit_function is initialised before
 | 
			
		||||
            # implicit_function we handle this case in the post_init.
 | 
			
		||||
            pass
 | 
			
		||||
        else:
 | 
			
		||||
            self.coarse_implicit_function = None
 | 
			
		||||
							
								
								
									
										195
									
								
								pytorch3d/implicitron/models/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										195
									
								
								pytorch3d/implicitron/models/utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,195 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Note: The #noqa comments below are for unused imports of pluggable implementations
 | 
			
		||||
# which are part of implicitron. They ensure that the registry is prepopulated.
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from logging import Logger
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
from pytorch3d.common.compat import prod
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools import image_utils
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def preprocess_input(
 | 
			
		||||
    image_rgb: Optional[torch.Tensor],
 | 
			
		||||
    fg_probability: Optional[torch.Tensor],
 | 
			
		||||
    depth_map: Optional[torch.Tensor],
 | 
			
		||||
    mask_images: bool,
 | 
			
		||||
    mask_depths: bool,
 | 
			
		||||
    mask_threshold: float,
 | 
			
		||||
    bg_color: Tuple[float, float, float],
 | 
			
		||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
 | 
			
		||||
    """
 | 
			
		||||
    Helper function to preprocess the input images and optional depth maps
 | 
			
		||||
    to apply masking if required.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
 | 
			
		||||
            corresponding to the source viewpoints from which features will be extracted
 | 
			
		||||
        fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
 | 
			
		||||
            of foreground masks with values in [0, 1].
 | 
			
		||||
        depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
 | 
			
		||||
        mask_images: Whether or not to mask the RGB image background given the
 | 
			
		||||
            foreground mask (the `fg_probability` argument of `GenericModel.forward`)
 | 
			
		||||
        mask_depths: Whether or not to mask the depth image background given the
 | 
			
		||||
            foreground mask (the `fg_probability` argument of `GenericModel.forward`)
 | 
			
		||||
        mask_threshold: If greater than 0.0, the foreground mask is
 | 
			
		||||
            thresholded by this value before being applied to the RGB/Depth images
 | 
			
		||||
        bg_color: RGB values for setting the background color of input image
 | 
			
		||||
            if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
 | 
			
		||||
            way to determine the background color of its output, unrelated to this.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Modified image_rgb, fg_mask, depth_map
 | 
			
		||||
    """
 | 
			
		||||
    if image_rgb is not None and image_rgb.ndim == 3:
 | 
			
		||||
        # The FrameData object is used for both frames and batches of frames,
 | 
			
		||||
        # and a user might get this error if those were confused.
 | 
			
		||||
        # Perhaps a user has a FrameData `fd` representing a single frame and
 | 
			
		||||
        # wrote something like `model(**fd)` instead of
 | 
			
		||||
        # `model(**fd.collate([fd]))`.
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            "Model received unbatched inputs. "
 | 
			
		||||
            + "Perhaps they came from a FrameData which had not been collated."
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    fg_mask = fg_probability
 | 
			
		||||
    if fg_mask is not None and mask_threshold > 0.0:
 | 
			
		||||
        # threshold masks
 | 
			
		||||
        warnings.warn("Thresholding masks!")
 | 
			
		||||
        fg_mask = (fg_mask >= mask_threshold).type_as(fg_mask)
 | 
			
		||||
 | 
			
		||||
    if mask_images and fg_mask is not None and image_rgb is not None:
 | 
			
		||||
        # mask the image
 | 
			
		||||
        warnings.warn("Masking images!")
 | 
			
		||||
        image_rgb = image_utils.mask_background(
 | 
			
		||||
            image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(bg_color)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    if mask_depths and fg_mask is not None and depth_map is not None:
 | 
			
		||||
        # mask the depths
 | 
			
		||||
        assert (
 | 
			
		||||
            mask_threshold > 0.0
 | 
			
		||||
        ), "Depths should be masked only with thresholded masks"
 | 
			
		||||
        warnings.warn("Masking depths!")
 | 
			
		||||
        depth_map = depth_map * fg_mask
 | 
			
		||||
 | 
			
		||||
    return image_rgb, fg_mask, depth_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def log_loss_weights(loss_weights: Dict[str, float], logger: Logger) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    Print a table of the loss weights.
 | 
			
		||||
    """
 | 
			
		||||
    loss_weights_message = (
 | 
			
		||||
        "-------\nloss_weights:\n"
 | 
			
		||||
        + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in loss_weights.items())
 | 
			
		||||
        + "-------"
 | 
			
		||||
    )
 | 
			
		||||
    logger.info(loss_weights_message)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def weighted_sum_losses(
 | 
			
		||||
    preds: Dict[str, torch.Tensor], loss_weights: Dict[str, float]
 | 
			
		||||
) -> Optional[torch.Tensor]:
 | 
			
		||||
    """
 | 
			
		||||
    A helper function to compute the overall loss as the dot product
 | 
			
		||||
    of individual loss functions with the corresponding weights.
 | 
			
		||||
    """
 | 
			
		||||
    losses_weighted = [
 | 
			
		||||
        preds[k] * float(w)
 | 
			
		||||
        for k, w in loss_weights.items()
 | 
			
		||||
        if (k in preds and w != 0.0)
 | 
			
		||||
    ]
 | 
			
		||||
    if len(losses_weighted) == 0:
 | 
			
		||||
        warnings.warn("No main objective found.")
 | 
			
		||||
        return None
 | 
			
		||||
    loss = sum(losses_weighted)
 | 
			
		||||
    assert torch.is_tensor(loss)
 | 
			
		||||
    # pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
 | 
			
		||||
    return loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_chunked(func, chunk_generator, tensor_collator):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function to apply a function on a sequence of
 | 
			
		||||
    chunked inputs yielded by a generator and collate
 | 
			
		||||
    the result.
 | 
			
		||||
    """
 | 
			
		||||
    processed_chunks = [
 | 
			
		||||
        func(*chunk_args, **chunk_kwargs)
 | 
			
		||||
        for chunk_args, chunk_kwargs in chunk_generator
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    return cat_dataclass(processed_chunks, tensor_collator)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def chunk_generator(
 | 
			
		||||
    chunk_size: int,
 | 
			
		||||
    ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
    chunked_inputs: Dict[str, torch.Tensor],
 | 
			
		||||
    tqdm_trigger_threshold: int,
 | 
			
		||||
    *args,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Helper function which yields chunks of rays from the
 | 
			
		||||
    input ray_bundle, to be used when the number of rays is
 | 
			
		||||
    large and will not fit in memory for rendering.
 | 
			
		||||
    """
 | 
			
		||||
    (
 | 
			
		||||
        batch_size,
 | 
			
		||||
        *spatial_dim,
 | 
			
		||||
        n_pts_per_ray,
 | 
			
		||||
    ) = ray_bundle.lengths.shape  # B x ... x n_pts_per_ray
 | 
			
		||||
    if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
 | 
			
		||||
        raise ValueError(
 | 
			
		||||
            f"chunk_size_grid ({chunk_size}) should be divisible "
 | 
			
		||||
            f"by n_pts_per_ray ({n_pts_per_ray})"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    n_rays = prod(spatial_dim)
 | 
			
		||||
    # special handling for raytracing-based methods
 | 
			
		||||
    n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
 | 
			
		||||
    chunk_size_in_rays = -(-n_rays // n_chunks)
 | 
			
		||||
 | 
			
		||||
    iter = range(0, n_rays, chunk_size_in_rays)
 | 
			
		||||
    if len(iter) >= tqdm_trigger_threshold:
 | 
			
		||||
        iter = tqdm.tqdm(iter)
 | 
			
		||||
 | 
			
		||||
    def _safe_slice(
 | 
			
		||||
        tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
 | 
			
		||||
    ) -> Any:
 | 
			
		||||
        return tensor[start_idx:end_idx] if tensor is not None else None
 | 
			
		||||
 | 
			
		||||
    for start_idx in iter:
 | 
			
		||||
        end_idx = min(start_idx + chunk_size_in_rays, n_rays)
 | 
			
		||||
        ray_bundle_chunk = ImplicitronRayBundle(
 | 
			
		||||
            origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
 | 
			
		||||
            directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
 | 
			
		||||
                :, start_idx:end_idx
 | 
			
		||||
            ],
 | 
			
		||||
            lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
 | 
			
		||||
                :, start_idx:end_idx
 | 
			
		||||
            ],
 | 
			
		||||
            xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
 | 
			
		||||
            camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
 | 
			
		||||
            camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
 | 
			
		||||
        )
 | 
			
		||||
        extra_args = kwargs.copy()
 | 
			
		||||
        for k, v in chunked_inputs.items():
 | 
			
		||||
            extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
 | 
			
		||||
        yield [ray_bundle_chunk, *args], extra_args
 | 
			
		||||
@ -7,8 +7,9 @@
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import pytorch3d
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
 | 
			
		||||
from pytorch3d.ops import packed_to_padded
 | 
			
		||||
from pytorch3d.renderer import PerspectiveCameras
 | 
			
		||||
from pytorch3d.structures import Pointclouds
 | 
			
		||||
@ -18,7 +19,7 @@ from .point_cloud_utils import render_point_cloud_pytorch3d
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def rasterize_sparse_ray_bundle(
 | 
			
		||||
    ray_bundle: ImplicitronRayBundle,
 | 
			
		||||
    ray_bundle: "pytorch3d.implicitron.models.renderer.base.ImplicitronRayBundle",
 | 
			
		||||
    features: torch.Tensor,
 | 
			
		||||
    image_size_hw: Tuple[int, int],
 | 
			
		||||
    depth: torch.Tensor,
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										5
									
								
								tests/implicitron/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								tests/implicitron/models/__init__.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,5 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
							
								
								
									
										217
									
								
								tests/implicitron/models/test_overfit_model.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								tests/implicitron/models/test_overfit_model.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,217 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
from typing import Any, Dict
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.models.generic_model import GenericModel
 | 
			
		||||
from pytorch3d.implicitron.models.overfit_model import OverfitModel
 | 
			
		||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
 | 
			
		||||
from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
			
		||||
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
 | 
			
		||||
 | 
			
		||||
DEVICE = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _generate_fake_inputs(N: int, H: int, W: int) -> Dict[str, Any]:
 | 
			
		||||
    R, T = look_at_view_transform(azim=torch.rand(N) * 360)
 | 
			
		||||
    return {
 | 
			
		||||
        "camera": PerspectiveCameras(R=R, T=T, device=DEVICE),
 | 
			
		||||
        "fg_probability": torch.randint(
 | 
			
		||||
            high=2, size=(N, 1, H, W), device=DEVICE
 | 
			
		||||
        ).float(),
 | 
			
		||||
        "depth_map": torch.rand((N, 1, H, W), device=DEVICE) + 0.1,
 | 
			
		||||
        "mask_crop": torch.randint(high=2, size=(N, 1, H, W), device=DEVICE).float(),
 | 
			
		||||
        "sequence_name": ["sequence"] * N,
 | 
			
		||||
        "image_rgb": torch.rand((N, 1, H, W), device=DEVICE),
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def mock_safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
 | 
			
		||||
    """Return non deterministic indexes to mock safe_multinomial
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        input: tensor of shape [B, n] containing non-negative values;
 | 
			
		||||
                rows are interpreted as unnormalized event probabilities
 | 
			
		||||
                in categorical distributions.
 | 
			
		||||
        num_samples: number of samples to take.
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        Tensor of shape [B, num_samples]
 | 
			
		||||
    """
 | 
			
		||||
    batch_size = input.shape[0]
 | 
			
		||||
    return torch.arange(num_samples).repeat(batch_size, 1).to(DEVICE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestOverfitModel(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        torch.manual_seed(42)
 | 
			
		||||
 | 
			
		||||
    def test_overfit_model_vs_generic_model_with_batch_size_one(self):
 | 
			
		||||
        """In this test we compare OverfitModel to GenericModel behavior.
 | 
			
		||||
 | 
			
		||||
        We use a Nerf setup (2 rendering passes).
 | 
			
		||||
 | 
			
		||||
        OverfitModel is a specific case of GenericModel. Hence, with the same inputs,
 | 
			
		||||
        they should provide the exact same results.
 | 
			
		||||
        """
 | 
			
		||||
        expand_args_fields(OverfitModel)
 | 
			
		||||
        expand_args_fields(GenericModel)
 | 
			
		||||
        batch_size, image_height, image_width = 1, 80, 80
 | 
			
		||||
        assert batch_size == 1
 | 
			
		||||
        overfit_model = OverfitModel(
 | 
			
		||||
            render_image_height=image_height,
 | 
			
		||||
            render_image_width=image_width,
 | 
			
		||||
            coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
 | 
			
		||||
            # To avoid randomization to compare the outputs of our model
 | 
			
		||||
            # we deactivate the stratified_point_sampling_training
 | 
			
		||||
            raysampler_AdaptiveRaySampler_args={
 | 
			
		||||
                "stratified_point_sampling_training": False
 | 
			
		||||
            },
 | 
			
		||||
            global_encoder_class_type="SequenceAutodecoder",
 | 
			
		||||
            global_encoder_SequenceAutodecoder_args={
 | 
			
		||||
                "autodecoder_args": {
 | 
			
		||||
                    "n_instances": 1000,
 | 
			
		||||
                    "init_scale": 1.0,
 | 
			
		||||
                    "encoding_dim": 64,
 | 
			
		||||
                }
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        generic_model = GenericModel(
 | 
			
		||||
            render_image_height=image_height,
 | 
			
		||||
            render_image_width=image_width,
 | 
			
		||||
            n_train_target_views=batch_size,
 | 
			
		||||
            num_passes=2,
 | 
			
		||||
            # To avoid randomization to compare the outputs of our model
 | 
			
		||||
            # we deactivate the stratified_point_sampling_training
 | 
			
		||||
            raysampler_AdaptiveRaySampler_args={
 | 
			
		||||
                "stratified_point_sampling_training": False
 | 
			
		||||
            },
 | 
			
		||||
            global_encoder_class_type="SequenceAutodecoder",
 | 
			
		||||
            global_encoder_SequenceAutodecoder_args={
 | 
			
		||||
                "autodecoder_args": {
 | 
			
		||||
                    "n_instances": 1000,
 | 
			
		||||
                    "init_scale": 1.0,
 | 
			
		||||
                    "encoding_dim": 64,
 | 
			
		||||
                }
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Check if they do share the number of parameters
 | 
			
		||||
        num_params_mvm = sum(p.numel() for p in overfit_model.parameters())
 | 
			
		||||
        num_params_gm = sum(p.numel() for p in generic_model.parameters())
 | 
			
		||||
        self.assertEqual(num_params_mvm, num_params_gm)
 | 
			
		||||
 | 
			
		||||
        # Adapt the mapping from generic model to overfit model
 | 
			
		||||
        mapping_om_from_gm = {
 | 
			
		||||
            key.replace("_implicit_functions.0._fn", "implicit_function").replace(
 | 
			
		||||
                "_implicit_functions.1._fn", "coarse_implicit_function"
 | 
			
		||||
            ): val
 | 
			
		||||
            for key, val in generic_model.state_dict().items()
 | 
			
		||||
        }
 | 
			
		||||
        # Copy parameters from generic_model to overfit_model
 | 
			
		||||
        overfit_model.load_state_dict(mapping_om_from_gm)
 | 
			
		||||
 | 
			
		||||
        overfit_model.to(DEVICE)
 | 
			
		||||
        generic_model.to(DEVICE)
 | 
			
		||||
        inputs_ = _generate_fake_inputs(batch_size, image_height, image_width)
 | 
			
		||||
 | 
			
		||||
        # training forward pass
 | 
			
		||||
        overfit_model.train()
 | 
			
		||||
        generic_model.train()
 | 
			
		||||
 | 
			
		||||
        with patch(
 | 
			
		||||
            "pytorch3d.renderer.implicit.raysampling._safe_multinomial",
 | 
			
		||||
            side_effect=mock_safe_multinomial,
 | 
			
		||||
        ):
 | 
			
		||||
            train_preds_om = overfit_model(
 | 
			
		||||
                **inputs_,
 | 
			
		||||
                evaluation_mode=EvaluationMode.TRAINING,
 | 
			
		||||
            )
 | 
			
		||||
            train_preds_gm = generic_model(
 | 
			
		||||
                **inputs_,
 | 
			
		||||
                evaluation_mode=EvaluationMode.TRAINING,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(len(train_preds_om) == len(train_preds_gm))
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(train_preds_om["objective"].isfinite().item())
 | 
			
		||||
        # We avoid all the randomization and the weights are the same
 | 
			
		||||
        # The objective should be the same
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(train_preds_om["objective"], train_preds_gm["objective"])
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test if the evaluation works
 | 
			
		||||
        overfit_model.eval()
 | 
			
		||||
        generic_model.eval()
 | 
			
		||||
        with torch.no_grad():
 | 
			
		||||
            eval_preds_om = overfit_model(
 | 
			
		||||
                **inputs_,
 | 
			
		||||
                evaluation_mode=EvaluationMode.EVALUATION,
 | 
			
		||||
            )
 | 
			
		||||
            eval_preds_gm = generic_model(
 | 
			
		||||
                **inputs_,
 | 
			
		||||
                evaluation_mode=EvaluationMode.EVALUATION,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(
 | 
			
		||||
            eval_preds_om["images_render"].shape,
 | 
			
		||||
            (batch_size, 3, image_height, image_width),
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(eval_preds_om["objective"], eval_preds_gm["objective"])
 | 
			
		||||
        )
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            torch.allclose(
 | 
			
		||||
                eval_preds_om["images_render"], eval_preds_gm["images_render"]
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def test_overfit_model_check_share_weights(self):
 | 
			
		||||
        model = OverfitModel(share_implicit_function_across_passes=True)
 | 
			
		||||
        for p1, p2 in zip(
 | 
			
		||||
            model.implicit_function.parameters(),
 | 
			
		||||
            model.coarse_implicit_function.parameters(),
 | 
			
		||||
        ):
 | 
			
		||||
            self.assertEqual(id(p1), id(p2))
 | 
			
		||||
 | 
			
		||||
        model.to(DEVICE)
 | 
			
		||||
        inputs_ = _generate_fake_inputs(2, 80, 80)
 | 
			
		||||
        model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
 | 
			
		||||
 | 
			
		||||
    def test_overfit_model_check_no_share_weights(self):
 | 
			
		||||
        model = OverfitModel(
 | 
			
		||||
            share_implicit_function_across_passes=False,
 | 
			
		||||
            coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
 | 
			
		||||
            coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args={
 | 
			
		||||
                "transformer_dim_down_factor": 1.0,
 | 
			
		||||
                "n_hidden_neurons_xyz": 256,
 | 
			
		||||
                "n_layers_xyz": 8,
 | 
			
		||||
                "append_xyz": (5,),
 | 
			
		||||
            },
 | 
			
		||||
        )
 | 
			
		||||
        for p1, p2 in zip(
 | 
			
		||||
            model.implicit_function.parameters(),
 | 
			
		||||
            model.coarse_implicit_function.parameters(),
 | 
			
		||||
        ):
 | 
			
		||||
            self.assertNotEqual(id(p1), id(p2))
 | 
			
		||||
 | 
			
		||||
        model.to(DEVICE)
 | 
			
		||||
        inputs_ = _generate_fake_inputs(2, 80, 80)
 | 
			
		||||
        model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
 | 
			
		||||
 | 
			
		||||
    def test_overfit_model_coarse_implicit_function_is_none(self):
 | 
			
		||||
        model = OverfitModel(
 | 
			
		||||
            share_implicit_function_across_passes=False,
 | 
			
		||||
            coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args=None,
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsNone(model.coarse_implicit_function)
 | 
			
		||||
        model.to(DEVICE)
 | 
			
		||||
        inputs_ = _generate_fake_inputs(2, 80, 80)
 | 
			
		||||
        model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
 | 
			
		||||
							
								
								
									
										66
									
								
								tests/implicitron/models/test_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								tests/implicitron/models/test_utils.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,66 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
#
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestUtils(unittest.TestCase):
 | 
			
		||||
    def test_prepare_inputs_wrong_num_dim(self):
 | 
			
		||||
        img = torch.randn(3, 3, 3)
 | 
			
		||||
        with self.assertRaises(ValueError) as context:
 | 
			
		||||
            img, fg_prob, depth_map = preprocess_input(
 | 
			
		||||
                img, None, None, True, True, 0.5, (0.0, 0.0, 0.0)
 | 
			
		||||
            )
 | 
			
		||||
            self.assertEqual(
 | 
			
		||||
                "Model received unbatched inputs. "
 | 
			
		||||
                + "Perhaps they came from a FrameData which had not been collated.",
 | 
			
		||||
                context.exception,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def test_prepare_inputs_mask_image_true(self):
 | 
			
		||||
        batch, channels, height, width = 2, 3, 10, 10
 | 
			
		||||
        img = torch.ones(batch, channels, height, width)
 | 
			
		||||
        # Create a mask on the lower triangular matrix
 | 
			
		||||
        fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
 | 
			
		||||
 | 
			
		||||
        out_img, out_fg_prob, out_depth_map = preprocess_input(
 | 
			
		||||
            img, fg_prob, None, True, False, 0.3, (0.0, 0.0, 0.0)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(torch.equal(out_img, torch.tril(img)))
 | 
			
		||||
        self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
 | 
			
		||||
        self.assertIsNone(out_depth_map)
 | 
			
		||||
 | 
			
		||||
    def test_prepare_inputs_mask_depth_true(self):
 | 
			
		||||
        batch, channels, height, width = 2, 3, 10, 10
 | 
			
		||||
        img = torch.ones(batch, channels, height, width)
 | 
			
		||||
        depth_map = torch.randn(batch, channels, height, width)
 | 
			
		||||
        # Create a mask on the lower triangular matrix
 | 
			
		||||
        fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
 | 
			
		||||
 | 
			
		||||
        out_img, out_fg_prob, out_depth_map = preprocess_input(
 | 
			
		||||
            img, fg_prob, depth_map, False, True, 0.3, (0.0, 0.0, 0.0)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(torch.equal(out_img, img))
 | 
			
		||||
        self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
 | 
			
		||||
        self.assertTrue(torch.equal(out_depth_map, torch.tril(depth_map)))
 | 
			
		||||
 | 
			
		||||
    def test_weighted_sum_losses(self):
 | 
			
		||||
        preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
 | 
			
		||||
        weights = {"a": 2.0, "b": 0.0}
 | 
			
		||||
        loss = weighted_sum_losses(preds, weights)
 | 
			
		||||
        self.assertEqual(loss, 4.0)
 | 
			
		||||
 | 
			
		||||
    def test_weighted_sum_losses_raise_warning(self):
 | 
			
		||||
        preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
 | 
			
		||||
        weights = {"c": 2.0, "d": 2.0}
 | 
			
		||||
        self.assertIsNone(weighted_sum_losses(preds, weights))
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user