diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index 232d697b..f8f875eb 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -248,7 +248,7 @@ The main object for this trainer loop is `Experiment`. It has four top-level rep * `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`. It constructs the data sets and dataloaders. * `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`. -It constructs the model, which is usually an instance of implicitron's main `GenericModel` class, and can load its weights from a checkpoint. +It constructs the model, which is usually an instance of `OverfitModel` (for NeRF-style training with overfitting to one scene) or `GenericModel` (that is able to generalize to multiple scenes by NeRFormer-style conditioning on other scene views), and can load its weights from a checkpoint. * `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`. It constructs the optimizer and can load its weights from a checkpoint. * `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop. @@ -292,6 +292,43 @@ model_GenericModel_args: GenericModel ╘== ReductionFeatureAggregator ``` +Here is the class structure of OverfitModel: + +``` +model_OverfitModel_args: OverfitModel +└-- raysampler_*_args: RaySampler + ╘== AdaptiveRaysampler + ╘== NearFarRaysampler +└-- renderer_*_args: BaseRenderer + ╘== MultiPassEmissionAbsorptionRenderer + ╘== LSTMRenderer + ╘== SignedDistanceFunctionRenderer + └-- ray_tracer_args: RayTracing + └-- ray_normal_coloring_network_args: RayNormalColoringNetwork +└-- implicit_function_*_args: ImplicitFunctionBase + ╘== NeuralRadianceFieldImplicitFunction + ╘== SRNImplicitFunction + └-- raymarch_function_args: SRNRaymarchFunction + └-- pixel_generator_args: SRNPixelGenerator + ╘== SRNHyperNetImplicitFunction + └-- hypernet_args: SRNRaymarchHyperNet + └-- pixel_generator_args: SRNPixelGenerator + ╘== IdrFeatureField +└-- coarse_implicit_function_*_args: ImplicitFunctionBase + ╘== NeuralRadianceFieldImplicitFunction + ╘== SRNImplicitFunction + └-- raymarch_function_args: SRNRaymarchFunction + └-- pixel_generator_args: SRNPixelGenerator + ╘== SRNHyperNetImplicitFunction + └-- hypernet_args: SRNRaymarchHyperNet + └-- pixel_generator_args: SRNPixelGenerator + ╘== IdrFeatureField +``` + +OverfitModel has been introduced to create a simple class to disantagle Nerfs which the overfit pattern +from the GenericModel. + + Please look at the annotations of the respective classes or functions for the lists of hyperparameters. `tests/experiment.yaml` shows every possible option if you have no user-defined classes. diff --git a/projects/implicitron_trainer/configs/overfit_base.yaml b/projects/implicitron_trainer/configs/overfit_base.yaml new file mode 100644 index 00000000..d5cc0ccc --- /dev/null +++ b/projects/implicitron_trainer/configs/overfit_base.yaml @@ -0,0 +1,79 @@ +defaults: +- default_config +- _self_ +exp_dir: ./data/exps/overfit_base/ +training_loop_ImplicitronTrainingLoop_args: + visdom_port: 8097 + visualize_interval: 0 + max_epochs: 1000 +data_source_ImplicitronDataSource_args: + data_loader_map_provider_class_type: SequenceDataLoaderMapProvider + dataset_map_provider_class_type: JsonIndexDatasetMapProvider + data_loader_map_provider_SequenceDataLoaderMapProvider_args: + dataset_length_train: 1000 + dataset_length_val: 1 + num_workers: 8 + dataset_map_provider_JsonIndexDatasetMapProvider_args: + dataset_root: ${oc.env:CO3D_DATASET_ROOT} + n_frames_per_sequence: -1 + test_on_train: true + test_restrict_sequence_id: 0 + dataset_JsonIndexDataset_args: + load_point_clouds: false + mask_depths: false + mask_images: false +model_factory_ImplicitronModelFactory_args: + model_class_type: "OverfitModel" + model_OverfitModel_args: + loss_weights: + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 1.0 + loss_autodecoder_norm: 0.01 + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + output_rasterized_mc: false + chunk_size_grid: 102400 + render_image_height: 400 + render_image_width: 400 + share_implicit_function_across_passes: false + implicit_function_class_type: "NeuralRadianceFieldImplicitFunction" + implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + n_layers_xyz: 8 + append_xyz: + - 5 + coarse_implicit_function_class_type: "NeuralRadianceFieldImplicitFunction" + coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + n_layers_xyz: 8 + append_xyz: + - 5 + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 1024 + scene_extent: 8.0 + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 64 + n_pts_per_ray_fine_evaluation: 64 + append_coarse_samples_to_fine: true + density_noise_std_train: 1.0 +optimizer_factory_ImplicitronOptimizerFactory_args: + breed: Adam + weight_decay: 0.0 + lr_policy: MultiStepLR + multistep_lr_milestones: [] + lr: 0.0005 + gamma: 0.1 + momentum: 0.9 + betas: + - 0.9 + - 0.999 diff --git a/projects/implicitron_trainer/configs/overfit_singleseq_base.yaml b/projects/implicitron_trainer/configs/overfit_singleseq_base.yaml new file mode 100644 index 00000000..0349fd27 --- /dev/null +++ b/projects/implicitron_trainer/configs/overfit_singleseq_base.yaml @@ -0,0 +1,42 @@ +defaults: +- overfit_base +- _self_ +data_source_ImplicitronDataSource_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: + batch_size: 1 + dataset_length_train: 1000 + dataset_length_val: 1 + num_workers: 8 + dataset_map_provider_JsonIndexDatasetMapProvider_args: + assert_single_seq: true + n_frames_per_sequence: -1 + test_restrict_sequence_id: 0 + test_on_train: false +model_factory_ImplicitronModelFactory_args: + model_class_type: "OverfitModel" + model_OverfitModel_args: + render_image_height: 800 + render_image_width: 800 + log_vars: + - loss_rgb_psnr_fg + - loss_rgb_psnr + - loss_eikonal + - loss_prev_stage_rgb_psnr + - loss_mask_bce + - loss_prev_stage_mask_bce + - loss_rgb_mse + - loss_prev_stage_rgb_mse + - loss_depth_abs + - loss_depth_abs_fg + - loss_kl + - loss_mask_neg_iou + - objective + - epoch + - sec/it +optimizer_factory_ImplicitronOptimizerFactory_args: + lr: 0.0005 + multistep_lr_milestones: + - 200 + - 300 +training_loop_ImplicitronTrainingLoop_args: + max_epochs: 400 diff --git a/projects/implicitron_trainer/configs/overfit_singleseq_nerf_blender.yaml b/projects/implicitron_trainer/configs/overfit_singleseq_nerf_blender.yaml new file mode 100644 index 00000000..c61d759f --- /dev/null +++ b/projects/implicitron_trainer/configs/overfit_singleseq_nerf_blender.yaml @@ -0,0 +1,56 @@ +defaults: +- overfit_singleseq_base +- _self_ +exp_dir: "./data/overfit_nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}" +data_source_ImplicitronDataSource_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: + dataset_length_train: 100 + dataset_map_provider_class_type: BlenderDatasetMapProvider + dataset_map_provider_BlenderDatasetMapProvider_args: + base_dir: ${oc.env:BLENDER_DATASET_ROOT}/${oc.env:BLENDER_SINGLESEQ_CLASS} + n_known_frames_for_test: null + object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS} + path_manager_factory_class_type: PathManagerFactory + path_manager_factory_PathManagerFactory_args: + silence_logs: true + +model_factory_ImplicitronModelFactory_args: + model_class_type: "OverfitModel" + model_OverfitModel_args: + mask_images: false + raysampler_class_type: AdaptiveRaySampler + raysampler_AdaptiveRaySampler_args: + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + n_rays_per_image_sampled_from_mask: 4096 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + scene_extent: 2.0 + scene_center: + - 0.0 + - 0.0 + - 0.0 + renderer_MultiPassEmissionAbsorptionRenderer_args: + density_noise_std_train: 0.0 + n_pts_per_ray_fine_training: 128 + n_pts_per_ray_fine_evaluation: 128 + raymarcher_EmissionAbsorptionRaymarcher_args: + blend_output: false + loss_weights: + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + loss_mask_bce: 0.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.00 + +optimizer_factory_ImplicitronOptimizerFactory_args: + exponential_lr_step_size: 3001 + lr_policy: LinearExponential + linear_exponential_lr_milestone: 200 + +training_loop_ImplicitronTrainingLoop_args: + max_epochs: 6000 + metric_print_interval: 10 + store_checkpoints_purge: 3 + test_when_finished: true + validation_interval: 100 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index cede59a3..797660c8 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -59,7 +59,7 @@ from pytorch3d.implicitron.dataset.data_source import ( DataSourceBase, ImplicitronDataSource, ) -from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase +from pytorch3d.implicitron.models.base_model import ImplicitronModelBase from pytorch3d.implicitron.models.renderer.multipass_ea import ( MultiPassEmissionAbsorptionRenderer, diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index e8cdba05..f2df83e5 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -561,6 +561,623 @@ model_factory_ImplicitronModelFactory_args: use_xavier_init: true view_metrics_ViewMetrics_args: {} regularization_metrics_RegularizationMetrics_args: {} + model_OverfitModel_args: + log_vars: + - loss_rgb_psnr_fg + - loss_rgb_psnr + - loss_rgb_mse + - loss_rgb_huber + - loss_depth_abs + - loss_depth_abs_fg + - loss_mask_neg_iou + - loss_mask_bce + - loss_mask_beta_prior + - loss_eikonal + - loss_density_tv + - loss_depth_neg_penalty + - loss_autodecoder_norm + - loss_prev_stage_rgb_mse + - loss_prev_stage_rgb_psnr_fg + - loss_prev_stage_rgb_psnr + - loss_prev_stage_mask_bce + - objective + - epoch + - sec/it + mask_images: true + mask_depths: true + render_image_width: 400 + render_image_height: 400 + mask_threshold: 0.5 + output_rasterized_mc: false + bg_color: + - 0.0 + - 0.0 + - 0.0 + chunk_size_grid: 4096 + render_features_dimensions: 3 + tqdm_trigger_threshold: 16 + n_train_target_views: 1 + sampling_mode_training: mask_sample + sampling_mode_evaluation: full_grid + global_encoder_class_type: null + raysampler_class_type: AdaptiveRaySampler + renderer_class_type: MultiPassEmissionAbsorptionRenderer + share_implicit_function_across_passes: false + implicit_function_class_type: NeuralRadianceFieldImplicitFunction + coarse_implicit_function_class_type: null + view_metrics_class_type: ViewMetrics + regularization_metrics_class_type: RegularizationMetrics + loss_weights: + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + loss_mask_bce: 0.0 + loss_prev_stage_mask_bce: 0.0 + global_encoder_HarmonicTimeEncoder_args: + n_harmonic_functions: 10 + append_input: true + time_divisor: 1.0 + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + encoding_dim: 0 + n_instances: 1 + init_scale: 1.0 + ignore_input: false + raysampler_AdaptiveRaySampler_args: + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + n_rays_per_image_sampled_from_mask: 1024 + n_rays_total_training: null + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + scene_extent: 8.0 + scene_center: + - 0.0 + - 0.0 + - 0.0 + raysampler_NearFarRaySampler_args: + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + n_rays_per_image_sampled_from_mask: 1024 + n_rays_total_training: null + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + min_depth: 0.1 + max_depth: 8.0 + renderer_LSTMRenderer_args: + num_raymarch_steps: 10 + init_depth: 17.0 + init_depth_noise_std: 0.0005 + hidden_size: 16 + n_feature_channels: 256 + bg_color: null + verbose: false + renderer_MultiPassEmissionAbsorptionRenderer_args: + raymarcher_class_type: EmissionAbsorptionRaymarcher + n_pts_per_ray_fine_training: 64 + n_pts_per_ray_fine_evaluation: 64 + stratified_sampling_coarse_training: true + stratified_sampling_coarse_evaluation: false + append_coarse_samples_to_fine: true + density_noise_std_train: 0.0 + return_weights: false + raymarcher_CumsumRaymarcher_args: + surface_thickness: 1 + bg_color: + - 0.0 + replicate_last_interval: false + background_opacity: 0.0 + density_relu: true + blend_output: false + raymarcher_EmissionAbsorptionRaymarcher_args: + surface_thickness: 1 + bg_color: + - 0.0 + replicate_last_interval: false + background_opacity: 10000000000.0 + density_relu: true + blend_output: false + renderer_SignedDistanceFunctionRenderer_args: + ray_normal_coloring_network_args: + feature_vector_size: 3 + mode: idr + d_in: 9 + d_out: 3 + dims: + - 512 + - 512 + - 512 + - 512 + weight_norm: true + n_harmonic_functions_dir: 0 + pooled_feature_dim: 0 + bg_color: + - 0.0 + soft_mask_alpha: 50.0 + ray_tracer_args: + sdf_threshold: 5.0e-05 + line_search_step: 0.5 + line_step_iters: 1 + sphere_tracing_iters: 10 + n_steps: 100 + n_secant_steps: 8 + implicit_function_IdrFeatureField_args: + d_in: 3 + d_out: 1 + dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + geometric_init: true + bias: 1.0 + skip_in: [] + weight_norm: true + n_harmonic_functions_xyz: 0 + pooled_feature_dim: 0 + implicit_function_NeRFormerImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + transformer_dim_down_factor: 2.0 + n_hidden_neurons_xyz: 80 + n_layers_xyz: 2 + append_xyz: + - 1 + implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + transformer_dim_down_factor: 1.0 + n_hidden_neurons_xyz: 256 + n_layers_xyz: 8 + append_xyz: + - 5 + implicit_function_SRNHyperNetImplicitFunction_args: + latent_dim_hypernet: 0 + hypernet_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + n_hidden_units_hypernet: 256 + n_layers_hypernet: 1 + in_features: 3 + out_features: 256 + xyz_in_camera_coords: false + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false + implicit_function_SRNImplicitFunction_args: + raymarch_function_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + in_features: 3 + out_features: 256 + xyz_in_camera_coords: false + raymarch_function: null + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false + implicit_function_VoxelGridImplicitFunction_args: + harmonic_embedder_xyz_density_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + harmonic_embedder_xyz_color_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + harmonic_embedder_dir_color_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + decoder_density_class_type: MLPDecoder + decoder_color_class_type: MLPDecoder + use_multiple_streams: true + xyz_ray_dir_in_camera_coords: false + scaffold_calculating_epochs: [] + scaffold_resolution: + - 128 + - 128 + - 128 + scaffold_empty_space_threshold: 0.001 + scaffold_occupancy_chunk_size: -1 + scaffold_max_pool_kernel_size: 3 + scaffold_filter_points: true + volume_cropping_epochs: [] + voxel_grid_density_args: + voxel_grid_class_type: FullResolutionVoxelGrid + extents: + - 2.0 + - 2.0 + - 2.0 + translation: + - 0.0 + - 0.0 + - 0.0 + init_std: 0.1 + init_mean: 0.0 + hold_voxel_grid_as_parameters: true + param_groups: {} + voxel_grid_CPFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: 24 + basis_matrix: true + voxel_grid_FullResolutionVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + voxel_grid_VMFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: null + distribution_of_components: null + basis_matrix: true + voxel_grid_color_args: + voxel_grid_class_type: FullResolutionVoxelGrid + extents: + - 2.0 + - 2.0 + - 2.0 + translation: + - 0.0 + - 0.0 + - 0.0 + init_std: 0.1 + init_mean: 0.0 + hold_voxel_grid_as_parameters: true + param_groups: {} + voxel_grid_CPFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: 24 + basis_matrix: true + voxel_grid_FullResolutionVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + voxel_grid_VMFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: null + distribution_of_components: null + basis_matrix: true + decoder_density_ElementwiseDecoder_args: + scale: 1.0 + shift: 0.0 + operation: IDENTITY + decoder_density_MLPDecoder_args: + param_groups: {} + network_args: + n_layers: 8 + output_dim: 256 + skip_dim: 39 + hidden_dim: 256 + input_skips: + - 5 + skip_affine_trans: false + last_layer_bias_init: null + last_activation: RELU + use_xavier_init: true + decoder_color_ElementwiseDecoder_args: + scale: 1.0 + shift: 0.0 + operation: IDENTITY + decoder_color_MLPDecoder_args: + param_groups: {} + network_args: + n_layers: 8 + output_dim: 256 + skip_dim: 39 + hidden_dim: 256 + input_skips: + - 5 + skip_affine_trans: false + last_layer_bias_init: null + last_activation: RELU + use_xavier_init: true + coarse_implicit_function_IdrFeatureField_args: + d_in: 3 + d_out: 1 + dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + geometric_init: true + bias: 1.0 + skip_in: [] + weight_norm: true + n_harmonic_functions_xyz: 0 + pooled_feature_dim: 0 + coarse_implicit_function_NeRFormerImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + transformer_dim_down_factor: 2.0 + n_hidden_neurons_xyz: 80 + n_layers_xyz: 2 + append_xyz: + - 1 + coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + transformer_dim_down_factor: 1.0 + n_hidden_neurons_xyz: 256 + n_layers_xyz: 8 + append_xyz: + - 5 + coarse_implicit_function_SRNHyperNetImplicitFunction_args: + latent_dim_hypernet: 0 + hypernet_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + n_hidden_units_hypernet: 256 + n_layers_hypernet: 1 + in_features: 3 + out_features: 256 + xyz_in_camera_coords: false + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false + coarse_implicit_function_SRNImplicitFunction_args: + raymarch_function_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + in_features: 3 + out_features: 256 + xyz_in_camera_coords: false + raymarch_function: null + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false + coarse_implicit_function_VoxelGridImplicitFunction_args: + harmonic_embedder_xyz_density_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + harmonic_embedder_xyz_color_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + harmonic_embedder_dir_color_args: + n_harmonic_functions: 6 + omega_0: 1.0 + logspace: true + append_input: true + decoder_density_class_type: MLPDecoder + decoder_color_class_type: MLPDecoder + use_multiple_streams: true + xyz_ray_dir_in_camera_coords: false + scaffold_calculating_epochs: [] + scaffold_resolution: + - 128 + - 128 + - 128 + scaffold_empty_space_threshold: 0.001 + scaffold_occupancy_chunk_size: -1 + scaffold_max_pool_kernel_size: 3 + scaffold_filter_points: true + volume_cropping_epochs: [] + voxel_grid_density_args: + voxel_grid_class_type: FullResolutionVoxelGrid + extents: + - 2.0 + - 2.0 + - 2.0 + translation: + - 0.0 + - 0.0 + - 0.0 + init_std: 0.1 + init_mean: 0.0 + hold_voxel_grid_as_parameters: true + param_groups: {} + voxel_grid_CPFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: 24 + basis_matrix: true + voxel_grid_FullResolutionVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + voxel_grid_VMFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: null + distribution_of_components: null + basis_matrix: true + voxel_grid_color_args: + voxel_grid_class_type: FullResolutionVoxelGrid + extents: + - 2.0 + - 2.0 + - 2.0 + translation: + - 0.0 + - 0.0 + - 0.0 + init_std: 0.1 + init_mean: 0.0 + hold_voxel_grid_as_parameters: true + param_groups: {} + voxel_grid_CPFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: 24 + basis_matrix: true + voxel_grid_FullResolutionVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + voxel_grid_VMFactorizedVoxelGrid_args: + align_corners: true + padding: zeros + mode: bilinear + n_features: 1 + resolution_changes: + 0: + - 128 + - 128 + - 128 + n_components: null + distribution_of_components: null + basis_matrix: true + decoder_density_ElementwiseDecoder_args: + scale: 1.0 + shift: 0.0 + operation: IDENTITY + decoder_density_MLPDecoder_args: + param_groups: {} + network_args: + n_layers: 8 + output_dim: 256 + skip_dim: 39 + hidden_dim: 256 + input_skips: + - 5 + skip_affine_trans: false + last_layer_bias_init: null + last_activation: RELU + use_xavier_init: true + decoder_color_ElementwiseDecoder_args: + scale: 1.0 + shift: 0.0 + operation: IDENTITY + decoder_color_MLPDecoder_args: + param_groups: {} + network_args: + n_layers: 8 + output_dim: 256 + skip_dim: 39 + hidden_dim: 256 + input_skips: + - 5 + skip_affine_trans: false + last_layer_bias_init: null + last_activation: RELU + use_xavier_init: true + view_metrics_ViewMetrics_args: {} + regularization_metrics_RegularizationMetrics_args: {} optimizer_factory_ImplicitronOptimizerFactory_args: betas: - 0.9 diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index d16c788d..a07eb8be 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -141,7 +141,11 @@ class TestExperiment(unittest.TestCase): # Check that all the pre-prepared configs are valid. config_files = [] - for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): + for pattern in ( + "repro_singleseq*.yaml", + "repro_multiseq*.yaml", + "overfit_singleseq*.yaml", + ): config_files.extend( [ f diff --git a/pytorch3d/implicitron/models/__init__.py b/pytorch3d/implicitron/models/__init__.py index 2e41cd71..5a3ab83f 100644 --- a/pytorch3d/implicitron/models/__init__.py +++ b/pytorch3d/implicitron/models/__init__.py @@ -3,3 +3,8 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +# Allows to register the models +# see: pytorch3d.implicitron.tools.config.registry:register +from pytorch3d.implicitron.models.generic_model import GenericModel +from pytorch3d.implicitron.models.overfit_model import OverfitModel diff --git a/pytorch3d/implicitron/models/base_model.py b/pytorch3d/implicitron/models/base_model.py index 56efa69c..bd48bf7f 100644 --- a/pytorch3d/implicitron/models/base_model.py +++ b/pytorch3d/implicitron/models/base_model.py @@ -8,11 +8,11 @@ from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import torch + +from pytorch3d.implicitron.models.renderer.base import EvaluationMode from pytorch3d.implicitron.tools.config import ReplaceableBase from pytorch3d.renderer.cameras import CamerasBase -from .renderer.base import EvaluationMode - @dataclass class ImplicitronRender: diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index d20be5bf..b903814f 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -9,14 +9,11 @@ # which are part of implicitron. They ensure that the registry is prepopulated. import logging -import warnings from dataclasses import field from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch -import tqdm from omegaconf import DictConfig -from pytorch3d.common.compat import prod from pytorch3d.implicitron.models.base_model import ( ImplicitronModelBase, @@ -33,11 +30,9 @@ from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( ) from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa NeRFormerImplicitFunction, - NeuralRadianceFieldImplicitFunction, ) from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa SRNHyperNetImplicitFunction, - SRNImplicitFunction, ) from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa VoxelGridImplicitFunction, @@ -63,8 +58,16 @@ from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa SignedDistanceFunctionRenderer, ) + +from pytorch3d.implicitron.models.utils import ( + apply_chunked, + chunk_generator, + log_loss_weights, + preprocess_input, + weighted_sum_losses, +) from pytorch3d.implicitron.models.view_pooler.view_pooler import ViewPooler -from pytorch3d.implicitron.tools import image_utils, vis_utils +from pytorch3d.implicitron.tools import vis_utils from pytorch3d.implicitron.tools.config import ( expand_args_fields, registry, @@ -72,7 +75,6 @@ from pytorch3d.implicitron.tools.config import ( ) from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle -from pytorch3d.implicitron.tools.utils import cat_dataclass from pytorch3d.renderer import utils as rend_utils from pytorch3d.renderer.cameras import CamerasBase @@ -323,7 +325,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 self._implicit_functions = self._construct_implicit_functions() - self.log_loss_weights() + log_loss_weights(self.loss_weights, logger) def forward( self, @@ -367,8 +369,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 preds: A dictionary containing all outputs of the forward pass including the rendered images, depths, masks, losses and other metrics. """ - image_rgb, fg_probability, depth_map = self._preprocess_input( - image_rgb, fg_probability, depth_map + image_rgb, fg_probability, depth_map = preprocess_input( + image_rgb, + fg_probability, + depth_map, + self.mask_images, + self.mask_depths, + self.mask_threshold, + self.bg_color, ) # Obtain the batch size from the camera as this is the only required input. @@ -453,12 +461,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 for func in self._implicit_functions: func.bind_args(**custom_args) - chunked_renderer_inputs = {} + inputs_to_be_chunked = {} if fg_probability is not None and self.renderer.requires_object_mask(): sampled_fb_prob = rend_utils.ndc_grid_sample( fg_probability[:n_targets], ray_bundle.xys, mode="nearest" ) - chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5 + inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5 # (5)-(6) Implicit function evaluation and Rendering rendered = self._render( @@ -466,7 +474,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 sampling_mode=sampling_mode, evaluation_mode=evaluation_mode, implicit_functions=self._implicit_functions, - chunked_inputs=chunked_renderer_inputs, + inputs_to_be_chunked=inputs_to_be_chunked, ) # Unbind the custom arguments to prevent pytorch from storing @@ -530,30 +538,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 raise AssertionError("Unreachable state") # (7) Compute losses - # finally get the optimization objective using self.loss_weights objective = self._get_objective(preds) if objective is not None: preds["objective"] = objective return preds - def _get_objective(self, preds) -> Optional[torch.Tensor]: + def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]: """ A helper function to compute the overall loss as the dot product of individual loss functions with the corresponding weights. """ - losses_weighted = [ - preds[k] * float(w) - for k, w in self.loss_weights.items() - if (k in preds and w != 0.0) - ] - if len(losses_weighted) == 0: - warnings.warn("No main objective found.") - return None - loss = sum(losses_weighted) - assert torch.is_tensor(loss) - # pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`. - return loss + return weighted_sum_losses(preds, self.loss_weights) def visualize( self, @@ -585,7 +581,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 self, *, ray_bundle: ImplicitronRayBundle, - chunked_inputs: Dict[str, torch.Tensor], + inputs_to_be_chunked: Dict[str, torch.Tensor], sampling_mode: RenderSamplingMode, **kwargs, ) -> RendererOutput: @@ -593,7 +589,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 Args: ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the sampled rendering rays. - chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g. + inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g. SignedDistanceFunctionRenderer requires "object_mask", shape (B, 1, H, W), the silhouette of the object in the image. When chunking, they are passed to the renderer as shape @@ -605,30 +601,27 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 An instance of RendererOutput """ if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0: - return _apply_chunked( + return apply_chunked( self.renderer, - _chunk_generator( + chunk_generator( self.chunk_size_grid, ray_bundle, - chunked_inputs, + inputs_to_be_chunked, self.tqdm_trigger_threshold, **kwargs, ), - lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]), + lambda batch: torch.cat(batch, dim=1).reshape( + *ray_bundle.lengths.shape[:-1], -1 + ), ) else: # pyre-fixme[29]: `BaseRenderer` is not a function. return self.renderer( ray_bundle=ray_bundle, - **chunked_inputs, + **inputs_to_be_chunked, **kwargs, ) - def _get_global_encoder_encoding_dim(self) -> int: - if self.global_encoder is None: - return 0 - return self.global_encoder.get_encoding_dim() - def _get_viewpooled_feature_dim(self) -> int: if self.view_pooler is None: return 0 @@ -720,30 +713,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 function(s) are initialized. """ extra_args = {} + global_encoder_dim = ( + 0 if self.global_encoder is None else self.global_encoder.get_encoding_dim() + ) + viewpooled_feature_dim = self._get_viewpooled_feature_dim() + if self.implicit_function_class_type in ( "NeuralRadianceFieldImplicitFunction", "NeRFormerImplicitFunction", ): - extra_args["latent_dim"] = ( - self._get_viewpooled_feature_dim() - + self._get_global_encoder_encoding_dim() - ) + extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim extra_args["color_dim"] = self.render_features_dimensions if self.implicit_function_class_type == "IdrFeatureField": extra_args["feature_vector_size"] = self.render_features_dimensions - extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim() + extra_args["encoding_dim"] = global_encoder_dim if self.implicit_function_class_type == "SRNImplicitFunction": - extra_args["latent_dim"] = ( - self._get_viewpooled_feature_dim() - + self._get_global_encoder_encoding_dim() - ) + extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim # srn_hypernet preprocessing if self.implicit_function_class_type == "SRNHyperNetImplicitFunction": - extra_args["latent_dim"] = self._get_viewpooled_feature_dim() - extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim() + extra_args["latent_dim"] = viewpooled_feature_dim + extra_args["latent_dim_hypernet"] = global_encoder_dim # check that for srn, srn_hypernet, idr we have self.num_passes=1 implicit_function_type = registry.get( @@ -770,147 +762,3 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 for _ in range(self.num_passes) ] return torch.nn.ModuleList(implicit_functions_list) - - def log_loss_weights(self) -> None: - """ - Print a table of the loss weights. - """ - loss_weights_message = ( - "-------\nloss_weights:\n" - + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items()) - + "-------" - ) - logger.info(loss_weights_message) - - def _preprocess_input( - self, - image_rgb: Optional[torch.Tensor], - fg_probability: Optional[torch.Tensor], - depth_map: Optional[torch.Tensor], - ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Helper function to preprocess the input images and optional depth maps - to apply masking if required. - - Args: - image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images - corresponding to the source viewpoints from which features will be extracted - fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch - of foreground masks with values in [0, 1]. - depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. - - Returns: - Modified image_rgb, fg_mask, depth_map - """ - if image_rgb is not None and image_rgb.ndim == 3: - # The FrameData object is used for both frames and batches of frames, - # and a user might get this error if those were confused. - # Perhaps a user has a FrameData `fd` representing a single frame and - # wrote something like `model(**fd)` instead of - # `model(**fd.collate([fd]))`. - raise ValueError( - "Model received unbatched inputs. " - + "Perhaps they came from a FrameData which had not been collated." - ) - - fg_mask = fg_probability - if fg_mask is not None and self.mask_threshold > 0.0: - # threshold masks - warnings.warn("Thresholding masks!") - fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask) - - if self.mask_images and fg_mask is not None and image_rgb is not None: - # mask the image - warnings.warn("Masking images!") - image_rgb = image_utils.mask_background( - image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color) - ) - - if self.mask_depths and fg_mask is not None and depth_map is not None: - # mask the depths - assert ( - self.mask_threshold > 0.0 - ), "Depths should be masked only with thresholded masks" - warnings.warn("Masking depths!") - depth_map = depth_map * fg_mask - - return image_rgb, fg_mask, depth_map - - -def _apply_chunked(func, chunk_generator, tensor_collator): - """ - Helper function to apply a function on a sequence of - chunked inputs yielded by a generator and collate - the result. - """ - processed_chunks = [ - func(*chunk_args, **chunk_kwargs) - for chunk_args, chunk_kwargs in chunk_generator - ] - - return cat_dataclass(processed_chunks, tensor_collator) - - -def _tensor_collator(batch, new_dims) -> torch.Tensor: - """ - Helper function to reshape the batch to the desired shape - """ - return torch.cat(batch, dim=1).reshape(*new_dims, -1) - - -def _chunk_generator( - chunk_size: int, - ray_bundle: ImplicitronRayBundle, - chunked_inputs: Dict[str, torch.Tensor], - tqdm_trigger_threshold: int, - *args, - **kwargs, -): - """ - Helper function which yields chunks of rays from the - input ray_bundle, to be used when the number of rays is - large and will not fit in memory for rendering. - """ - ( - batch_size, - *spatial_dim, - n_pts_per_ray, - ) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray - if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0: - raise ValueError( - f"chunk_size_grid ({chunk_size}) should be divisible " - f"by n_pts_per_ray ({n_pts_per_ray})" - ) - - n_rays = prod(spatial_dim) - # special handling for raytracing-based methods - n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) - chunk_size_in_rays = -(-n_rays // n_chunks) - - iter = range(0, n_rays, chunk_size_in_rays) - if len(iter) >= tqdm_trigger_threshold: - iter = tqdm.tqdm(iter) - - def _safe_slice( - tensor: Optional[torch.Tensor], start_idx: int, end_idx: int - ) -> Any: - return tensor[start_idx:end_idx] if tensor is not None else None - - for start_idx in iter: - end_idx = min(start_idx + chunk_size_in_rays, n_rays) - ray_bundle_chunk = ImplicitronRayBundle( - origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], - directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ - :, start_idx:end_idx - ], - lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[ - :, start_idx:end_idx - ], - xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], - camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx), - camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx), - ) - extra_args = kwargs.copy() - for k, v in chunked_inputs.items(): - extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx] - yield [ray_bundle_chunk, *args], extra_args diff --git a/pytorch3d/implicitron/models/overfit_model.py b/pytorch3d/implicitron/models/overfit_model.py new file mode 100644 index 00000000..b773f437 --- /dev/null +++ b/pytorch3d/implicitron/models/overfit_model.py @@ -0,0 +1,639 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# Note: The #noqa comments below are for unused imports of pluggable implementations +# which are part of implicitron. They ensure that the registry is prepopulated. + +import functools +import logging +from dataclasses import field +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +import torch +from omegaconf import DictConfig + +from pytorch3d.implicitron.models.base_model import ( + ImplicitronModelBase, + ImplicitronRender, +) +from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase +from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase +from pytorch3d.implicitron.models.metrics import ( + RegularizationMetricsBase, + ViewMetricsBase, +) + +from pytorch3d.implicitron.models.renderer.base import ( + BaseRenderer, + EvaluationMode, + ImplicitronRayBundle, + RendererOutput, + RenderSamplingMode, +) +from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase +from pytorch3d.implicitron.models.utils import ( + apply_chunked, + chunk_generator, + log_loss_weights, + preprocess_input, + weighted_sum_losses, +) +from pytorch3d.implicitron.tools import vis_utils +from pytorch3d.implicitron.tools.config import ( + expand_args_fields, + registry, + run_auto_creation, +) + +from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle +from pytorch3d.renderer import utils as rend_utils +from pytorch3d.renderer.cameras import CamerasBase + + +if TYPE_CHECKING: + from visdom import Visdom +logger = logging.getLogger(__name__) + +IMPLICIT_FUNCTION_ARGS_TO_REMOVE: List[str] = [ + "feature_vector_size", + "encoding_dim", + "latent_dim", + "color_dim", +] + + +@registry.register +class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13 + """ + OverfitModel is a wrapper for the neural implicit + rendering and reconstruction pipeline which consists + of the following sequence of 4 steps: + + + (1) Ray Sampling + ------------------ + Rays are sampled from an image grid based on the target view(s). + │ + ▼ + (2) Implicit Function Evaluation + ------------------ + Evaluate the implicit function(s) at the sampled ray points + (also optionally pass in a global encoding from global_encoder). + │ + ▼ + (3) Rendering + ------------------ + Render the image into the target cameras by raymarching along + the sampled rays and aggregating the colors and densities + output by the implicit function in (2). + │ + ▼ + (4) Loss Computation + ------------------ + Compute losses based on the predicted target image(s). + + + The `forward` function of OverfitModel executes + this sequence of steps. Currently, steps 1, 2, 3 + can be customized by intializing a subclass of the appropriate + base class and adding the newly created module to the registry. + Please see https://github.com/facebookresearch/pytorch3d/blob/main/projects/implicitron_trainer/README.md#custom-plugins + for more details on how to create and register a custom component. + + In the config .yaml files for experiments, the parameters below are + contained in the + `model_factory_ImplicitronModelFactory_args.model_OverfitModel_args` + node. As OverfitModel derives from ReplaceableBase, the input arguments are + parsed by the run_auto_creation function to initialize the + necessary member modules. Please see implicitron_trainer/README.md + for more details on this process. + + Args: + mask_images: Whether or not to mask the RGB image background given the + foreground mask (the `fg_probability` argument of `GenericModel.forward`) + mask_depths: Whether or not to mask the depth image background given the + foreground mask (the `fg_probability` argument of `GenericModel.forward`) + render_image_width: Width of the output image to render + render_image_height: Height of the output image to render + mask_threshold: If greater than 0.0, the foreground mask is + thresholded by this value before being applied to the RGB/Depth images + output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by + splatting onto an image grid. Default: False. + bg_color: RGB values for setting the background color of input image + if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own + way to determine the background color of its output, unrelated to this. + chunk_size_grid: The total number of points which can be rendered + per chunk. This is used to compute the number of rays used + per chunk when the chunked version of the renderer is used (in order + to fit rendering on all rays in memory) + render_features_dimensions: The number of output features to render. + Defaults to 3, corresponding to RGB images. + sampling_mode_training: The sampling method to use during training. Must be + a value from the RenderSamplingMode Enum. + sampling_mode_evaluation: Same as above but for evaluation. + global_encoder_class_type: The name of the class to use for global_encoder, + which must be available in the registry. Or `None` to disable global encoder. + global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding + of the image (referred to as the global_code) that can be used to model aspects of + the scene such as multiple objects or morphing objects. It is up to the implicit + function definition how to use it, but the most typical way is to broadcast and + concatenate to the other inputs for the implicit function. + raysampler_class_type: The name of the raysampler class which is available + in the global registry. + raysampler: An instance of RaySampler which is used to emit + rays from the target view(s). + renderer_class_type: The name of the renderer class which is available in the global + registry. + renderer: A renderer class which inherits from BaseRenderer. This is used to + generate the images from the target view(s). + share_implicit_function_across_passes: If set to True + coarse_implicit_function is automatically set as implicit_function + (coarse_implicit_function=implicit_funciton). The + implicit_functions are then run sequentially during the rendering. + implicit_function_class_type: The type of implicit function to use which + is available in the global registry. + implicit_function: An instance of ImplicitFunctionBase. + coarse_implicit_function_class_type: The type of implicit function to use which + is available in the global registry. + coarse_implicit_function: An instance of ImplicitFunctionBase. + If set and `share_implicit_function_across_passes` is set to False, + coarse_implicit_function is instantiated on itself. It + is then used as the second pass during the rendering. + If set to None, we only do a single pass with implicit_function. + view_metrics: An instance of ViewMetricsBase used to compute loss terms which + are independent of the model's parameters. + view_metrics_class_type: The type of view metrics to use, must be available in + the global registry. + regularization_metrics: An instance of RegularizationMetricsBase used to compute + regularization terms which can depend on the model's parameters. + regularization_metrics_class_type: The type of regularization metrics to use, + must be available in the global registry. + loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation + for `ViewMetrics` class for available loss functions. + log_vars: A list of variable names which should be logged. + The names should correspond to a subset of the keys of the + dict `preds` output by the `forward` function. + """ # noqa: B950 + + mask_images: bool = True + mask_depths: bool = True + render_image_width: int = 400 + render_image_height: int = 400 + mask_threshold: float = 0.5 + output_rasterized_mc: bool = False + bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) + chunk_size_grid: int = 4096 + render_features_dimensions: int = 3 + tqdm_trigger_threshold: int = 16 + + n_train_target_views: int = 1 + sampling_mode_training: str = "mask_sample" + sampling_mode_evaluation: str = "full_grid" + + # ---- global encoder settings + global_encoder_class_type: Optional[str] = None + global_encoder: Optional[GlobalEncoderBase] + + # ---- raysampler + raysampler_class_type: str = "AdaptiveRaySampler" + raysampler: RaySamplerBase + + # ---- renderer configs + renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer" + renderer: BaseRenderer + + # ---- implicit function settings + share_implicit_function_across_passes: bool = False + implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction" + implicit_function: ImplicitFunctionBase + coarse_implicit_function_class_type: Optional[str] = None + coarse_implicit_function: Optional[ImplicitFunctionBase] + + # ----- metrics + view_metrics: ViewMetricsBase + view_metrics_class_type: str = "ViewMetrics" + + regularization_metrics: RegularizationMetricsBase + regularization_metrics_class_type: str = "RegularizationMetrics" + + # ---- loss weights + loss_weights: Dict[str, float] = field( + default_factory=lambda: { + "loss_rgb_mse": 1.0, + "loss_prev_stage_rgb_mse": 1.0, + "loss_mask_bce": 0.0, + "loss_prev_stage_mask_bce": 0.0, + } + ) + + # ---- variables to be logged (logger automatically ignores if not computed) + log_vars: List[str] = field( + default_factory=lambda: [ + "loss_rgb_psnr_fg", + "loss_rgb_psnr", + "loss_rgb_mse", + "loss_rgb_huber", + "loss_depth_abs", + "loss_depth_abs_fg", + "loss_mask_neg_iou", + "loss_mask_bce", + "loss_mask_beta_prior", + "loss_eikonal", + "loss_density_tv", + "loss_depth_neg_penalty", + "loss_autodecoder_norm", + # metrics that are only logged in 2+stage renderes + "loss_prev_stage_rgb_mse", + "loss_prev_stage_rgb_psnr_fg", + "loss_prev_stage_rgb_psnr", + "loss_prev_stage_mask_bce", + # basic metrics + "objective", + "epoch", + "sec/it", + ] + ) + + def __post_init__(self): + # The attribute will be filled by run_auto_creation + run_auto_creation(self) + log_loss_weights(self.loss_weights, logger) + # We need to set it here since run_auto_creation + # will create coarse_implicit_function before implicit_function + if self.share_implicit_function_across_passes: + self.coarse_implicit_function = self.implicit_function + + def forward( + self, + *, # force keyword-only arguments + image_rgb: Optional[torch.Tensor], + camera: CamerasBase, + fg_probability: Optional[torch.Tensor] = None, + mask_crop: Optional[torch.Tensor] = None, + depth_map: Optional[torch.Tensor] = None, + sequence_name: Optional[List[str]] = None, + frame_timestamp: Optional[torch.Tensor] = None, + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs, + ) -> Dict[str, Any]: + """ + Args: + image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images; + the first `min(B, n_train_target_views)` images are considered targets and + are used to supervise the renders; the rest corresponding to the source + viewpoints from which features will be extracted. + camera: An instance of CamerasBase containing a batch of `B` cameras corresponding + to the viewpoints of target images, from which the rays will be sampled, + and source images, which will be used for intersecting with target rays. + fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of + foreground masks. + mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid + regions in the input images (i.e. regions that do not correspond + to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to + "mask_sample", rays will be sampled in the non zero regions. + depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. + sequence_name: A list of `B` strings corresponding to the sequence names + from which images `image_rgb` were extracted. They are used to match + target frames with relevant source frames. + frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch + of frame timestamps. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering. + + Returns: + preds: A dictionary containing all outputs of the forward pass including the + rendered images, depths, masks, losses and other metrics. + """ + image_rgb, fg_probability, depth_map = preprocess_input( + image_rgb, + fg_probability, + depth_map, + self.mask_images, + self.mask_depths, + self.mask_threshold, + self.bg_color, + ) + + # Determine the used ray sampling mode. + sampling_mode = RenderSamplingMode( + self.sampling_mode_training + if evaluation_mode == EvaluationMode.TRAINING + else self.sampling_mode_evaluation + ) + + # (1) Sample rendering rays with the ray sampler. + # pyre-ignore[29] + ray_bundle: ImplicitronRayBundle = self.raysampler( + camera, + evaluation_mode, + mask=mask_crop + if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE + else None, + ) + + inputs_to_be_chunked = {} + if fg_probability is not None and self.renderer.requires_object_mask(): + sampled_fb_prob = rend_utils.ndc_grid_sample( + fg_probability, ray_bundle.xys, mode="nearest" + ) + inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5 + + # (2)-(3) Implicit function evaluation and Rendering + implicit_functions: List[Union[Callable, ImplicitFunctionBase]] = [ + self.implicit_function + ] + if self.coarse_implicit_function is not None: + implicit_functions += [self.coarse_implicit_function] + + if self.global_encoder is not None: + global_code = self.global_encoder( # pyre-fixme[29] + sequence_name=sequence_name, + frame_timestamp=frame_timestamp, + ) + implicit_functions = [ + functools.partial(implicit_function, global_code=global_code) + if isinstance(implicit_function, Callable) + else functools.partial( + implicit_function.forward, global_code=global_code + ) + for implicit_function in implicit_functions + ] + rendered = self._render( + ray_bundle=ray_bundle, + sampling_mode=sampling_mode, + evaluation_mode=evaluation_mode, + implicit_functions=implicit_functions, + inputs_to_be_chunked=inputs_to_be_chunked, + ) + + # A dict to store losses as well as rendering results. + preds: Dict[str, Any] = self.view_metrics( + results={}, + raymarched=rendered, + ray_bundle=ray_bundle, + image_rgb=image_rgb, + depth_map=depth_map, + fg_probability=fg_probability, + mask_crop=mask_crop, + ) + + preds.update( + self.regularization_metrics( + results=preds, + model=self, + ) + ) + + if sampling_mode == RenderSamplingMode.MASK_SAMPLE: + if self.output_rasterized_mc: + # Visualize the monte-carlo pixel renders by splatting onto + # an image grid. + ( + preds["images_render"], + preds["depths_render"], + preds["masks_render"], + ) = rasterize_sparse_ray_bundle( + ray_bundle, + rendered.features, + (self.render_image_height, self.render_image_width), + rendered.depths, + masks=rendered.masks, + ) + elif sampling_mode == RenderSamplingMode.FULL_GRID: + preds["images_render"] = rendered.features.permute(0, 3, 1, 2) + preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2) + preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2) + + preds["implicitron_render"] = ImplicitronRender( + image_render=preds["images_render"], + depth_render=preds["depths_render"], + mask_render=preds["masks_render"], + ) + else: + raise AssertionError("Unreachable state") + + # (4) Compute losses + # finally get the optimization objective using self.loss_weights + objective = self._get_objective(preds) + if objective is not None: + preds["objective"] = objective + + return preds + + def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]: + """ + A helper function to compute the overall loss as the dot product + of individual loss functions with the corresponding weights. + """ + return weighted_sum_losses(preds, self.loss_weights) + + def visualize( + self, + viz: Optional["Visdom"], + visdom_env_imgs: str, + preds: Dict[str, Any], + prefix: str, + ) -> None: + """ + Helper function to visualize the predictions generated + in the forward pass. + + Args: + viz: Visdom connection object + visdom_env_imgs: name of visdom environment for the images. + preds: predictions dict like returned by forward() + prefix: prepended to the names of images + """ + if viz is None or not viz.check_connection(): + logger.info("no visdom server! -> skipping batch vis") + return + + idx_image = 0 + title = f"{prefix}_im{idx_image}" + + vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title) + + def _render( + self, + *, + ray_bundle: ImplicitronRayBundle, + inputs_to_be_chunked: Dict[str, torch.Tensor], + sampling_mode: RenderSamplingMode, + **kwargs, + ) -> RendererOutput: + """ + Args: + ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the + sampled rendering rays. + inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g. + SignedDistanceFunctionRenderer requires "object_mask", shape + (B, 1, H, W), the silhouette of the object in the image. When + chunking, they are passed to the renderer as shape + `(B, _, chunksize)`. + sampling_mode: The sampling method to use. Must be a value from the + RenderSamplingMode Enum. + + Returns: + An instance of RendererOutput + """ + if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0: + return apply_chunked( + self.renderer, + chunk_generator( + self.chunk_size_grid, + ray_bundle, + inputs_to_be_chunked, + self.tqdm_trigger_threshold, + **kwargs, + ), + lambda batch: torch.cat(batch, dim=1).reshape( + *ray_bundle.lengths.shape[:-1], -1 + ), + ) + else: + # pyre-fixme[29]: `BaseRenderer` is not a function. + return self.renderer( + ray_bundle=ray_bundle, + **inputs_to_be_chunked, + **kwargs, + ) + + @classmethod + def raysampler_tweak_args(cls, type, args: DictConfig) -> None: + """ + We don't expose certain fields of the raysampler because we want to set + them from our own members. + """ + del args["sampling_mode_training"] + del args["sampling_mode_evaluation"] + del args["image_width"] + del args["image_height"] + + def create_raysampler(self): + extra_args = { + "sampling_mode_training": self.sampling_mode_training, + "sampling_mode_evaluation": self.sampling_mode_evaluation, + "image_width": self.render_image_width, + "image_height": self.render_image_height, + } + raysampler_args = getattr( + self, "raysampler_" + self.raysampler_class_type + "_args" + ) + self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)( + **raysampler_args, **extra_args + ) + + @classmethod + def renderer_tweak_args(cls, type, args: DictConfig) -> None: + """ + We don't expose certain fields of the renderer because we want to set + them based on other inputs. + """ + args.pop("render_features_dimensions", None) + args.pop("object_bounding_sphere", None) + + def create_renderer(self): + extra_args = {} + + if self.renderer_class_type == "SignedDistanceFunctionRenderer": + extra_args["render_features_dimensions"] = self.render_features_dimensions + if not hasattr(self.raysampler, "scene_extent"): + raise ValueError( + "SignedDistanceFunctionRenderer requires" + + " a raysampler that defines the 'scene_extent' field" + + " (this field is supported by, e.g., the adaptive raysampler - " + + " self.raysampler_class_type='AdaptiveRaySampler')." + ) + extra_args["object_bounding_sphere"] = self.raysampler.scene_extent + + renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args") + self.renderer = registry.get(BaseRenderer, self.renderer_class_type)( + **renderer_args, **extra_args + ) + + @classmethod + def implicit_function_tweak_args(cls, type, args: DictConfig) -> None: + """ + We don't expose certain implicit_function fields because we want to set + them based on other inputs. + """ + for arg in IMPLICIT_FUNCTION_ARGS_TO_REMOVE: + args.pop(arg, None) + + @classmethod + def coarse_implicit_function_tweak_args(cls, type, args: DictConfig) -> None: + """ + We don't expose certain implicit_function fields because we want to set + them based on other inputs. + """ + for arg in IMPLICIT_FUNCTION_ARGS_TO_REMOVE: + args.pop(arg, None) + + def _create_extra_args_for_implicit_function(self) -> Dict[str, Any]: + extra_args = {} + global_encoder_dim = ( + 0 if self.global_encoder is None else self.global_encoder.get_encoding_dim() + ) + if self.implicit_function_class_type in ( + "NeuralRadianceFieldImplicitFunction", + "NeRFormerImplicitFunction", + ): + extra_args["latent_dim"] = global_encoder_dim + extra_args["color_dim"] = self.render_features_dimensions + + if self.implicit_function_class_type == "IdrFeatureField": + extra_args["feature_work_size"] = global_encoder_dim + extra_args["feature_vector_size"] = self.render_features_dimensions + + if self.implicit_function_class_type == "SRNImplicitFunction": + extra_args["latent_dim"] = global_encoder_dim + return extra_args + + def create_implicit_function(self) -> None: + implicit_function_type = registry.get( + ImplicitFunctionBase, self.implicit_function_class_type + ) + expand_args_fields(implicit_function_type) + + config_name = f"implicit_function_{self.implicit_function_class_type}_args" + config = getattr(self, config_name, None) + if config is None: + raise ValueError(f"{config_name} not present") + + extra_args = self._create_extra_args_for_implicit_function() + self.implicit_function = implicit_function_type(**config, **extra_args) + + def create_coarse_implicit_function(self) -> None: + # If coarse_implicit_function_class_type has been defined + # then we init a module based on its arguments + if ( + self.coarse_implicit_function_class_type is not None + and not self.share_implicit_function_across_passes + ): + config_name = "coarse_implicit_function_{0}_args".format( + self.coarse_implicit_function_class_type + ) + config = getattr(self, config_name, {}) + + implicit_function_type = registry.get( + ImplicitFunctionBase, + # pyre-ignore: config is None allow to check if this is None. + self.coarse_implicit_function_class_type, + ) + expand_args_fields(implicit_function_type) + + extra_args = self._create_extra_args_for_implicit_function() + self.coarse_implicit_function = implicit_function_type( + **config, **extra_args + ) + elif self.share_implicit_function_across_passes: + # Since coarse_implicit_function is initialised before + # implicit_function we handle this case in the post_init. + pass + else: + self.coarse_implicit_function = None diff --git a/pytorch3d/implicitron/models/utils.py b/pytorch3d/implicitron/models/utils.py new file mode 100644 index 00000000..94480cd6 --- /dev/null +++ b/pytorch3d/implicitron/models/utils.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# Note: The #noqa comments below are for unused imports of pluggable implementations +# which are part of implicitron. They ensure that the registry is prepopulated. + +import warnings +from logging import Logger +from typing import Any, Dict, Optional, Tuple + +import torch +import tqdm +from pytorch3d.common.compat import prod + +from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle + +from pytorch3d.implicitron.tools import image_utils + +from pytorch3d.implicitron.tools.utils import cat_dataclass + + +def preprocess_input( + image_rgb: Optional[torch.Tensor], + fg_probability: Optional[torch.Tensor], + depth_map: Optional[torch.Tensor], + mask_images: bool, + mask_depths: bool, + mask_threshold: float, + bg_color: Tuple[float, float, float], +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Helper function to preprocess the input images and optional depth maps + to apply masking if required. + + Args: + image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images + corresponding to the source viewpoints from which features will be extracted + fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch + of foreground masks with values in [0, 1]. + depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. + mask_images: Whether or not to mask the RGB image background given the + foreground mask (the `fg_probability` argument of `GenericModel.forward`) + mask_depths: Whether or not to mask the depth image background given the + foreground mask (the `fg_probability` argument of `GenericModel.forward`) + mask_threshold: If greater than 0.0, the foreground mask is + thresholded by this value before being applied to the RGB/Depth images + bg_color: RGB values for setting the background color of input image + if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own + way to determine the background color of its output, unrelated to this. + + Returns: + Modified image_rgb, fg_mask, depth_map + """ + if image_rgb is not None and image_rgb.ndim == 3: + # The FrameData object is used for both frames and batches of frames, + # and a user might get this error if those were confused. + # Perhaps a user has a FrameData `fd` representing a single frame and + # wrote something like `model(**fd)` instead of + # `model(**fd.collate([fd]))`. + raise ValueError( + "Model received unbatched inputs. " + + "Perhaps they came from a FrameData which had not been collated." + ) + + fg_mask = fg_probability + if fg_mask is not None and mask_threshold > 0.0: + # threshold masks + warnings.warn("Thresholding masks!") + fg_mask = (fg_mask >= mask_threshold).type_as(fg_mask) + + if mask_images and fg_mask is not None and image_rgb is not None: + # mask the image + warnings.warn("Masking images!") + image_rgb = image_utils.mask_background( + image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(bg_color) + ) + + if mask_depths and fg_mask is not None and depth_map is not None: + # mask the depths + assert ( + mask_threshold > 0.0 + ), "Depths should be masked only with thresholded masks" + warnings.warn("Masking depths!") + depth_map = depth_map * fg_mask + + return image_rgb, fg_mask, depth_map + + +def log_loss_weights(loss_weights: Dict[str, float], logger: Logger) -> None: + """ + Print a table of the loss weights. + """ + loss_weights_message = ( + "-------\nloss_weights:\n" + + "\n".join(f"{k:40s}: {w:1.2e}" for k, w in loss_weights.items()) + + "-------" + ) + logger.info(loss_weights_message) + + +def weighted_sum_losses( + preds: Dict[str, torch.Tensor], loss_weights: Dict[str, float] +) -> Optional[torch.Tensor]: + """ + A helper function to compute the overall loss as the dot product + of individual loss functions with the corresponding weights. + """ + losses_weighted = [ + preds[k] * float(w) + for k, w in loss_weights.items() + if (k in preds and w != 0.0) + ] + if len(losses_weighted) == 0: + warnings.warn("No main objective found.") + return None + loss = sum(losses_weighted) + assert torch.is_tensor(loss) + # pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`. + return loss + + +def apply_chunked(func, chunk_generator, tensor_collator): + """ + Helper function to apply a function on a sequence of + chunked inputs yielded by a generator and collate + the result. + """ + processed_chunks = [ + func(*chunk_args, **chunk_kwargs) + for chunk_args, chunk_kwargs in chunk_generator + ] + + return cat_dataclass(processed_chunks, tensor_collator) + + +def chunk_generator( + chunk_size: int, + ray_bundle: ImplicitronRayBundle, + chunked_inputs: Dict[str, torch.Tensor], + tqdm_trigger_threshold: int, + *args, + **kwargs, +): + """ + Helper function which yields chunks of rays from the + input ray_bundle, to be used when the number of rays is + large and will not fit in memory for rendering. + """ + ( + batch_size, + *spatial_dim, + n_pts_per_ray, + ) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray + if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0: + raise ValueError( + f"chunk_size_grid ({chunk_size}) should be divisible " + f"by n_pts_per_ray ({n_pts_per_ray})" + ) + + n_rays = prod(spatial_dim) + # special handling for raytracing-based methods + n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) + chunk_size_in_rays = -(-n_rays // n_chunks) + + iter = range(0, n_rays, chunk_size_in_rays) + if len(iter) >= tqdm_trigger_threshold: + iter = tqdm.tqdm(iter) + + def _safe_slice( + tensor: Optional[torch.Tensor], start_idx: int, end_idx: int + ) -> Any: + return tensor[start_idx:end_idx] if tensor is not None else None + + for start_idx in iter: + end_idx = min(start_idx + chunk_size_in_rays, n_rays) + ray_bundle_chunk = ImplicitronRayBundle( + origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], + directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ + :, start_idx:end_idx + ], + lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[ + :, start_idx:end_idx + ], + xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], + camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx), + camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx), + ) + extra_args = kwargs.copy() + for k, v in chunked_inputs.items(): + extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx] + yield [ray_bundle_chunk, *args], extra_args diff --git a/pytorch3d/implicitron/tools/rasterize_mc.py b/pytorch3d/implicitron/tools/rasterize_mc.py index 645a5bac..3fbf4b8d 100644 --- a/pytorch3d/implicitron/tools/rasterize_mc.py +++ b/pytorch3d/implicitron/tools/rasterize_mc.py @@ -7,8 +7,9 @@ import math from typing import Optional, Tuple +import pytorch3d + import torch -from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle from pytorch3d.ops import packed_to_padded from pytorch3d.renderer import PerspectiveCameras from pytorch3d.structures import Pointclouds @@ -18,7 +19,7 @@ from .point_cloud_utils import render_point_cloud_pytorch3d @torch.no_grad() def rasterize_sparse_ray_bundle( - ray_bundle: ImplicitronRayBundle, + ray_bundle: "pytorch3d.implicitron.models.renderer.base.ImplicitronRayBundle", features: torch.Tensor, image_size_hw: Tuple[int, int], depth: torch.Tensor, diff --git a/tests/implicitron/models/__init__.py b/tests/implicitron/models/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/implicitron/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/implicitron/models/test_overfit_model.py b/tests/implicitron/models/test_overfit_model.py new file mode 100644 index 00000000..15a6a6c6 --- /dev/null +++ b/tests/implicitron/models/test_overfit_model.py @@ -0,0 +1,217 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Any, Dict +from unittest.mock import patch + +import torch +from pytorch3d.implicitron.models.generic_model import GenericModel +from pytorch3d.implicitron.models.overfit_model import OverfitModel +from pytorch3d.implicitron.models.renderer.base import EvaluationMode +from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras + +DEVICE = torch.device("cuda:0") + + +def _generate_fake_inputs(N: int, H: int, W: int) -> Dict[str, Any]: + R, T = look_at_view_transform(azim=torch.rand(N) * 360) + return { + "camera": PerspectiveCameras(R=R, T=T, device=DEVICE), + "fg_probability": torch.randint( + high=2, size=(N, 1, H, W), device=DEVICE + ).float(), + "depth_map": torch.rand((N, 1, H, W), device=DEVICE) + 0.1, + "mask_crop": torch.randint(high=2, size=(N, 1, H, W), device=DEVICE).float(), + "sequence_name": ["sequence"] * N, + "image_rgb": torch.rand((N, 1, H, W), device=DEVICE), + } + + +def mock_safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor: + """Return non deterministic indexes to mock safe_multinomial + + Args: + input: tensor of shape [B, n] containing non-negative values; + rows are interpreted as unnormalized event probabilities + in categorical distributions. + num_samples: number of samples to take. + + Returns: + Tensor of shape [B, num_samples] + """ + batch_size = input.shape[0] + return torch.arange(num_samples).repeat(batch_size, 1).to(DEVICE) + + +class TestOverfitModel(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def test_overfit_model_vs_generic_model_with_batch_size_one(self): + """In this test we compare OverfitModel to GenericModel behavior. + + We use a Nerf setup (2 rendering passes). + + OverfitModel is a specific case of GenericModel. Hence, with the same inputs, + they should provide the exact same results. + """ + expand_args_fields(OverfitModel) + expand_args_fields(GenericModel) + batch_size, image_height, image_width = 1, 80, 80 + assert batch_size == 1 + overfit_model = OverfitModel( + render_image_height=image_height, + render_image_width=image_width, + coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction", + # To avoid randomization to compare the outputs of our model + # we deactivate the stratified_point_sampling_training + raysampler_AdaptiveRaySampler_args={ + "stratified_point_sampling_training": False + }, + global_encoder_class_type="SequenceAutodecoder", + global_encoder_SequenceAutodecoder_args={ + "autodecoder_args": { + "n_instances": 1000, + "init_scale": 1.0, + "encoding_dim": 64, + } + }, + ) + generic_model = GenericModel( + render_image_height=image_height, + render_image_width=image_width, + n_train_target_views=batch_size, + num_passes=2, + # To avoid randomization to compare the outputs of our model + # we deactivate the stratified_point_sampling_training + raysampler_AdaptiveRaySampler_args={ + "stratified_point_sampling_training": False + }, + global_encoder_class_type="SequenceAutodecoder", + global_encoder_SequenceAutodecoder_args={ + "autodecoder_args": { + "n_instances": 1000, + "init_scale": 1.0, + "encoding_dim": 64, + } + }, + ) + + # Check if they do share the number of parameters + num_params_mvm = sum(p.numel() for p in overfit_model.parameters()) + num_params_gm = sum(p.numel() for p in generic_model.parameters()) + self.assertEqual(num_params_mvm, num_params_gm) + + # Adapt the mapping from generic model to overfit model + mapping_om_from_gm = { + key.replace("_implicit_functions.0._fn", "implicit_function").replace( + "_implicit_functions.1._fn", "coarse_implicit_function" + ): val + for key, val in generic_model.state_dict().items() + } + # Copy parameters from generic_model to overfit_model + overfit_model.load_state_dict(mapping_om_from_gm) + + overfit_model.to(DEVICE) + generic_model.to(DEVICE) + inputs_ = _generate_fake_inputs(batch_size, image_height, image_width) + + # training forward pass + overfit_model.train() + generic_model.train() + + with patch( + "pytorch3d.renderer.implicit.raysampling._safe_multinomial", + side_effect=mock_safe_multinomial, + ): + train_preds_om = overfit_model( + **inputs_, + evaluation_mode=EvaluationMode.TRAINING, + ) + train_preds_gm = generic_model( + **inputs_, + evaluation_mode=EvaluationMode.TRAINING, + ) + + self.assertTrue(len(train_preds_om) == len(train_preds_gm)) + + self.assertTrue(train_preds_om["objective"].isfinite().item()) + # We avoid all the randomization and the weights are the same + # The objective should be the same + self.assertTrue( + torch.allclose(train_preds_om["objective"], train_preds_gm["objective"]) + ) + + # Test if the evaluation works + overfit_model.eval() + generic_model.eval() + with torch.no_grad(): + eval_preds_om = overfit_model( + **inputs_, + evaluation_mode=EvaluationMode.EVALUATION, + ) + eval_preds_gm = generic_model( + **inputs_, + evaluation_mode=EvaluationMode.EVALUATION, + ) + + self.assertEqual( + eval_preds_om["images_render"].shape, + (batch_size, 3, image_height, image_width), + ) + self.assertTrue( + torch.allclose(eval_preds_om["objective"], eval_preds_gm["objective"]) + ) + self.assertTrue( + torch.allclose( + eval_preds_om["images_render"], eval_preds_gm["images_render"] + ) + ) + + def test_overfit_model_check_share_weights(self): + model = OverfitModel(share_implicit_function_across_passes=True) + for p1, p2 in zip( + model.implicit_function.parameters(), + model.coarse_implicit_function.parameters(), + ): + self.assertEqual(id(p1), id(p2)) + + model.to(DEVICE) + inputs_ = _generate_fake_inputs(2, 80, 80) + model(**inputs_, evaluation_mode=EvaluationMode.TRAINING) + + def test_overfit_model_check_no_share_weights(self): + model = OverfitModel( + share_implicit_function_across_passes=False, + coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction", + coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args={ + "transformer_dim_down_factor": 1.0, + "n_hidden_neurons_xyz": 256, + "n_layers_xyz": 8, + "append_xyz": (5,), + }, + ) + for p1, p2 in zip( + model.implicit_function.parameters(), + model.coarse_implicit_function.parameters(), + ): + self.assertNotEqual(id(p1), id(p2)) + + model.to(DEVICE) + inputs_ = _generate_fake_inputs(2, 80, 80) + model(**inputs_, evaluation_mode=EvaluationMode.TRAINING) + + def test_overfit_model_coarse_implicit_function_is_none(self): + model = OverfitModel( + share_implicit_function_across_passes=False, + coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args=None, + ) + self.assertIsNone(model.coarse_implicit_function) + model.to(DEVICE) + inputs_ = _generate_fake_inputs(2, 80, 80) + model(**inputs_, evaluation_mode=EvaluationMode.TRAINING) diff --git a/tests/implicitron/models/test_utils.py b/tests/implicitron/models/test_utils.py new file mode 100644 index 00000000..b86f7fd4 --- /dev/null +++ b/tests/implicitron/models/test_utils.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import torch + +from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses + + +class TestUtils(unittest.TestCase): + def test_prepare_inputs_wrong_num_dim(self): + img = torch.randn(3, 3, 3) + with self.assertRaises(ValueError) as context: + img, fg_prob, depth_map = preprocess_input( + img, None, None, True, True, 0.5, (0.0, 0.0, 0.0) + ) + self.assertEqual( + "Model received unbatched inputs. " + + "Perhaps they came from a FrameData which had not been collated.", + context.exception, + ) + + def test_prepare_inputs_mask_image_true(self): + batch, channels, height, width = 2, 3, 10, 10 + img = torch.ones(batch, channels, height, width) + # Create a mask on the lower triangular matrix + fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3 + + out_img, out_fg_prob, out_depth_map = preprocess_input( + img, fg_prob, None, True, False, 0.3, (0.0, 0.0, 0.0) + ) + + self.assertTrue(torch.equal(out_img, torch.tril(img))) + self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3)) + self.assertIsNone(out_depth_map) + + def test_prepare_inputs_mask_depth_true(self): + batch, channels, height, width = 2, 3, 10, 10 + img = torch.ones(batch, channels, height, width) + depth_map = torch.randn(batch, channels, height, width) + # Create a mask on the lower triangular matrix + fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3 + + out_img, out_fg_prob, out_depth_map = preprocess_input( + img, fg_prob, depth_map, False, True, 0.3, (0.0, 0.0, 0.0) + ) + + self.assertTrue(torch.equal(out_img, img)) + self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3)) + self.assertTrue(torch.equal(out_depth_map, torch.tril(depth_map))) + + def test_weighted_sum_losses(self): + preds = {"a": torch.tensor(2), "b": torch.tensor(2)} + weights = {"a": 2.0, "b": 0.0} + loss = weighted_sum_losses(preds, weights) + self.assertEqual(loss, 4.0) + + def test_weighted_sum_losses_raise_warning(self): + preds = {"a": torch.tensor(2), "b": torch.tensor(2)} + weights = {"c": 2.0, "d": 2.0} + self.assertIsNone(weighted_sum_losses(preds, weights))