From 1b0584f7bd2bbf0d6a2e5563a8c530d62f2338ba Mon Sep 17 00:00:00 2001 From: Krzysztof Chalupka Date: Fri, 29 Jul 2022 17:32:51 -0700 Subject: [PATCH] Replace pluggable components to create a proper Configurable hierarchy. Summary: This large diff rewrites a significant portion of Implicitron's config hierarchy. The new hierarchy, and some of the default implementation classes, are as follows: ``` Experiment data_source: ImplicitronDataSource dataset_map_provider data_loader_map_provider model_factory: ImplicitronModelFactory model: GenericModel optimizer_factory: ImplicitronOptimizerFactory training_loop: ImplicitronTrainingLoop evaluator: ImplicitronEvaluator ``` 1) Experiment (used to be ExperimentConfig) is now a top-level Configurable and contains as members mainly (mostly new) high-level factory Configurables. 2) Experiment's job is to run factories, do some accelerate setup and then pass the results to the main training loop. 3) ImplicitronOptimizerFactory and ImplicitronModelFactory are new high-level factories that create the optimizer, scheduler, model, and stats objects. 4) TrainingLoop is a new configurable that runs the main training loop and the inner train-validate step. 5) Evaluator is a new configurable that TrainingLoop uses to run validation/test steps. 6) GenericModel is not the only model choice anymore. Instead, ImplicitronModelBase (by default instantiated with GenericModel) is a member of Experiment and can be easily replaced by a custom implementation by the user. All the new Configurables are children of ReplaceableBase, and can be easily replaced with custom implementations. In addition, I added support for the exponential LR schedule, updated the config files and the test, as well as added a config file that reproduces NERF results and a test to run the repro experiment. Reviewed By: bottler Differential Revision: D37723227 fbshipit-source-id: b36bee880d6aa53efdd2abfaae4489d8ab1e8a27 --- .../configs/repro_base.yaml | 116 +-- .../configs/repro_feat_extractor_normed.yaml | 35 +- .../repro_feat_extractor_transformer.yaml | 35 +- .../repro_feat_extractor_unnormed.yaml | 37 +- .../configs/repro_multiseq_base.yaml | 16 +- .../configs/repro_multiseq_idr_ad.yaml | 111 +-- .../configs/repro_multiseq_nerf_ad.yaml | 17 +- .../configs/repro_multiseq_nerf_wce.yaml | 14 +- .../configs/repro_multiseq_nerformer.yaml | 27 +- .../repro_multiseq_nerformer_angle_w.yaml | 7 +- .../repro_multiseq_srn_ad_hypernet.yaml | 61 +- ...repro_multiseq_srn_ad_hypernet_noharm.yaml | 15 +- .../configs/repro_multiseq_srn_wce.yaml | 51 +- .../repro_multiseq_srn_wce_noharm.yaml | 15 +- .../configs/repro_singleseq_base.yaml | 48 +- .../configs/repro_singleseq_idr.yaml | 95 +-- .../configs/repro_singleseq_nerf_blender.yaml | 38 + .../configs/repro_singleseq_nerf_wce.yaml | 11 +- .../configs/repro_singleseq_nerformer.yaml | 27 +- .../configs/repro_singleseq_srn.yaml | 49 +- .../configs/repro_singleseq_srn_noharm.yaml | 15 +- .../configs/repro_singleseq_srn_wce.yaml | 49 +- .../repro_singleseq_srn_wce_noharm.yaml | 15 +- .../configs/repro_singleseq_wce_base.yaml | 2 +- projects/implicitron_trainer/experiment.py | 712 ++++-------------- .../impl/experiment_config.py | 49 -- .../implicitron_trainer/impl/model_factory.py | 199 +++++ .../implicitron_trainer/impl/optimization.py | 109 --- .../impl/optimizer_factory.py | 197 +++++ .../implicitron_trainer/impl/training_loop.py | 365 +++++++++ .../implicitron_trainer/tests/experiment.yaml | 650 ++++++++-------- .../tests/test_experiment.py | 99 ++- projects/implicitron_trainer/tests/utils.py | 31 + .../visualize_reconstruction.py | 10 +- pytorch3d/implicitron/dataset/data_source.py | 3 + pytorch3d/implicitron/evaluation/evaluator.py | 161 ++++ pytorch3d/implicitron/models/base_model.py | 6 +- pytorch3d/implicitron/models/generic_model.py | 6 +- pytorch3d/implicitron/models/model_dbir.py | 2 +- .../models/renderer/ray_sampler.py | 2 +- pytorch3d/implicitron/tools/stats.py | 2 + pytorch3d/implicitron/tools/vis_utils.py | 14 +- 42 files changed, 2045 insertions(+), 1478 deletions(-) create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_nerf_blender.yaml delete mode 100644 projects/implicitron_trainer/impl/experiment_config.py create mode 100644 projects/implicitron_trainer/impl/model_factory.py delete mode 100644 projects/implicitron_trainer/impl/optimization.py create mode 100644 projects/implicitron_trainer/impl/optimizer_factory.py create mode 100644 projects/implicitron_trainer/impl/training_loop.py create mode 100644 projects/implicitron_trainer/tests/utils.py create mode 100644 pytorch3d/implicitron/evaluation/evaluator.py diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index 8c2c4066..11ebdc9e 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -2,10 +2,10 @@ defaults: - default_config - _self_ exp_dir: ./data/exps/base/ -architecture: generic -visualize_interval: 0 -visdom_port: 8097 -data_source_args: +training_loop_ImplicitronTrainingLoop_args: + visualize_interval: 0 + max_epochs: 1000 +data_source_ImplicitronDataSource_args: data_loader_map_provider_class_type: SequenceDataLoaderMapProvider dataset_map_provider_class_type: JsonIndexDatasetMapProvider data_loader_map_provider_SequenceDataLoaderMapProvider_args: @@ -21,55 +21,61 @@ data_source_args: load_point_clouds: false mask_depths: false mask_images: false -generic_model_args: - loss_weights: - loss_mask_bce: 1.0 - loss_prev_stage_mask_bce: 1.0 - loss_autodecoder_norm: 0.01 - loss_rgb_mse: 1.0 - loss_prev_stage_rgb_mse: 1.0 - output_rasterized_mc: false - chunk_size_grid: 102400 - render_image_height: 400 - render_image_width: 400 - num_passes: 2 - implicit_function_NeuralRadianceFieldImplicitFunction_args: - n_harmonic_functions_xyz: 10 - n_harmonic_functions_dir: 4 - n_hidden_neurons_xyz: 256 - n_hidden_neurons_dir: 128 - n_layers_xyz: 8 - append_xyz: - - 5 - latent_dim: 0 - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 1024 - scene_extent: 8.0 - n_pts_per_ray_training: 64 - n_pts_per_ray_evaluation: 64 - stratified_point_sampling_training: true - stratified_point_sampling_evaluation: false - renderer_MultiPassEmissionAbsorptionRenderer_args: - n_pts_per_ray_fine_training: 64 - n_pts_per_ray_fine_evaluation: 64 - append_coarse_samples_to_fine: true - density_noise_std_train: 1.0 - view_pooler_args: - view_sampler_args: - masked_sampling: false - image_feature_extractor_ResNetFeatureExtractor_args: - stages: - - 1 - - 2 - - 3 - - 4 - proj_dim: 16 - image_rescale: 0.32 - first_max_pool: false -solver_args: - breed: adam - lr: 0.0005 - lr_policy: multistep - max_epochs: 2000 - momentum: 0.9 +model_factory_ImplicitronModelFactory_args: + visdom_port: 8097 + model_GenericModel_args: + loss_weights: + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 1.0 + loss_autodecoder_norm: 0.01 + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + output_rasterized_mc: false + chunk_size_grid: 102400 + render_image_height: 400 + render_image_width: 400 + num_passes: 2 + implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + n_layers_xyz: 8 + append_xyz: + - 5 + latent_dim: 0 + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 1024 + scene_extent: 8.0 + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 64 + n_pts_per_ray_fine_evaluation: 64 + append_coarse_samples_to_fine: true + density_noise_std_train: 1.0 + view_pooler_args: + view_sampler_args: + masked_sampling: false + image_feature_extractor_ResNetFeatureExtractor_args: + stages: + - 1 + - 2 + - 3 + - 4 + proj_dim: 16 + image_rescale: 0.32 + first_max_pool: false +optimizer_factory_ImplicitronOptimizerFactory_args: + breed: Adam weight_decay: 0.0 + lr_policy: MultiStepLR + multistep_lr_milestones: [] + lr: 0.0005 + gamma: 0.1 + momentum: 0.9 + betas: + - 0.9 + - 0.999 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml index 1ea74b23..b2154c8b 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml @@ -1,17 +1,18 @@ -generic_model_args: - image_feature_extractor_class_type: ResNetFeatureExtractor - image_feature_extractor_ResNetFeatureExtractor_args: - add_images: true - add_masks: true - first_max_pool: true - image_rescale: 0.375 - l2_norm: true - name: resnet34 - normalize_image: true - pretrained: true - stages: - - 1 - - 2 - - 3 - - 4 - proj_dim: 32 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + image_feature_extractor_class_type: ResNetFeatureExtractor + image_feature_extractor_ResNetFeatureExtractor_args: + add_images: true + add_masks: true + first_max_pool: true + image_rescale: 0.375 + l2_norm: true + name: resnet34 + normalize_image: true + pretrained: true + stages: + - 1 + - 2 + - 3 + - 4 + proj_dim: 32 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml index 734ab43e..8d24495b 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml @@ -1,17 +1,18 @@ -generic_model_args: - image_feature_extractor_class_type: ResNetFeatureExtractor - image_feature_extractor_ResNetFeatureExtractor_args: - add_images: true - add_masks: true - first_max_pool: false - image_rescale: 0.375 - l2_norm: true - name: resnet34 - normalize_image: true - pretrained: true - stages: - - 1 - - 2 - - 3 - - 4 - proj_dim: 16 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + image_feature_extractor_class_type: ResNetFeatureExtractor + image_feature_extractor_ResNetFeatureExtractor_args: + add_images: true + add_masks: true + first_max_pool: false + image_rescale: 0.375 + l2_norm: true + name: resnet34 + normalize_image: true + pretrained: true + stages: + - 1 + - 2 + - 3 + - 4 + proj_dim: 16 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml index bc7bc37e..2d4eb3f8 100644 --- a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml @@ -1,18 +1,19 @@ -generic_model_args: - image_feature_extractor_class_type: ResNetFeatureExtractor - image_feature_extractor_ResNetFeatureExtractor_args: - stages: - - 1 - - 2 - - 3 - first_max_pool: false - proj_dim: -1 - l2_norm: false - image_rescale: 0.375 - name: resnet34 - normalize_image: true - pretrained: true - view_pooler_args: - feature_aggregator_AngleWeightedReductionFeatureAggregator_args: - reduction_functions: - - AVG +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + image_feature_extractor_class_type: ResNetFeatureExtractor + image_feature_extractor_ResNetFeatureExtractor_args: + stages: + - 1 + - 2 + - 3 + first_max_pool: false + proj_dim: -1 + l2_norm: false + image_rescale: 0.375 + name: resnet34 + normalize_image: true + pretrained: true + view_pooler_args: + feature_aggregator_AngleWeightedReductionFeatureAggregator_args: + reduction_functions: + - AVG diff --git a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml index 7b6dc093..a63cbedf 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml @@ -1,7 +1,7 @@ defaults: - repro_base.yaml - _self_ -data_source_args: +data_source_ImplicitronDataSource_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 dataset_length_train: 1000 @@ -26,10 +26,12 @@ data_source_args: n_frames_per_sequence: -1 test_on_train: true test_restrict_sequence_id: 0 -solver_args: - max_epochs: 3000 - milestones: +optimizer_factory_ImplicitronOptimizerFactory_args: + multistep_lr_milestones: - 1000 -camera_difficulty_bin_breaks: - - 0.666667 - - 0.833334 +training_loop_ImplicitronTrainingLoop_args: + max_epochs: 3000 + evaluator_ImplicitronEvaluator_args: + camera_difficulty_bin_breaks: + - 0.666667 + - 0.833334 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml index 183abae8..56684f6f 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml @@ -1,65 +1,66 @@ defaults: - repro_multiseq_base.yaml - _self_ -generic_model_args: - loss_weights: - loss_mask_bce: 100.0 - loss_kl: 0.0 - loss_rgb_mse: 1.0 - loss_eikonal: 0.1 - chunk_size_grid: 65536 - num_passes: 1 - output_rasterized_mc: true - sampling_mode_training: mask_sample - global_encoder_class_type: SequenceAutodecoder - global_encoder_SequenceAutodecoder_args: - autodecoder_args: - n_instances: 20000 - init_scale: 1.0 - encoding_dim: 256 - implicit_function_IdrFeatureField_args: - n_harmonic_functions_xyz: 6 - bias: 0.6 - d_in: 3 - d_out: 1 - dims: - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - geometric_init: true - pooled_feature_dim: 0 - skip_in: - - 6 - weight_norm: true - renderer_SignedDistanceFunctionRenderer_args: - ray_tracer_args: - line_search_step: 0.5 - line_step_iters: 3 - n_secant_steps: 8 - n_steps: 100 - object_bounding_sphere: 8.0 - sdf_threshold: 5.0e-05 - ray_normal_coloring_network_args: - d_in: 9 - d_out: 3 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + loss_weights: + loss_mask_bce: 100.0 + loss_kl: 0.0 + loss_rgb_mse: 1.0 + loss_eikonal: 0.1 + chunk_size_grid: 65536 + num_passes: 1 + output_rasterized_mc: true + sampling_mode_training: mask_sample + global_encoder_class_type: SequenceAutodecoder + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + n_instances: 20000 + init_scale: 1.0 + encoding_dim: 256 + implicit_function_IdrFeatureField_args: + n_harmonic_functions_xyz: 6 + bias: 0.6 + d_in: 3 + d_out: 1 dims: - 512 - 512 - 512 - 512 - mode: idr - n_harmonic_functions_dir: 4 + - 512 + - 512 + - 512 + - 512 + geometric_init: true pooled_feature_dim: 0 + skip_in: + - 6 weight_norm: true - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 1024 - n_pts_per_ray_training: 0 - n_pts_per_ray_evaluation: 0 - scene_extent: 8.0 - renderer_class_type: SignedDistanceFunctionRenderer - implicit_function_class_type: IdrFeatureField + renderer_SignedDistanceFunctionRenderer_args: + ray_tracer_args: + line_search_step: 0.5 + line_step_iters: 3 + n_secant_steps: 8 + n_steps: 100 + object_bounding_sphere: 8.0 + sdf_threshold: 5.0e-05 + ray_normal_coloring_network_args: + d_in: 9 + d_out: 3 + dims: + - 512 + - 512 + - 512 + - 512 + mode: idr + n_harmonic_functions_dir: 4 + pooled_feature_dim: 0 + weight_norm: true + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 1024 + n_pts_per_ray_training: 0 + n_pts_per_ray_evaluation: 0 + scene_extent: 8.0 + renderer_class_type: SignedDistanceFunctionRenderer + implicit_function_class_type: IdrFeatureField diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml index a8b99df8..aa4291d3 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml @@ -1,11 +1,12 @@ defaults: - repro_multiseq_base.yaml - _self_ -generic_model_args: - chunk_size_grid: 16000 - view_pooler_enabled: false - global_encoder_class_type: SequenceAutodecoder - global_encoder_SequenceAutodecoder_args: - autodecoder_args: - n_instances: 20000 - encoding_dim: 256 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 16000 + view_pooler_enabled: false + global_encoder_class_type: SequenceAutodecoder + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + n_instances: 20000 + encoding_dim: 256 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml index 00140db6..fa366d46 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml @@ -2,9 +2,11 @@ defaults: - repro_multiseq_base.yaml - repro_feat_extractor_unnormed.yaml - _self_ -clip_grad: 1.0 -generic_model_args: - chunk_size_grid: 16000 - view_pooler_enabled: true - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 850 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 16000 + view_pooler_enabled: true + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 850 +training_loop_ImplicitronTrainingLoop_args: + clip_grad: 1.0 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml index c4a20f6e..9aa9f4c5 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml @@ -2,16 +2,17 @@ defaults: - repro_multiseq_base.yaml - repro_feat_extractor_transformer.yaml - _self_ -generic_model_args: - chunk_size_grid: 16000 - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 800 - n_pts_per_ray_training: 32 - n_pts_per_ray_evaluation: 32 - renderer_MultiPassEmissionAbsorptionRenderer_args: - n_pts_per_ray_fine_training: 16 - n_pts_per_ray_fine_evaluation: 16 - implicit_function_class_type: NeRFormerImplicitFunction - view_pooler_enabled: true - view_pooler_args: - feature_aggregator_class_type: IdentityFeatureAggregator +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 16000 + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 800 + n_pts_per_ray_training: 32 + n_pts_per_ray_evaluation: 32 + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 16 + n_pts_per_ray_fine_evaluation: 16 + implicit_function_class_type: NeRFormerImplicitFunction + view_pooler_enabled: true + view_pooler_args: + feature_aggregator_class_type: IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml index 61f6ebb4..9c9a30fe 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml @@ -1,6 +1,7 @@ defaults: - repro_multiseq_nerformer.yaml - _self_ -generic_model_args: - view_pooler_args: - feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + view_pooler_args: + feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml index 8d88f736..1b4a2ef2 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml @@ -1,34 +1,35 @@ defaults: - repro_multiseq_base.yaml - _self_ -generic_model_args: - chunk_size_grid: 16000 - view_pooler_enabled: false - n_train_target_views: -1 - num_passes: 1 - loss_weights: - loss_rgb_mse: 200.0 - loss_prev_stage_rgb_mse: 0.0 - loss_mask_bce: 1.0 - loss_prev_stage_mask_bce: 0.0 - loss_autodecoder_norm: 0.001 - depth_neg_penalty: 10000.0 - global_encoder_class_type: SequenceAutodecoder - global_encoder_SequenceAutodecoder_args: - autodecoder_args: - encoding_dim: 256 - n_instances: 20000 - raysampler_class_type: NearFarRaySampler - raysampler_NearFarRaySampler_args: - n_rays_per_image_sampled_from_mask: 2048 - min_depth: 0.05 - max_depth: 0.05 - n_pts_per_ray_training: 1 - n_pts_per_ray_evaluation: 1 - stratified_point_sampling_training: false - stratified_point_sampling_evaluation: false - renderer_class_type: LSTMRenderer - implicit_function_class_type: SRNHyperNetImplicitFunction -solver_args: - breed: adam +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 16000 + view_pooler_enabled: false + n_train_target_views: -1 + num_passes: 1 + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.001 + depth_neg_penalty: 10000.0 + global_encoder_class_type: SequenceAutodecoder + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + encoding_dim: 256 + n_instances: 20000 + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNHyperNetImplicitFunction +optimizer_factory_ImplicitronOptimizerFactory_args: + breed: Adam lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml index 42355955..9f29cbbe 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml @@ -1,10 +1,11 @@ defaults: - repro_multiseq_srn_ad_hypernet.yaml - _self_ -generic_model_args: - num_passes: 1 - implicit_function_SRNHyperNetImplicitFunction_args: - pixel_generator_args: - n_harmonic_functions: 0 - hypernet_args: - n_harmonic_functions: 0 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + num_passes: 1 + implicit_function_SRNHyperNetImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + hypernet_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml index d340c18a..4a72c326 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml @@ -2,29 +2,30 @@ defaults: - repro_multiseq_base.yaml - repro_feat_extractor_normed.yaml - _self_ -generic_model_args: - chunk_size_grid: 32000 - num_passes: 1 - n_train_target_views: -1 - loss_weights: - loss_rgb_mse: 200.0 - loss_prev_stage_rgb_mse: 0.0 - loss_mask_bce: 1.0 - loss_prev_stage_mask_bce: 0.0 - loss_autodecoder_norm: 0.0 - depth_neg_penalty: 10000.0 - raysampler_class_type: NearFarRaySampler - raysampler_NearFarRaySampler_args: - n_rays_per_image_sampled_from_mask: 2048 - min_depth: 0.05 - max_depth: 0.05 - n_pts_per_ray_training: 1 - n_pts_per_ray_evaluation: 1 - stratified_point_sampling_training: false - stratified_point_sampling_evaluation: false - renderer_class_type: LSTMRenderer - implicit_function_class_type: SRNImplicitFunction - view_pooler_enabled: true -solver_args: - breed: adam +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 32000 + num_passes: 1 + n_train_target_views: -1 + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.0 + depth_neg_penalty: 10000.0 + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNImplicitFunction + view_pooler_enabled: true +optimizer_factory_ImplicitronOptimizerFactory_args: + breed: Adam lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml index e80d1cb9..d2ea11e3 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml @@ -1,10 +1,11 @@ defaults: - repro_multiseq_srn_wce.yaml - _self_ -generic_model_args: - num_passes: 1 - implicit_function_SRNImplicitFunction_args: - pixel_generator_args: - n_harmonic_functions: 0 - raymarch_function_args: - n_harmonic_functions: 0 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + num_passes: 1 + implicit_function_SRNImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + raymarch_function_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml index 977d5dd9..572fc7d5 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml @@ -1,7 +1,7 @@ defaults: - repro_base - _self_ -data_source_args: +data_source_ImplicitronDataSource_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 dataset_length_train: 1000 @@ -12,28 +12,30 @@ data_source_args: n_frames_per_sequence: -1 test_restrict_sequence_id: 0 test_on_train: false -generic_model_args: - render_image_height: 800 - render_image_width: 800 - log_vars: - - loss_rgb_psnr_fg - - loss_rgb_psnr - - loss_eikonal - - loss_prev_stage_rgb_psnr - - loss_mask_bce - - loss_prev_stage_mask_bce - - loss_rgb_mse - - loss_prev_stage_rgb_mse - - loss_depth_abs - - loss_depth_abs_fg - - loss_kl - - loss_mask_neg_iou - - objective - - epoch - - sec/it -solver_args: +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + render_image_height: 800 + render_image_width: 800 + log_vars: + - loss_rgb_psnr_fg + - loss_rgb_psnr + - loss_eikonal + - loss_prev_stage_rgb_psnr + - loss_mask_bce + - loss_prev_stage_mask_bce + - loss_rgb_mse + - loss_prev_stage_rgb_mse + - loss_depth_abs + - loss_depth_abs_fg + - loss_kl + - loss_mask_neg_iou + - objective + - epoch + - sec/it +optimizer_factory_ImplicitronOptimizerFactory_args: lr: 0.0005 - max_epochs: 400 - milestones: + multistep_lr_milestones: - 200 - 300 +training_loop_ImplicitronTrainingLoop_args: + max_epochs: 400 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml index bb587056..c936d092 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml @@ -1,57 +1,58 @@ defaults: - repro_singleseq_base - _self_ -generic_model_args: - loss_weights: - loss_mask_bce: 100.0 - loss_kl: 0.0 - loss_rgb_mse: 1.0 - loss_eikonal: 0.1 - chunk_size_grid: 65536 - num_passes: 1 - view_pooler_enabled: false - implicit_function_IdrFeatureField_args: - n_harmonic_functions_xyz: 6 - bias: 0.6 - d_in: 3 - d_out: 1 - dims: - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - geometric_init: true - pooled_feature_dim: 0 - skip_in: - - 6 - weight_norm: true - renderer_SignedDistanceFunctionRenderer_args: - ray_tracer_args: - line_search_step: 0.5 - line_step_iters: 3 - n_secant_steps: 8 - n_steps: 100 - object_bounding_sphere: 8.0 - sdf_threshold: 5.0e-05 - ray_normal_coloring_network_args: - d_in: 9 - d_out: 3 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + loss_weights: + loss_mask_bce: 100.0 + loss_kl: 0.0 + loss_rgb_mse: 1.0 + loss_eikonal: 0.1 + chunk_size_grid: 65536 + num_passes: 1 + view_pooler_enabled: false + implicit_function_IdrFeatureField_args: + n_harmonic_functions_xyz: 6 + bias: 0.6 + d_in: 3 + d_out: 1 dims: - 512 - 512 - 512 - 512 - mode: idr - n_harmonic_functions_dir: 4 + - 512 + - 512 + - 512 + - 512 + geometric_init: true pooled_feature_dim: 0 + skip_in: + - 6 weight_norm: true - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 1024 - n_pts_per_ray_training: 0 - n_pts_per_ray_evaluation: 0 - renderer_class_type: SignedDistanceFunctionRenderer - implicit_function_class_type: IdrFeatureField + renderer_SignedDistanceFunctionRenderer_args: + ray_tracer_args: + line_search_step: 0.5 + line_step_iters: 3 + n_secant_steps: 8 + n_steps: 100 + object_bounding_sphere: 8.0 + sdf_threshold: 5.0e-05 + ray_normal_coloring_network_args: + d_in: 9 + d_out: 3 + dims: + - 512 + - 512 + - 512 + - 512 + mode: idr + n_harmonic_functions_dir: 4 + pooled_feature_dim: 0 + weight_norm: true + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 1024 + n_pts_per_ray_training: 0 + n_pts_per_ray_evaluation: 0 + renderer_class_type: SignedDistanceFunctionRenderer + implicit_function_class_type: IdrFeatureField diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf_blender.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf_blender.yaml new file mode 100644 index 00000000..6f6e2aaa --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf_blender.yaml @@ -0,0 +1,38 @@ +defaults: +- repro_singleseq_base +- _self_ +exp_dir: "./data/nerf_blender_publ/${oc.env:BLENDER_SINGLESEQ_CLASS}" +data_source_ImplicitronDataSource_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: + dataset_length_train: 100 + dataset_map_provider_class_type: BlenderDatasetMapProvider + dataset_map_provider_BlenderDatasetMapProvider_args: + base_dir: ${oc.env:BLENDER_DATASET_ROOT} + object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS} + path_manager_factory_class_type: PathManagerFactory + n_known_frames_for_test: null + path_manager_factory_PathManagerFactory_args: + silence_logs: true + +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 4096 + scene_extent: 2.0 + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 128 + n_pts_per_ray_fine_evaluation: 128 + loss_weights: + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + loss_mask_bce: 0.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.00 + +optimizer_factory_ImplicitronOptimizerFactory_args: + exponential_lr_step_size: 2000 + +training_loop_ImplicitronTrainingLoop_args: + max_epochs: 2000 + visualize_interval: 0 + validation_interval: 30 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml index 5a587dbe..38212e35 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml @@ -2,8 +2,9 @@ defaults: - repro_singleseq_wce_base.yaml - repro_feat_extractor_unnormed.yaml - _self_ -generic_model_args: - chunk_size_grid: 16000 - view_pooler_enabled: true - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 850 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 16000 + view_pooler_enabled: true + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml index 37b08dfa..8983c26f 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml @@ -2,16 +2,17 @@ defaults: - repro_singleseq_wce_base.yaml - repro_feat_extractor_transformer.yaml - _self_ -generic_model_args: - chunk_size_grid: 16000 - view_pooler_enabled: true - implicit_function_class_type: NeRFormerImplicitFunction - raysampler_AdaptiveRaySampler_args: - n_rays_per_image_sampled_from_mask: 800 - n_pts_per_ray_training: 32 - n_pts_per_ray_evaluation: 32 - renderer_MultiPassEmissionAbsorptionRenderer_args: - n_pts_per_ray_fine_training: 16 - n_pts_per_ray_fine_evaluation: 16 - view_pooler_args: - feature_aggregator_class_type: IdentityFeatureAggregator +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + chunk_size_grid: 16000 + view_pooler_enabled: true + implicit_function_class_type: NeRFormerImplicitFunction + raysampler_AdaptiveRaySampler_args: + n_rays_per_image_sampled_from_mask: 800 + n_pts_per_ray_training: 32 + n_pts_per_ray_evaluation: 32 + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 16 + n_pts_per_ray_fine_evaluation: 16 + view_pooler_args: + feature_aggregator_class_type: IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml index 6500f56a..1f60f0b9 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml @@ -1,28 +1,29 @@ defaults: - repro_singleseq_base.yaml - _self_ -generic_model_args: - num_passes: 1 - chunk_size_grid: 32000 - view_pooler_enabled: false - loss_weights: - loss_rgb_mse: 200.0 - loss_prev_stage_rgb_mse: 0.0 - loss_mask_bce: 1.0 - loss_prev_stage_mask_bce: 0.0 - loss_autodecoder_norm: 0.0 - depth_neg_penalty: 10000.0 - raysampler_class_type: NearFarRaySampler - raysampler_NearFarRaySampler_args: - n_rays_per_image_sampled_from_mask: 2048 - min_depth: 0.05 - max_depth: 0.05 - n_pts_per_ray_training: 1 - n_pts_per_ray_evaluation: 1 - stratified_point_sampling_training: false - stratified_point_sampling_evaluation: false - renderer_class_type: LSTMRenderer - implicit_function_class_type: SRNImplicitFunction -solver_args: - breed: adam +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + num_passes: 1 + chunk_size_grid: 32000 + view_pooler_enabled: false + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.0 + depth_neg_penalty: 10000.0 + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNImplicitFunction +optimizer_factory_ImplicitronOptimizerFactory_args: + breed: Adam lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml index dd81241c..28b7570c 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml @@ -1,10 +1,11 @@ defaults: - repro_singleseq_srn.yaml - _self_ -generic_model_args: - num_passes: 1 - implicit_function_SRNImplicitFunction_args: - pixel_generator_args: - n_harmonic_functions: 0 - raymarch_function_args: - n_harmonic_functions: 0 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + num_passes: 1 + implicit_function_SRNImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + raymarch_function_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml index 3da29f06..d190c280 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml @@ -2,28 +2,29 @@ defaults: - repro_singleseq_wce_base - repro_feat_extractor_normed.yaml - _self_ -generic_model_args: - num_passes: 1 - chunk_size_grid: 32000 - view_pooler_enabled: true - loss_weights: - loss_rgb_mse: 200.0 - loss_prev_stage_rgb_mse: 0.0 - loss_mask_bce: 1.0 - loss_prev_stage_mask_bce: 0.0 - loss_autodecoder_norm: 0.0 - depth_neg_penalty: 10000.0 - raysampler_class_type: NearFarRaySampler - raysampler_NearFarRaySampler_args: - n_rays_per_image_sampled_from_mask: 2048 - min_depth: 0.05 - max_depth: 0.05 - n_pts_per_ray_training: 1 - n_pts_per_ray_evaluation: 1 - stratified_point_sampling_training: false - stratified_point_sampling_evaluation: false - renderer_class_type: LSTMRenderer - implicit_function_class_type: SRNImplicitFunction -solver_args: - breed: adam +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + num_passes: 1 + chunk_size_grid: 32000 + view_pooler_enabled: true + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.0 + depth_neg_penalty: 10000.0 + raysampler_class_type: NearFarRaySampler + raysampler_NearFarRaySampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNImplicitFunction +optimizer_factory_ImplicitronOptimizerFactory_args: + breed: Adam lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml index 2a0c3fd9..3fc1254b 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml @@ -1,10 +1,11 @@ defaults: - repro_singleseq_srn_wce.yaml - _self_ -generic_model_args: - num_passes: 1 - implicit_function_SRNImplicitFunction_args: - pixel_generator_args: - n_harmonic_functions: 0 - raymarch_function_args: - n_harmonic_functions: 0 +model_factory_ImplicitronModelFactory_args: + model_GenericModel_args: + num_passes: 1 + implicit_function_SRNImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + raymarch_function_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml index 85a5cdfc..f5b174c0 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml @@ -1,7 +1,7 @@ defaults: - repro_singleseq_base - _self_ -data_source_args: +data_source_ImplicitronDataSource_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 dataset_length_train: 1000 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 16bea062..02566288 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -8,27 +8,28 @@ """" This file is the entry point for launching experiments with Implicitron. -Main functions ---------------- -- `run_training` is the wrapper for the train, val, test loops - and checkpointing -- `trainvalidate` is the inner loop which runs the model forward/backward - pass, visualizations and metric printing - Launch Training --------------- Experiment config .yaml files are located in the -`projects/implicitron_trainer/configs` folder. To launch -an experiment, specify the name of the file. Specific config values can -also be overridden from the command line, for example: +`projects/implicitron_trainer/configs` folder. To launch an experiment, +specify the name of the file. Specific config values can also be overridden +from the command line, for example: ``` ./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84 ``` -To run an experiment on a specific GPU, specify the `gpu_idx` key -in the config file / CLI. To run on a different device, specify the -device in `run_training`. +To run an experiment on a specific GPU, specify the `gpu_idx` key in the +config file / CLI. To run on a different device, specify the device in +`run_training`. + +Main functions +--------------- +- The Experiment class defines `run` which creates the model, optimizer, and other + objects used in training, then starts TrainingLoop's `run` function. +- TrainingLoop takes care of the actual training logic: forward and backward passes, + evaluation and testing, as well as model checkpointing, visualization, and metric + printing. Outputs -------- @@ -45,43 +46,38 @@ The outputs of the experiment are saved and logged in multiple ways: config file. """ -import copy -import json import logging import os -import random -import time import warnings -from typing import Any, Dict, Optional, Tuple + +from dataclasses import field import hydra -import lpips -import numpy as np import torch -import tqdm from accelerate import Accelerator from omegaconf import DictConfig, OmegaConf from packaging import version -from pytorch3d.implicitron.dataset import utils as ds_utils -from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap -from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task -from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap -from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate -from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel + +from pytorch3d.implicitron.dataset.data_source import ( + DataSourceBase, + ImplicitronDataSource, +) +from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase + from pytorch3d.implicitron.models.renderer.multipass_ea import ( MultiPassEmissionAbsorptionRenderer, ) from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler -from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools.config import ( + Configurable, expand_args_fields, remove_unused_components, + run_auto_creation, ) -from pytorch3d.implicitron.tools.stats import Stats -from pytorch3d.renderer.cameras import CamerasBase -from .impl.experiment_config import ExperimentConfig -from .impl.optimization import init_optimizer +from .impl.model_factory import ModelFactoryBase +from .impl.optimizer_factory import OptimizerFactoryBase +from .impl.training_loop import TrainingLoopBase logger = logging.getLogger(__name__) @@ -100,551 +96,146 @@ except ModuleNotFoundError: no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None -def init_model( - *, - cfg: DictConfig, - accelerator: Optional[Accelerator] = None, - force_load: bool = False, - clear_stats: bool = False, - load_model_only: bool = False, -) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]: +class Experiment(Configurable): # pyre-ignore: 13 """ - Returns an instance of `GenericModel`. + This class is at the top level of Implicitron's config hierarchy. Its + members are high-level components necessary for training an implicit rende- + ring network. - If `cfg.resume` is set or `force_load` is true, - attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so - will return the model with initial weights, unless `force_load` is passed, - in which case a FileNotFoundError is raised. - - Args: - force_load: If true, force load model from checkpoint even if - cfg.resume is false. - clear_stats: If true, clear the stats object loaded from checkpoint - load_model_only: If true, load only the model weights from checkpoint - and do not load the state of the optimizer and stats. - - Returns: - model: The model with optionally loaded weights from checkpoint - stats: The stats structure (optionally loaded from checkpoint) - optimizer_state: The optimizer state dict containing - `state` and `param_groups` keys (optionally loaded from checkpoint) - - Raise: - FileNotFoundError if `force_load` is passed but checkpoint is not found. + Members: + data_source: An object that produces datasets and dataloaders. + model_factory: An object that produces an implicit rendering model as + well as its corresponding Stats object. + optimizer_factory: An object that produces the optimizer and lr + scheduler. + training_loop: An object that runs training given the outputs produced + by the data_source, model_factory and optimizer_factory. + detect_anomaly: Whether torch.autograd should detect anomalies. Useful + for debugging, but might slow down the training. + exp_dir: Root experimentation directory. Checkpoints and training stats + will be saved here. """ - # Initialize the model - if cfg.architecture == "generic": - model = GenericModel(**cfg.generic_model_args) - else: - raise ValueError(f"No such arch {cfg.architecture}.") + data_source: DataSourceBase + data_source_class_type: str = "ImplicitronDataSource" + model_factory: ModelFactoryBase + model_factory_class_type: str = "ImplicitronModelFactory" + optimizer_factory: OptimizerFactoryBase + optimizer_factory_class_type: str = "ImplicitronOptimizerFactory" + training_loop: TrainingLoopBase + training_loop_class_type: str = "ImplicitronTrainingLoop" - # Determine the network outputs that should be logged - if hasattr(model, "log_vars"): - log_vars = copy.deepcopy(list(model.log_vars)) - else: - log_vars = ["objective"] + detect_anomaly: bool = False + exp_dir: str = "./data/default_experiment/" - visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts" - - # Init the stats struct - stats = Stats( - log_vars, - visdom_env=visdom_env_charts, - verbose=False, - visdom_server=cfg.visdom_server, - visdom_port=cfg.visdom_port, + hydra: dict = field( + default_factory=lambda: { + "run": {"dir": "."}, # Make hydra not change the working dir. + "output_subdir": None, # disable storing the .hydra logs + } ) - # Retrieve the last checkpoint - if cfg.resume_epoch > 0: - model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch) - else: - model_path = model_io.find_last_checkpoint(cfg.exp_dir) + def __post_init__(self): + run_auto_creation(self) - optimizer_state = None - if model_path is not None: - logger.info("found previous model %s" % model_path) - if force_load or cfg.resume: - logger.info(" -> resuming") + def run(self) -> None: + # Make sure the config settings are self-consistent. + self._check_config_consistent() - map_location = None - if accelerator is not None and not accelerator.is_local_main_process: - map_location = { - "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index - } - if load_model_only: - model_state_dict = torch.load( - model_io.get_model_path(model_path), map_location=map_location - ) - stats_load, optimizer_state = None, None - else: - model_state_dict, stats_load, optimizer_state = model_io.load_model( - model_path, map_location=map_location - ) - - # Determine if stats should be reset - if not clear_stats: - if stats_load is None: - logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n") - last_epoch = model_io.parse_epoch_from_model_path(model_path) - logger.info(f"Estimated resume epoch = {last_epoch}") - - # Reset the stats struct - for _ in range(last_epoch + 1): - stats.new_epoch() - assert last_epoch == stats.epoch - else: - stats = stats_load - - # Update stats properties incase it was reset on load - stats.visdom_env = visdom_env_charts - stats.visdom_server = cfg.visdom_server - stats.visdom_port = cfg.visdom_port - stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf") - stats.synchronize_logged_vars(log_vars) - else: - logger.info(" -> clearing stats") - - try: - # TODO: fix on creation of the buffers - # after the hack above, this will not pass in most cases - # ... but this is fine for now - model.load_state_dict(model_state_dict, strict=True) - except RuntimeError as e: - logger.error(e) - logger.info("Cant load state dict in strict mode! -> trying non-strict") - model.load_state_dict(model_state_dict, strict=False) - model.log_vars = log_vars + # Initialize the accelerator if desired. + if no_accelerate: + accelerator = None + device = torch.device("cuda:0") else: - logger.info(" -> but not resuming -> starting from scratch") - elif force_load: - raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!") + accelerator = Accelerator(device_placement=False) + logger.info(accelerator.state) + device = accelerator.device - return model, stats, optimizer_state + logger.info(f"Running experiment on device: {device}") + os.makedirs(self.exp_dir, exist_ok=True) + # set the debug mode + if self.detect_anomaly: + logger.info("Anomaly detection!") + torch.autograd.set_detect_anomaly(self.detect_anomaly) -def trainvalidate( - model, - stats, - epoch, - loader, - optimizer, - validation: bool, - *, - accelerator: Optional[Accelerator], - device: torch.device, - bp_var: str = "objective", - metric_print_interval: int = 5, - visualize_interval: int = 100, - visdom_env_root: str = "trainvalidate", - clip_grad: float = 0.0, - **kwargs, -) -> None: - """ - This is the main loop for training and evaluation including: - model forward pass, loss computation, backward pass and visualization. + # Initialize the datasets and dataloaders. + datasets, dataloaders = self.data_source.get_datasets_and_dataloaders() - Args: - model: The model module optionally loaded from checkpoint - stats: The stats struct, also optionally loaded from checkpoint - epoch: The index of the current epoch - loader: The dataloader to use for the loop - optimizer: The optimizer module optionally loaded from checkpoint - validation: If true, run the loop with the model in eval mode - and skip the backward pass - bp_var: The name of the key in the model output `preds` dict which - should be used as the loss for the backward pass. - metric_print_interval: The batch interval at which the stats should be - logged. - visualize_interval: The batch interval at which the visualizations - should be plotted - visdom_env_root: The name of the visdom environment to use for plotting - clip_grad: Optionally clip the gradient norms. - If set to a value <=0.0, no clipping - device: The device on which to run the model. - - Returns: - None - """ - - if validation: - model.eval() - trainmode = "val" - else: - model.train() - trainmode = "train" - - t_start = time.time() - - # get the visdom env name - visdom_env_imgs = visdom_env_root + "_images_" + trainmode - viz = vis_utils.get_visdom_connection( - server=stats.visdom_server, - port=stats.visdom_port, - ) - - # Iterate through the batches - n_batches = len(loader) - for it, net_input in enumerate(loader): - last_iter = it == n_batches - 1 - - # move to gpu where possible (in place) - net_input = net_input.to(device) - - # run the forward pass - if not validation: - optimizer.zero_grad() - preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}) - else: - with torch.no_grad(): - preds = model( - **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION} - ) - - # make sure we dont overwrite something - assert all(k not in preds for k in net_input.keys()) - # merge everything into one big dict - preds.update(net_input) - - # update the stats logger - stats.update(preds, time_start=t_start, stat_set=trainmode) - assert stats.it[trainmode] == it, "inconsistent stat iteration number!" - - # print textual status update - if it % metric_print_interval == 0 or last_iter: - stats.print(stat_set=trainmode, max_it=n_batches) - - # visualize results - if ( - (accelerator is None or accelerator.is_local_main_process) - and visualize_interval > 0 - and it % visualize_interval == 0 - ): - prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" - - model.visualize( - viz, - visdom_env_imgs, - preds, - prefix, - ) - - # optimizer step - if not validation: - loss = preds[bp_var] - assert torch.isfinite(loss).all(), "Non-finite loss!" - # backprop - if accelerator is None: - loss.backward() - else: - accelerator.backward(loss) - if clip_grad > 0.0: - # Optionally clip the gradient norms. - total_norm = torch.nn.utils.clip_grad_norm( - model.parameters(), clip_grad - ) - if total_norm > clip_grad: - logger.info( - f"Clipping gradient: {total_norm}" - + f" with coef {clip_grad / float(total_norm)}." - ) - - optimizer.step() - - -def run_training(cfg: DictConfig) -> None: - """ - Entry point to run the training and validation loops - based on the specified config file. - """ - - # Initialize the accelerator - if no_accelerate: - accelerator = None - device = torch.device("cuda:0") - else: - accelerator = Accelerator(device_placement=False) - logger.info(accelerator.state) - device = accelerator.device - - logger.info(f"Running experiment on device: {device}") - - # set the debug mode - if cfg.detect_anomaly: - logger.info("Anomaly detection!") - torch.autograd.set_detect_anomaly(cfg.detect_anomaly) - - # create the output folder - os.makedirs(cfg.exp_dir, exist_ok=True) - _seed_all_random_engines(cfg.seed) - remove_unused_components(cfg) - - # dump the exp config to the exp dir - try: - cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml") - OmegaConf.save(config=cfg, f=cfg_filename) - except PermissionError: - warnings.warn("Cant dump config due to insufficient permissions!") - - # setup datasets - datasource = ImplicitronDataSource(**cfg.data_source_args) - datasets, dataloaders = datasource.get_datasets_and_dataloaders() - task = datasource.get_task() - - # init the model - model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator) - start_epoch = stats.epoch + 1 - - # move model to gpu - model.to(device) - - # only run evaluation on the test dataloader - if cfg.eval_only: - _eval_and_dump( - cfg, - task, - datasource.all_train_cameras, - datasets, - dataloaders, - model, - stats, - device=device, + # Init the model and the corresponding Stats object. + model = self.model_factory( + accelerator=accelerator, + exp_dir=self.exp_dir, ) - return - # init the optimizer - optimizer, scheduler = init_optimizer( - model, - optimizer_state=optimizer_state, - last_epoch=start_epoch, - **cfg.solver_args, - ) + stats = self.model_factory.load_stats( + exp_dir=self.exp_dir, + log_vars=model.log_vars, + ) + start_epoch = stats.epoch + 1 - # check the scheduler and stats have been initialized correctly - assert scheduler.last_epoch == stats.epoch + 1 - assert scheduler.last_epoch == start_epoch + model.to(device) - # Wrap all modules in the distributed library - # Note: we don't pass the scheduler to prepare as it - # doesn't need to be stepped at each optimizer step - train_loader = dataloaders.train - val_loader = dataloaders.val - if accelerator is not None: - ( - model, - optimizer, - train_loader, - val_loader, - ) = accelerator.prepare(model, optimizer, train_loader, val_loader) + # Init the optimizer and LR scheduler. + optimizer, scheduler = self.optimizer_factory( + accelerator=accelerator, + exp_dir=self.exp_dir, + last_epoch=start_epoch, + model=model, + ) - past_scheduler_lrs = [] - # loop through epochs - for epoch in range(start_epoch, cfg.solver_args.max_epochs): - # automatic new_epoch and plotting of stats at every epoch start - with stats: - - # Make sure to re-seed random generators to ensure reproducibility - # even after restart. - _seed_all_random_engines(cfg.seed + epoch) - - cur_lr = float(scheduler.get_last_lr()[-1]) - logger.info(f"scheduler lr = {cur_lr:1.2e}") - past_scheduler_lrs.append(cur_lr) - - # train loop - trainvalidate( + # Wrap all modules in the distributed library + # Note: we don't pass the scheduler to prepare as it + # doesn't need to be stepped at each optimizer step + train_loader = dataloaders.train + val_loader = dataloaders.val + test_loader = dataloaders.test + if accelerator is not None: + ( model, - stats, - epoch, - train_loader, optimizer, - False, - visdom_env_root=vis_utils.get_visdom_env(cfg), - device=device, - accelerator=accelerator, - **cfg, - ) + train_loader, + val_loader, + ) = accelerator.prepare(model, optimizer, train_loader, val_loader) - # val loop (optional) - if val_loader is not None and epoch % cfg.validation_interval == 0: - trainvalidate( - model, - stats, - epoch, - val_loader, - optimizer, - True, - visdom_env_root=vis_utils.get_visdom_env(cfg), - device=device, - accelerator=accelerator, - **cfg, - ) + task = self.data_source.get_task() + all_train_cameras = self.data_source.all_train_cameras - # eval loop (optional) - if ( - dataloaders.test is not None - and cfg.test_interval > 0 - and epoch % cfg.test_interval == 0 - ): - _run_eval( - model, - datasource.all_train_cameras, - dataloaders.test, - task, - camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, - device=device, - ) - - assert stats.epoch == epoch, "inconsistent stats!" - - # delete previous models if required - # save model only on the main process - if cfg.store_checkpoints and ( - accelerator is None or accelerator.is_local_main_process - ): - if cfg.store_checkpoints_purge > 0: - for prev_epoch in range(epoch - cfg.store_checkpoints_purge): - model_io.purge_epoch(cfg.exp_dir, prev_epoch) - outfile = model_io.get_checkpoint(cfg.exp_dir, epoch) - unwrapped_model = ( - model if accelerator is None else accelerator.unwrap_model(model) - ) - model_io.safe_save_model( - unwrapped_model, stats, outfile, optimizer=optimizer - ) - - scheduler.step() - - new_lr = float(scheduler.get_last_lr()[-1]) - if new_lr != cur_lr: - logger.info(f"LR change! {cur_lr} -> {new_lr}") - - if cfg.test_when_finished: - _eval_and_dump( - cfg, - task, - datasource.all_train_cameras, - datasets, - dataloaders, - model, - stats, + # Enter the main training loop. + self.training_loop.run( + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + model=model, + optimizer=optimizer, + scheduler=scheduler, + all_train_cameras=all_train_cameras, + accelerator=accelerator, device=device, + exp_dir=self.exp_dir, + stats=stats, + task=task, ) - -def _eval_and_dump( - cfg, - task: Task, - all_train_cameras: Optional[CamerasBase], - datasets: DatasetMap, - dataloaders: DataLoaderMap, - model, - stats, - device, -) -> None: - """ - Run the evaluation loop with the test data loader and - save the predictions to the `exp_dir`. - """ - - dataloader = dataloaders.test - - if dataloader is None: - raise ValueError('DataLoaderMap have to contain the "test" entry for eval!') - - results = _run_eval( - model, - all_train_cameras, - dataloader, - task, - camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, - device=device, - ) - - # add the evaluation epoch to the results - for r in results: - r["eval_epoch"] = int(stats.epoch) - - logger.info("Evaluation results") - evaluate.pretty_print_nvs_metrics(results) - - with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f: - json.dump(results, f) - - -def _get_eval_frame_data(frame_data): - """ - Masks the unknown image data to make sure we cannot use it at model evaluation time. - """ - frame_data_for_eval = copy.deepcopy(frame_data) - is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as( - frame_data.image_rgb - )[:, None, None, None] - for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"): - value_masked = getattr(frame_data_for_eval, k).clone() * is_known - setattr(frame_data_for_eval, k, value_masked) - return frame_data_for_eval - - -def _run_eval( - model, - all_train_cameras, - loader, - task: Task, - camera_difficulty_bin_breaks: Tuple[float, float], - device, -): - """ - Run the evaluation loop on the test dataloader - """ - lpips_model = lpips.LPIPS(net="vgg") - lpips_model = lpips_model.to(device) - - model.eval() - - per_batch_eval_results = [] - logger.info("Evaluating model ...") - for frame_data in tqdm.tqdm(loader): - frame_data = frame_data.to(device) - - # mask out the unknown images so that the model does not see them - frame_data_for_eval = _get_eval_frame_data(frame_data) - - with torch.no_grad(): - preds = model( - **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION} - ) - - # TODO: Cannot use accelerate gather for two reasons:. - # (1) TypeError: Can't apply _gpu_gather_one on object of type - # , - # only of nested list/tuple/dicts of objects that satisfy is_torch_tensor. - # (2) Same error above but for frame_data which contains Cameras. - - implicitron_render = copy.deepcopy(preds["implicitron_render"]) - - per_batch_eval_results.append( - evaluate.eval_batch( - frame_data, - implicitron_render, - bg_color="black", - lpips_model=lpips_model, - source_cameras=all_train_cameras, - ) - ) - - _, category_result = evaluate.summarize_nvs_eval_results( - per_batch_eval_results, task, camera_difficulty_bin_breaks - ) - - return category_result["results"] - - -def _seed_all_random_engines(seed: int) -> None: - np.random.seed(seed) - torch.manual_seed(seed) - random.seed(seed) + def _check_config_consistent(self) -> None: + if hasattr(self.optimizer_factory, "resume") and hasattr( + self.model_factory, "resume" + ): + assert ( + # pyre-ignore [16] + not self.optimizer_factory.resume + # pyre-ignore [16] + or self.model_factory.resume + ), "Cannot resume the optimizer without resuming the model." + if hasattr(self.optimizer_factory, "resume_epoch") and hasattr( + self.model_factory, "resume_epoch" + ): + assert ( + # pyre-ignore [16] + self.optimizer_factory.resume_epoch + # pyre-ignore [16] + == self.model_factory.resume_epoch + ), "Optimizer and model must resume from the same epoch." def _setup_envvars_for_cluster() -> bool: @@ -678,9 +269,20 @@ def _setup_envvars_for_cluster() -> bool: return True -expand_args_fields(ExperimentConfig) +def dump_cfg(cfg: DictConfig) -> None: + remove_unused_components(cfg) + # dump the exp config to the exp dir + os.makedirs(cfg.exp_dir, exist_ok=True) + try: + cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml") + OmegaConf.save(config=cfg, f=cfg_filename) + except PermissionError: + warnings.warn("Can't dump config due to insufficient permissions!") + + +expand_args_fields(Experiment) cs = hydra.core.config_store.ConfigStore.instance() -cs.store(name="default_config", node=ExperimentConfig) +cs.store(name="default_config", node=Experiment) @hydra.main(config_path="./configs/", config_name="default_config") @@ -694,12 +296,14 @@ def experiment(cfg: DictConfig) -> None: logger.info("Running locally") # TODO: The following may be needed for hydra/submitit it to work - expand_args_fields(GenericModel) + expand_args_fields(ImplicitronModelBase) expand_args_fields(AdaptiveRaySampler) expand_args_fields(MultiPassEmissionAbsorptionRenderer) expand_args_fields(ImplicitronDataSource) - run_training(cfg) + experiment = Experiment(**cfg) + dump_cfg(cfg) + experiment.run() if __name__ == "__main__": diff --git a/projects/implicitron_trainer/impl/experiment_config.py b/projects/implicitron_trainer/impl/experiment_config.py deleted file mode 100644 index 30b36934..00000000 --- a/projects/implicitron_trainer/impl/experiment_config.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import field -from typing import Any, Dict, Tuple - -from omegaconf import DictConfig -from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource -from pytorch3d.implicitron.models.generic_model import GenericModel -from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field - -from .optimization import init_optimizer - - -class ExperimentConfig(Configurable): - generic_model_args: DictConfig = get_default_args_field(GenericModel) - solver_args: DictConfig = get_default_args_field(init_optimizer) - data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource) - architecture: str = "generic" - detect_anomaly: bool = False - eval_only: bool = False - exp_dir: str = "./data/default_experiment/" - exp_idx: int = 0 - gpu_idx: int = 0 - metric_print_interval: int = 5 - resume: bool = True - resume_epoch: int = -1 - seed: int = 0 - store_checkpoints: bool = True - store_checkpoints_purge: int = 1 - test_interval: int = -1 - test_when_finished: bool = False - validation_interval: int = 1 - visdom_env: str = "" - visdom_port: int = 8097 - visdom_server: str = "http://127.0.0.1" - visualize_interval: int = 1000 - clip_grad: float = 0.0 - camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98 - - hydra: Dict[str, Any] = field( - default_factory=lambda: { - "run": {"dir": "."}, # Make hydra not change the working dir. - "output_subdir": None, # disable storing the .hydra logs - } - ) diff --git a/projects/implicitron_trainer/impl/model_factory.py b/projects/implicitron_trainer/impl/model_factory.py new file mode 100644 index 00000000..9e9e7f01 --- /dev/null +++ b/projects/implicitron_trainer/impl/model_factory.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from typing import List, Optional + +import torch.optim + +from accelerate import Accelerator +from pytorch3d.implicitron.models.base_model import ImplicitronModelBase +from pytorch3d.implicitron.tools import model_io, vis_utils +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) +from pytorch3d.implicitron.tools.stats import Stats + +logger = logging.getLogger(__name__) + + +class ModelFactoryBase(ReplaceableBase): + def __call__(self, **kwargs) -> ImplicitronModelBase: + """ + Initialize the model (possibly from a previously saved state). + + Returns: An instance of ImplicitronModelBase. + """ + raise NotImplementedError() + + def load_stats(self, **kwargs) -> Stats: + """ + Initialize or load a Stats object. + """ + raise NotImplementedError() + + +@registry.register +class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13] + """ + A factory class that initializes an implicit rendering model. + + Members: + force_load: If True, throw a FileNotFoundError if `resume` is True but + a model checkpoint cannot be found. + model: An ImplicitronModelBase object. + resume: If True, attempt to load the last checkpoint from `exp_dir` + passed to __call__. Failure to do so will return a model with ini- + tial weights unless `force_load` is True. + resume_epoch: If `resume` is True: Resume a model at this epoch, or if + `resume_epoch` <= 0, then resume from the latest checkpoint. + visdom_env: The name of the Visdom environment to use for plotting. + visdom_port: The Visdom port. + visdom_server: Address of the Visdom server. + """ + + force_load: bool = False + model: ImplicitronModelBase + model_class_type: str = "GenericModel" + resume: bool = False + resume_epoch: int = -1 + visdom_env: str = "" + visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097)) + visdom_server: str = "http://127.0.0.1" + + def __post_init__(self): + run_auto_creation(self) + + def __call__( + self, + exp_dir: str, + accelerator: Optional[Accelerator] = None, + ) -> ImplicitronModelBase: + """ + Returns an instance of `ImplicitronModelBase`, possibly loaded from a + checkpoint (if self.resume, self.resume_epoch specify so). + + Args: + exp_dir: Root experiment directory. + accelerator: An Accelerator object. + + Returns: + model: The model with optionally loaded weights from checkpoint + + Raise: + FileNotFoundError if `force_load` is True but checkpoint not found. + """ + # Determine the network outputs that should be logged + if hasattr(self.model, "log_vars"): + log_vars = list(self.model.log_vars) # pyre-ignore [6] + else: + log_vars = ["objective"] + + # Retrieve the last checkpoint + if self.resume_epoch > 0: + model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) + else: + model_path = model_io.find_last_checkpoint(exp_dir) + + if model_path is not None: + logger.info("found previous model %s" % model_path) + if self.force_load or self.resume: + logger.info(" -> resuming") + + map_location = None + if accelerator is not None and not accelerator.is_local_main_process: + map_location = { + "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index + } + model_state_dict = torch.load( + model_io.get_model_path(model_path), map_location=map_location + ) + + try: + self.model.load_state_dict(model_state_dict, strict=True) + except RuntimeError as e: + logger.error(e) + logger.info( + "Cant load state dict in strict mode! -> trying non-strict" + ) + self.model.load_state_dict(model_state_dict, strict=False) + self.model.log_vars = log_vars # pyre-ignore [16] + else: + logger.info(" -> but not resuming -> starting from scratch") + elif self.force_load: + raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!") + + return self.model + + def load_stats( + self, + log_vars: List[str], + exp_dir: str, + clear_stats: bool = False, + **kwargs, + ) -> Stats: + """ + Load Stats that correspond to the model's log_vars. + + Args: + log_vars: A list of variable names to log. Should be a subset of the + `preds` returned by the forward function of the corresponding + ImplicitronModelBase instance. + exp_dir: Root experiment directory. + clear_stats: If True, do not load stats from the checkpoint speci- + fied by self.resume and self.resume_epoch; instead, create a + fresh stats object. + + stats: The stats structure (optionally loaded from checkpoint) + """ + # Init the stats struct + visdom_env_charts = ( + vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts" + ) + stats = Stats( + # log_vars should be a list, but OmegaConf might load them as ListConfig + list(log_vars), + visdom_env=visdom_env_charts, + verbose=False, + visdom_server=self.visdom_server, + visdom_port=self.visdom_port, + ) + if self.resume_epoch > 0: + model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) + else: + model_path = model_io.find_last_checkpoint(exp_dir) + + if model_path is not None: + stats_path = model_io.get_stats_path(model_path) + stats_load = model_io.load_stats(stats_path) + + # Determine if stats should be reset + if not clear_stats: + if stats_load is None: + logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n") + last_epoch = model_io.parse_epoch_from_model_path(model_path) + logger.info(f"Estimated resume epoch = {last_epoch}") + + # Reset the stats struct + for _ in range(last_epoch + 1): + stats.new_epoch() + assert last_epoch == stats.epoch + else: + stats = stats_load + + # Update stats properties incase it was reset on load + stats.visdom_env = visdom_env_charts + stats.visdom_server = self.visdom_server + stats.visdom_port = self.visdom_port + stats.plot_file = os.path.join(exp_dir, "train_stats.pdf") + stats.synchronize_logged_vars(log_vars) + else: + logger.info(" -> clearing stats") + + return stats diff --git a/projects/implicitron_trainer/impl/optimization.py b/projects/implicitron_trainer/impl/optimization.py deleted file mode 100644 index 1c6a589c..00000000 --- a/projects/implicitron_trainer/impl/optimization.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from typing import Any, Dict, Optional, Tuple - -import torch -from pytorch3d.implicitron.models.generic_model import GenericModel -from pytorch3d.implicitron.tools.config import enable_get_default_args - -logger = logging.getLogger(__name__) - - -def init_optimizer( - model: GenericModel, - optimizer_state: Optional[Dict[str, Any]], - last_epoch: int, - breed: str = "adam", - weight_decay: float = 0.0, - lr_policy: str = "multistep", - lr: float = 0.0005, - gamma: float = 0.1, - momentum: float = 0.9, - betas: Tuple[float, ...] = (0.9, 0.999), - milestones: Tuple[int, ...] = (), - max_epochs: int = 1000, -): - """ - Initialize the optimizer (optionally from checkpoint state) - and the learning rate scheduler. - - Args: - model: The model with optionally loaded weights - optimizer_state: The state dict for the optimizer. If None - it has not been loaded from checkpoint - last_epoch: If the model was loaded from checkpoint this will be the - number of the last epoch that was saved - breed: The type of optimizer to use e.g. adam - weight_decay: The optimizer weight_decay (L2 penalty on model weights) - lr_policy: The policy to use for learning rate. Currently, only "multistep: - is supported. - lr: The value for the initial learning rate - gamma: Multiplicative factor of learning rate decay - momentum: Momentum factor for SGD optimizer - betas: Coefficients used for computing running averages of gradient and its square - in the Adam optimizer - milestones: List of increasing epoch indices at which the learning rate is - modified - max_epochs: The maximum number of epochs to run the optimizer for - - Returns: - optimizer: Optimizer module, optionally loaded from checkpoint - scheduler: Learning rate scheduler module - - Raise: - ValueError if `breed` or `lr_policy` are not supported. - """ - - # Get the parameters to optimize - if hasattr(model, "_get_param_groups"): # use the model function - # pyre-ignore[29] - p_groups = model._get_param_groups(lr, wd=weight_decay) - else: - allprm = [prm for prm in model.parameters() if prm.requires_grad] - p_groups = [{"params": allprm, "lr": lr}] - - # Intialize the optimizer - if breed == "sgd": - optimizer = torch.optim.SGD( - p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay - ) - elif breed == "adagrad": - optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay) - elif breed == "adam": - optimizer = torch.optim.Adam( - p_groups, lr=lr, betas=betas, weight_decay=weight_decay - ) - else: - raise ValueError("no such solver type %s" % breed) - logger.info(" -> solver type = %s" % breed) - - # Load state from checkpoint - if optimizer_state is not None: - logger.info(" -> setting loaded optimizer state") - optimizer.load_state_dict(optimizer_state) - - # Initialize the learning rate scheduler - if lr_policy == "multistep": - scheduler = torch.optim.lr_scheduler.MultiStepLR( - optimizer, - milestones=milestones, - gamma=gamma, - ) - else: - raise ValueError("no such lr policy %s" % lr_policy) - - # When loading from checkpoint, this will make sure that the - # lr is correctly set even after returning - for _ in range(last_epoch): - scheduler.step() - - optimizer.zero_grad() - return optimizer, scheduler - - -enable_get_default_args(init_optimizer) diff --git a/projects/implicitron_trainer/impl/optimizer_factory.py b/projects/implicitron_trainer/impl/optimizer_factory.py new file mode 100644 index 00000000..74de28c0 --- /dev/null +++ b/projects/implicitron_trainer/impl/optimizer_factory.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from typing import Any, Dict, Optional, Tuple + +import torch.optim + +from accelerate import Accelerator + +from pytorch3d.implicitron.models.base_model import ImplicitronModelBase +from pytorch3d.implicitron.tools import model_io +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) + +logger = logging.getLogger(__name__) + + +class OptimizerFactoryBase(ReplaceableBase): + def __call__( + self, model: ImplicitronModelBase, **kwargs + ) -> Tuple[torch.optim.Optimizer, Any]: + """ + Initialize the optimizer and lr scheduler. + + Args: + model: The model with optionally loaded weights. + + Returns: + An optimizer module (optionally loaded from a checkpoint) and + a learning rate scheduler module (should be a subclass of torch.optim's + lr_scheduler._LRScheduler). + """ + raise NotImplementedError() + + +@registry.register +class ImplicitronOptimizerFactory(OptimizerFactoryBase): + """ + A factory that initializes the optimizer and lr scheduler. + + Members: + betas: Beta parameters for the Adam optimizer. + breed: The type of optimizer to use. We currently support SGD, Adagrad + and Adam. + exponential_lr_step_size: With Exponential policy only, + lr = lr * gamma ** (epoch/step_size) + gamma: Multiplicative factor of learning rate decay. + lr: The value for the initial learning rate. + lr_policy: The policy to use for learning rate. We currently support + MultiStepLR and Exponential policies. + momentum: A momentum value (for SGD only). + multistep_lr_milestones: With MultiStepLR policy only: list of + increasing epoch indices at which the learning rate is modified. + momentum: Momentum factor for SGD optimizer. + resume: If True, attempt to load the last checkpoint from `exp_dir` + passed to __call__. Failure to do so will return a newly initialized + optimizer. + resume_epoch: If `resume` is True: Resume optimizer at this epoch. If + `resume_epoch` <= 0, then resume from the latest checkpoint. + weight_decay: The optimizer weight_decay (L2 penalty on model weights). + """ + + betas: Tuple[float, ...] = (0.9, 0.999) + breed: str = "Adam" + exponential_lr_step_size: int = 250 + gamma: float = 0.1 + lr: float = 0.0005 + lr_policy: str = "MultiStepLR" + momentum: float = 0.9 + multistep_lr_milestones: tuple = () + resume: bool = False + resume_epoch: int = -1 + weight_decay: float = 0.0 + + def __post_init__(self): + run_auto_creation(self) + + def __call__( + self, + last_epoch: int, + model: ImplicitronModelBase, + accelerator: Optional[Accelerator] = None, + exp_dir: Optional[str] = None, + **kwargs, + ) -> Tuple[torch.optim.Optimizer, Any]: + """ + Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer. + + Args: + last_epoch: If the model was loaded from checkpoint this will be the + number of the last epoch that was saved. + model: The model with optionally loaded weights. + accelerator: An optional Accelerator instance. + exp_dir: Root experiment directory. + + Returns: + An optimizer module (optionally loaded from a checkpoint) and + a learning rate scheduler module (should be a subclass of torch.optim's + lr_scheduler._LRScheduler). + """ + # Get the parameters to optimize + if hasattr(model, "_get_param_groups"): # use the model function + # pyre-ignore[29] + p_groups = model._get_param_groups(self.lr, wd=self.weight_decay) + else: + allprm = [prm for prm in model.parameters() if prm.requires_grad] + p_groups = [{"params": allprm, "lr": self.lr}] + + # Intialize the optimizer + if self.breed == "SGD": + optimizer = torch.optim.SGD( + p_groups, + lr=self.lr, + momentum=self.momentum, + weight_decay=self.weight_decay, + ) + elif self.breed == "Adagrad": + optimizer = torch.optim.Adagrad( + p_groups, lr=self.lr, weight_decay=self.weight_decay + ) + elif self.breed == "Adam": + optimizer = torch.optim.Adam( + p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay + ) + else: + raise ValueError("no such solver type %s" % self.breed) + logger.info(" -> solver type = %s" % self.breed) + + # Load state from checkpoint + optimizer_state = self._get_optimizer_state(exp_dir, accelerator) + if optimizer_state is not None: + logger.info(" -> setting loaded optimizer state") + optimizer.load_state_dict(optimizer_state) + + # Initialize the learning rate scheduler + if self.lr_policy.casefold() == "MultiStepLR".casefold(): + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=self.multistep_lr_milestones, + gamma=self.gamma, + ) + elif self.lr_policy.casefold() == "Exponential".casefold(): + scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, + lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size), + verbose=False, + ) + else: + raise ValueError("no such lr policy %s" % self.lr_policy) + + # When loading from checkpoint, this will make sure that the + # lr is correctly set even after returning. + for _ in range(last_epoch): + scheduler.step() + + optimizer.zero_grad() + + return optimizer, scheduler + + def _get_optimizer_state( + self, + exp_dir: Optional[str], + accelerator: Optional[Accelerator] = None, + ) -> Optional[Dict[str, Any]]: + """ + Load an optimizer state from a checkpoint. + """ + if exp_dir is None or not self.resume: + return None + if self.resume_epoch > 0: + save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) + else: + save_path = model_io.find_last_checkpoint(exp_dir) + optimizer_state = None + if save_path is not None: + logger.info(f"Found previous optimizer state {save_path}.") + logger.info(" -> resuming") + opt_path = model_io.get_optimizer_path(save_path) + + if os.path.isfile(opt_path): + map_location = None + if accelerator is not None and not accelerator.is_local_main_process: + map_location = { + "cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index + } + optimizer_state = torch.load(opt_path, map_location) + else: + optimizer_state = None + return optimizer_state diff --git a/projects/implicitron_trainer/impl/training_loop.py b/projects/implicitron_trainer/impl/training_loop.py new file mode 100644 index 00000000..9a0601a6 --- /dev/null +++ b/projects/implicitron_trainer/impl/training_loop.py @@ -0,0 +1,365 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import random +import time +from typing import Any, Optional + +import numpy as np +import torch +from accelerate import Accelerator +from pytorch3d.implicitron.dataset.data_source import Task +from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase +from pytorch3d.implicitron.models.base_model import ImplicitronModelBase +from pytorch3d.implicitron.models.generic_model import EvaluationMode +from pytorch3d.implicitron.tools import model_io, vis_utils +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) +from pytorch3d.implicitron.tools.stats import Stats +from pytorch3d.renderer.cameras import CamerasBase +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +class TrainingLoopBase(ReplaceableBase): + def run( + self, + train_loader: DataLoader, + val_loader: Optional[DataLoader], + test_loader: Optional[DataLoader], + model: ImplicitronModelBase, + optimizer: torch.optim.Optimizer, + scheduler: Any, + **kwargs, + ) -> None: + raise NotImplementedError() + + +@registry.register +class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13] + """ + Members: + eval_only: If True, only run evaluation using the test dataloader. + evaluator: An EvaluatorBase instance, used to evaluate training results. + max_epochs: Train for this many epochs. Note that if the model was + loaded from a checkpoint, we will restart training at the appropriate + epoch and run for (max_epochs - checkpoint_epoch) epochs. + seed: A random seed to ensure reproducibility. + store_checkpoints: If True, store model and optimizer state checkpoints. + store_checkpoints_purge: If >= 0, remove any checkpoints older or equal + to this many epochs. + test_interval: Evaluate on a test dataloader each `test_interval` epochs. + test_when_finished: If True, evaluate on a test dataloader when training + completes. + validation_interval: Validate each `validation_interval` epochs. + clip_grad: Optionally clip the gradient norms. + If set to a value <=0.0, no clipping + metric_print_interval: The batch interval at which the stats should be + logged. + visualize_interval: The batch interval at which the visualizations + should be plotted + """ + + # Parameters of the outer training loop. + eval_only: bool = False + evaluator: EvaluatorBase + evaluator_class_type: str = "ImplicitronEvaluator" + max_epochs: int = 1000 + seed: int = 0 + store_checkpoints: bool = True + store_checkpoints_purge: int = 1 + test_interval: int = -1 + test_when_finished: bool = False + validation_interval: int = 1 + + # Parameters of a single training-validation step. + clip_grad: float = 0.0 + metric_print_interval: int = 5 + visualize_interval: int = 1000 + + def __post_init__(self): + run_auto_creation(self) + + def run( + self, + *, + train_loader: DataLoader, + val_loader: Optional[DataLoader], + test_loader: Optional[DataLoader], + model: ImplicitronModelBase, + optimizer: torch.optim.Optimizer, + scheduler: Any, + accelerator: Optional[Accelerator], + all_train_cameras: Optional[CamerasBase], + device: torch.device, + exp_dir: str, + stats: Stats, + task: Task, + **kwargs, + ): + """ + Entry point to run the training and validation loops + based on the specified config file. + """ + _seed_all_random_engines(self.seed) + start_epoch = stats.epoch + 1 + assert scheduler.last_epoch == stats.epoch + 1 + assert scheduler.last_epoch == start_epoch + + # only run evaluation on the test dataloader + if self.eval_only: + if test_loader is not None: + self.evaluator.run( + all_train_cameras=all_train_cameras, + dataloader=test_loader, + device=device, + dump_to_json=True, + epoch=stats.epoch, + exp_dir=exp_dir, + model=model, + task=task, + ) + return + else: + raise ValueError( + "Cannot evaluate and dump results to json, no test data provided." + ) + + # loop through epochs + for epoch in range(start_epoch, self.max_epochs): + # automatic new_epoch and plotting of stats at every epoch start + with stats: + + # Make sure to re-seed random generators to ensure reproducibility + # even after restart. + _seed_all_random_engines(self.seed + epoch) + + cur_lr = float(scheduler.get_last_lr()[-1]) + logger.debug(f"scheduler lr = {cur_lr:1.2e}") + + # train loop + self._training_or_validation_epoch( + accelerator=accelerator, + device=device, + epoch=epoch, + loader=train_loader, + model=model, + optimizer=optimizer, + stats=stats, + validation=False, + ) + + # val loop (optional) + if val_loader is not None and epoch % self.validation_interval == 0: + self._training_or_validation_epoch( + accelerator=accelerator, + device=device, + epoch=epoch, + loader=val_loader, + model=model, + optimizer=optimizer, + stats=stats, + validation=True, + ) + + # eval loop (optional) + if ( + test_loader is not None + and self.test_interval > 0 + and epoch % self.test_interval == 0 + ): + self.evaluator.run( + all_train_cameras=all_train_cameras, + device=device, + dataloader=test_loader, + model=model, + task=task, + ) + + assert stats.epoch == epoch, "inconsistent stats!" + self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats) + + scheduler.step() + new_lr = float(scheduler.get_last_lr()[-1]) + if new_lr != cur_lr: + logger.info(f"LR change! {cur_lr} -> {new_lr}") + + if self.test_when_finished: + if test_loader is not None: + self.evaluator.run( + all_train_cameras=all_train_cameras, + device=device, + dump_to_json=True, + epoch=stats.epoch, + exp_dir=exp_dir, + dataloader=test_loader, + model=model, + task=task, + ) + else: + raise ValueError( + "Cannot evaluate and dump results to json, no test data provided." + ) + + def _training_or_validation_epoch( + self, + epoch: int, + loader: DataLoader, + model: ImplicitronModelBase, + optimizer: torch.optim.Optimizer, + stats: Stats, + validation: bool, + *, + accelerator: Optional[Accelerator], + bp_var: str = "objective", + device: torch.device, + **kwargs, + ) -> None: + """ + This is the main loop for training and evaluation including: + model forward pass, loss computation, backward pass and visualization. + + Args: + epoch: The index of the current epoch + loader: The dataloader to use for the loop + model: The model module optionally loaded from checkpoint + optimizer: The optimizer module optionally loaded from checkpoint + stats: The stats struct, also optionally loaded from checkpoint + validation: If true, run the loop with the model in eval mode + and skip the backward pass + accelerator: An optional Accelerator instance. + bp_var: The name of the key in the model output `preds` dict which + should be used as the loss for the backward pass. + device: The device on which to run the model. + """ + + if validation: + model.eval() + trainmode = "val" + else: + model.train() + trainmode = "train" + + t_start = time.time() + + # get the visdom env name + visdom_env_imgs = stats.visdom_env + "_images_" + trainmode + viz = vis_utils.get_visdom_connection( + server=stats.visdom_server, + port=stats.visdom_port, + ) + + # Iterate through the batches + n_batches = len(loader) + for it, net_input in enumerate(loader): + last_iter = it == n_batches - 1 + + # move to gpu where possible (in place) + net_input = net_input.to(device) + + # run the forward pass + if not validation: + optimizer.zero_grad() + preds = model( + **{**net_input, "evaluation_mode": EvaluationMode.TRAINING} + ) + else: + with torch.no_grad(): + preds = model( + **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION} + ) + + # make sure we dont overwrite something + assert all(k not in preds for k in net_input.keys()) + # merge everything into one big dict + preds.update(net_input) + + # update the stats logger + stats.update(preds, time_start=t_start, stat_set=trainmode) + # pyre-ignore [16] + assert stats.it[trainmode] == it, "inconsistent stat iteration number!" + + # print textual status update + if it % self.metric_print_interval == 0 or last_iter: + stats.print(stat_set=trainmode, max_it=n_batches) + + # visualize results + if ( + (accelerator is None or accelerator.is_local_main_process) + and self.visualize_interval > 0 + and it % self.visualize_interval == 0 + ): + prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" + if hasattr(model, "visualize"): + # pyre-ignore [29] + model.visualize( + viz, + visdom_env_imgs, + preds, + prefix, + ) + + # optimizer step + if not validation: + loss = preds[bp_var] + assert torch.isfinite(loss).all(), "Non-finite loss!" + # backprop + if accelerator is None: + loss.backward() + else: + accelerator.backward(loss) + if self.clip_grad > 0.0: + # Optionally clip the gradient norms. + total_norm = torch.nn.utils.clip_grad_norm( + model.parameters(), self.clip_grad + ) + if total_norm > self.clip_grad: + logger.debug( + f"Clipping gradient: {total_norm}" + + f" with coef {self.clip_grad / float(total_norm)}." + ) + + optimizer.step() + + def _checkpoint( + self, + accelerator: Optional[Accelerator], + epoch: int, + exp_dir: str, + model: ImplicitronModelBase, + optimizer: torch.optim.Optimizer, + stats: Stats, + ): + """ + Save a model and its corresponding Stats object to a file, if + `self.store_checkpoints` is True. In addition, if + `self.store_checkpoints_purge` is True, remove any checkpoints older + than `self.store_checkpoints_purge` epochs old. + """ + if self.store_checkpoints and ( + accelerator is None or accelerator.is_local_main_process + ): + if self.store_checkpoints_purge > 0: + for prev_epoch in range(epoch - self.store_checkpoints_purge): + model_io.purge_epoch(exp_dir, prev_epoch) + outfile = model_io.get_checkpoint(exp_dir, epoch) + unwrapped_model = ( + model if accelerator is None else accelerator.unwrap_model(model) + ) + model_io.safe_save_model( + unwrapped_model, stats, outfile, optimizer=optimizer + ) + + +def _seed_all_random_engines(seed: int) -> None: + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 1b2ebeb8..4ef2b510 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -1,296 +1,14 @@ -generic_model_args: - mask_images: true - mask_depths: true - render_image_width: 400 - render_image_height: 400 - mask_threshold: 0.5 - output_rasterized_mc: false - bg_color: - - 0.0 - - 0.0 - - 0.0 - num_passes: 1 - chunk_size_grid: 4096 - render_features_dimensions: 3 - tqdm_trigger_threshold: 16 - n_train_target_views: 1 - sampling_mode_training: mask_sample - sampling_mode_evaluation: full_grid - global_encoder_class_type: null - raysampler_class_type: AdaptiveRaySampler - renderer_class_type: MultiPassEmissionAbsorptionRenderer - image_feature_extractor_class_type: null - view_pooler_enabled: false - implicit_function_class_type: NeuralRadianceFieldImplicitFunction - view_metrics_class_type: ViewMetrics - regularization_metrics_class_type: RegularizationMetrics - loss_weights: - loss_rgb_mse: 1.0 - loss_prev_stage_rgb_mse: 1.0 - loss_mask_bce: 0.0 - loss_prev_stage_mask_bce: 0.0 - log_vars: - - loss_rgb_psnr_fg - - loss_rgb_psnr - - loss_rgb_mse - - loss_rgb_huber - - loss_depth_abs - - loss_depth_abs_fg - - loss_mask_neg_iou - - loss_mask_bce - - loss_mask_beta_prior - - loss_eikonal - - loss_density_tv - - loss_depth_neg_penalty - - loss_autodecoder_norm - - loss_prev_stage_rgb_mse - - loss_prev_stage_rgb_psnr_fg - - loss_prev_stage_rgb_psnr - - loss_prev_stage_mask_bce - - objective - - epoch - - sec/it - global_encoder_HarmonicTimeEncoder_args: - n_harmonic_functions: 10 - append_input: true - time_divisor: 1.0 - global_encoder_SequenceAutodecoder_args: - autodecoder_args: - encoding_dim: 0 - n_instances: 0 - init_scale: 1.0 - ignore_input: false - raysampler_AdaptiveRaySampler_args: - image_width: 400 - image_height: 400 - sampling_mode_training: mask_sample - sampling_mode_evaluation: full_grid - n_pts_per_ray_training: 64 - n_pts_per_ray_evaluation: 64 - n_rays_per_image_sampled_from_mask: 1024 - stratified_point_sampling_training: true - stratified_point_sampling_evaluation: false - scene_extent: 8.0 - scene_center: - - 0.0 - - 0.0 - - 0.0 - raysampler_NearFarRaySampler_args: - image_width: 400 - image_height: 400 - sampling_mode_training: mask_sample - sampling_mode_evaluation: full_grid - n_pts_per_ray_training: 64 - n_pts_per_ray_evaluation: 64 - n_rays_per_image_sampled_from_mask: 1024 - stratified_point_sampling_training: true - stratified_point_sampling_evaluation: false - min_depth: 0.1 - max_depth: 8.0 - renderer_LSTMRenderer_args: - num_raymarch_steps: 10 - init_depth: 17.0 - init_depth_noise_std: 0.0005 - hidden_size: 16 - n_feature_channels: 256 - bg_color: null - verbose: false - renderer_MultiPassEmissionAbsorptionRenderer_args: - raymarcher_class_type: EmissionAbsorptionRaymarcher - n_pts_per_ray_fine_training: 64 - n_pts_per_ray_fine_evaluation: 64 - stratified_sampling_coarse_training: true - stratified_sampling_coarse_evaluation: false - append_coarse_samples_to_fine: true - density_noise_std_train: 0.0 - return_weights: false - raymarcher_CumsumRaymarcher_args: - surface_thickness: 1 - bg_color: - - 0.0 - background_opacity: 0.0 - density_relu: true - blend_output: false - raymarcher_EmissionAbsorptionRaymarcher_args: - surface_thickness: 1 - bg_color: - - 0.0 - background_opacity: 10000000000.0 - density_relu: true - blend_output: false - renderer_SignedDistanceFunctionRenderer_args: - render_features_dimensions: 3 - ray_tracer_args: - object_bounding_sphere: 1.0 - sdf_threshold: 5.0e-05 - line_search_step: 0.5 - line_step_iters: 1 - sphere_tracing_iters: 10 - n_steps: 100 - n_secant_steps: 8 - ray_normal_coloring_network_args: - feature_vector_size: 3 - mode: idr - d_in: 9 - d_out: 3 - dims: - - 512 - - 512 - - 512 - - 512 - weight_norm: true - n_harmonic_functions_dir: 0 - pooled_feature_dim: 0 - bg_color: - - 0.0 - soft_mask_alpha: 50.0 - image_feature_extractor_ResNetFeatureExtractor_args: - name: resnet34 - pretrained: true - stages: - - 1 - - 2 - - 3 - - 4 - normalize_image: true - image_rescale: 0.16 - first_max_pool: true - proj_dim: 32 - l2_norm: true - add_masks: true - add_images: true - global_average_pool: false - feature_rescale: 1.0 - view_pooler_args: - feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator - view_sampler_args: - masked_sampling: false - sampling_mode: bilinear - feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - weight_by_ray_angle_gamma: 1.0 - min_ray_angle_weight: 0.1 - feature_aggregator_AngleWeightedReductionFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - reduction_functions: - - AVG - - STD - weight_by_ray_angle_gamma: 1.0 - min_ray_angle_weight: 0.1 - feature_aggregator_IdentityFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - feature_aggregator_ReductionFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - reduction_functions: - - AVG - - STD - implicit_function_IdrFeatureField_args: - feature_vector_size: 3 - d_in: 3 - d_out: 1 - dims: - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - - 512 - geometric_init: true - bias: 1.0 - skip_in: [] - weight_norm: true - n_harmonic_functions_xyz: 0 - pooled_feature_dim: 0 - encoding_dim: 0 - implicit_function_NeRFormerImplicitFunction_args: - n_harmonic_functions_xyz: 10 - n_harmonic_functions_dir: 4 - n_hidden_neurons_dir: 128 - latent_dim: 0 - input_xyz: true - xyz_ray_dir_in_camera_coords: false - color_dim: 3 - transformer_dim_down_factor: 2.0 - n_hidden_neurons_xyz: 80 - n_layers_xyz: 2 - append_xyz: - - 1 - implicit_function_NeuralRadianceFieldImplicitFunction_args: - n_harmonic_functions_xyz: 10 - n_harmonic_functions_dir: 4 - n_hidden_neurons_dir: 128 - latent_dim: 0 - input_xyz: true - xyz_ray_dir_in_camera_coords: false - color_dim: 3 - transformer_dim_down_factor: 1.0 - n_hidden_neurons_xyz: 256 - n_layers_xyz: 8 - append_xyz: - - 5 - implicit_function_SRNHyperNetImplicitFunction_args: - hypernet_args: - n_harmonic_functions: 3 - n_hidden_units: 256 - n_layers: 2 - n_hidden_units_hypernet: 256 - n_layers_hypernet: 1 - in_features: 3 - out_features: 256 - latent_dim_hypernet: 0 - latent_dim: 0 - xyz_in_camera_coords: false - pixel_generator_args: - n_harmonic_functions: 4 - n_hidden_units: 256 - n_hidden_units_color: 128 - n_layers: 2 - in_features: 256 - out_features: 3 - ray_dir_in_camera_coords: false - implicit_function_SRNImplicitFunction_args: - raymarch_function_args: - n_harmonic_functions: 3 - n_hidden_units: 256 - n_layers: 2 - in_features: 3 - out_features: 256 - latent_dim: 0 - xyz_in_camera_coords: false - raymarch_function: null - pixel_generator_args: - n_harmonic_functions: 4 - n_hidden_units: 256 - n_hidden_units_color: 128 - n_layers: 2 - in_features: 256 - out_features: 3 - ray_dir_in_camera_coords: false - view_metrics_ViewMetrics_args: {} - regularization_metrics_RegularizationMetrics_args: {} -solver_args: - breed: adam - weight_decay: 0.0 - lr_policy: multistep - lr: 0.0005 - gamma: 0.1 - momentum: 0.9 - betas: - - 0.9 - - 0.999 - milestones: [] - max_epochs: 1000 -data_source_args: +data_source_class_type: ImplicitronDataSource +model_factory_class_type: ImplicitronModelFactory +optimizer_factory_class_type: ImplicitronOptimizerFactory +training_loop_class_type: ImplicitronTrainingLoop +detect_anomaly: false +exp_dir: ./data/default_experiment/ +hydra: + run: + dir: . + output_subdir: null +data_source_ImplicitronDataSource_args: dataset_map_provider_class_type: ??? data_loader_map_provider_class_type: SequenceDataLoaderMapProvider dataset_map_provider_BlenderDatasetMapProvider_args: @@ -396,30 +114,322 @@ data_source_args: sample_consecutive_frames: false consecutive_frames_max_gap: 0 consecutive_frames_max_gap_seconds: 0.1 -architecture: generic -detect_anomaly: false -eval_only: false -exp_dir: ./data/default_experiment/ -exp_idx: 0 -gpu_idx: 0 -metric_print_interval: 5 -resume: true -resume_epoch: -1 -seed: 0 -store_checkpoints: true -store_checkpoints_purge: 1 -test_interval: -1 -test_when_finished: false -validation_interval: 1 -visdom_env: '' -visdom_port: 8097 -visdom_server: http://127.0.0.1 -visualize_interval: 1000 -clip_grad: 0.0 -camera_difficulty_bin_breaks: -- 0.97 -- 0.98 -hydra: - run: - dir: . - output_subdir: null +model_factory_ImplicitronModelFactory_args: + force_load: false + model_class_type: GenericModel + resume: false + resume_epoch: -1 + visdom_env: '' + visdom_port: 8097 + visdom_server: http://127.0.0.1 + model_GenericModel_args: + mask_images: true + mask_depths: true + render_image_width: 400 + render_image_height: 400 + mask_threshold: 0.5 + output_rasterized_mc: false + bg_color: + - 0.0 + - 0.0 + - 0.0 + num_passes: 1 + chunk_size_grid: 4096 + render_features_dimensions: 3 + tqdm_trigger_threshold: 16 + n_train_target_views: 1 + sampling_mode_training: mask_sample + sampling_mode_evaluation: full_grid + global_encoder_class_type: null + raysampler_class_type: AdaptiveRaySampler + renderer_class_type: MultiPassEmissionAbsorptionRenderer + image_feature_extractor_class_type: null + view_pooler_enabled: false + implicit_function_class_type: NeuralRadianceFieldImplicitFunction + view_metrics_class_type: ViewMetrics + regularization_metrics_class_type: RegularizationMetrics + loss_weights: + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + loss_mask_bce: 0.0 + loss_prev_stage_mask_bce: 0.0 + log_vars: + - loss_rgb_psnr_fg + - loss_rgb_psnr + - loss_rgb_mse + - loss_rgb_huber + - loss_depth_abs + - loss_depth_abs_fg + - loss_mask_neg_iou + - loss_mask_bce + - loss_mask_beta_prior + - loss_eikonal + - loss_density_tv + - loss_depth_neg_penalty + - loss_autodecoder_norm + - loss_prev_stage_rgb_mse + - loss_prev_stage_rgb_psnr_fg + - loss_prev_stage_rgb_psnr + - loss_prev_stage_mask_bce + - objective + - epoch + - sec/it + global_encoder_HarmonicTimeEncoder_args: + n_harmonic_functions: 10 + append_input: true + time_divisor: 1.0 + global_encoder_SequenceAutodecoder_args: + autodecoder_args: + encoding_dim: 0 + n_instances: 0 + init_scale: 1.0 + ignore_input: false + raysampler_AdaptiveRaySampler_args: + image_width: 400 + image_height: 400 + sampling_mode_training: mask_sample + sampling_mode_evaluation: full_grid + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + n_rays_per_image_sampled_from_mask: 1024 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + scene_extent: 8.0 + scene_center: + - 0.0 + - 0.0 + - 0.0 + raysampler_NearFarRaySampler_args: + image_width: 400 + image_height: 400 + sampling_mode_training: mask_sample + sampling_mode_evaluation: full_grid + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + n_rays_per_image_sampled_from_mask: 1024 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + min_depth: 0.1 + max_depth: 8.0 + renderer_LSTMRenderer_args: + num_raymarch_steps: 10 + init_depth: 17.0 + init_depth_noise_std: 0.0005 + hidden_size: 16 + n_feature_channels: 256 + bg_color: null + verbose: false + renderer_MultiPassEmissionAbsorptionRenderer_args: + raymarcher_class_type: EmissionAbsorptionRaymarcher + n_pts_per_ray_fine_training: 64 + n_pts_per_ray_fine_evaluation: 64 + stratified_sampling_coarse_training: true + stratified_sampling_coarse_evaluation: false + append_coarse_samples_to_fine: true + density_noise_std_train: 0.0 + return_weights: false + raymarcher_CumsumRaymarcher_args: + surface_thickness: 1 + bg_color: + - 0.0 + background_opacity: 0.0 + density_relu: true + blend_output: false + raymarcher_EmissionAbsorptionRaymarcher_args: + surface_thickness: 1 + bg_color: + - 0.0 + background_opacity: 10000000000.0 + density_relu: true + blend_output: false + renderer_SignedDistanceFunctionRenderer_args: + render_features_dimensions: 3 + ray_tracer_args: + object_bounding_sphere: 1.0 + sdf_threshold: 5.0e-05 + line_search_step: 0.5 + line_step_iters: 1 + sphere_tracing_iters: 10 + n_steps: 100 + n_secant_steps: 8 + ray_normal_coloring_network_args: + feature_vector_size: 3 + mode: idr + d_in: 9 + d_out: 3 + dims: + - 512 + - 512 + - 512 + - 512 + weight_norm: true + n_harmonic_functions_dir: 0 + pooled_feature_dim: 0 + bg_color: + - 0.0 + soft_mask_alpha: 50.0 + image_feature_extractor_ResNetFeatureExtractor_args: + name: resnet34 + pretrained: true + stages: + - 1 + - 2 + - 3 + - 4 + normalize_image: true + image_rescale: 0.16 + first_max_pool: true + proj_dim: 32 + l2_norm: true + add_masks: true + add_images: true + global_average_pool: false + feature_rescale: 1.0 + view_pooler_args: + feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator + view_sampler_args: + masked_sampling: false + sampling_mode: bilinear + feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + weight_by_ray_angle_gamma: 1.0 + min_ray_angle_weight: 0.1 + feature_aggregator_AngleWeightedReductionFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + reduction_functions: + - AVG + - STD + weight_by_ray_angle_gamma: 1.0 + min_ray_angle_weight: 0.1 + feature_aggregator_IdentityFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + feature_aggregator_ReductionFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + reduction_functions: + - AVG + - STD + implicit_function_IdrFeatureField_args: + feature_vector_size: 3 + d_in: 3 + d_out: 1 + dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + geometric_init: true + bias: 1.0 + skip_in: [] + weight_norm: true + n_harmonic_functions_xyz: 0 + pooled_feature_dim: 0 + encoding_dim: 0 + implicit_function_NeRFormerImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + latent_dim: 0 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + color_dim: 3 + transformer_dim_down_factor: 2.0 + n_hidden_neurons_xyz: 80 + n_layers_xyz: 2 + append_xyz: + - 1 + implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + latent_dim: 0 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + color_dim: 3 + transformer_dim_down_factor: 1.0 + n_hidden_neurons_xyz: 256 + n_layers_xyz: 8 + append_xyz: + - 5 + implicit_function_SRNHyperNetImplicitFunction_args: + hypernet_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + n_hidden_units_hypernet: 256 + n_layers_hypernet: 1 + in_features: 3 + out_features: 256 + latent_dim_hypernet: 0 + latent_dim: 0 + xyz_in_camera_coords: false + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false + implicit_function_SRNImplicitFunction_args: + raymarch_function_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + in_features: 3 + out_features: 256 + latent_dim: 0 + xyz_in_camera_coords: false + raymarch_function: null + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false + view_metrics_ViewMetrics_args: {} + regularization_metrics_RegularizationMetrics_args: {} +optimizer_factory_ImplicitronOptimizerFactory_args: + betas: + - 0.9 + - 0.999 + breed: Adam + exponential_lr_step_size: 250 + gamma: 0.1 + lr: 0.0005 + lr_policy: MultiStepLR + momentum: 0.9 + multistep_lr_milestones: [] + resume: false + resume_epoch: -1 + weight_decay: 0.0 +training_loop_ImplicitronTrainingLoop_args: + eval_only: false + evaluator_class_type: ImplicitronEvaluator + max_epochs: 1000 + seed: 0 + store_checkpoints: true + store_checkpoints_purge: 1 + test_interval: -1 + test_when_finished: false + validation_interval: 1 + clip_grad: 0.0 + metric_print_interval: 5 + visualize_interval: 1000 + evaluator_ImplicitronEvaluator_args: + camera_difficulty_bin_breaks: + - 0.97 + - 0.98 diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index 7d3da06f..836a9530 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -12,6 +12,7 @@ from hydra import compose, initialize_config_dir from omegaconf import OmegaConf from .. import experiment +from .utils import intercept_logs def interactive_testing_requested() -> bool: @@ -33,7 +34,10 @@ DEBUG: bool = False # TODO: # - add enough files to skateboard_first_5 that this works on RE. # - share common code with PyTorch3D tests? -# - deal with the temporary output files this test creates + + +def _parse_float_from_log(line): + return float(line.split()[-1]) class TestExperiment(unittest.TestCase): @@ -44,15 +48,18 @@ class TestExperiment(unittest.TestCase): # Test making minimal changes to the dataclass defaults. if not interactive_testing_requested() or not internal: return - cfg = OmegaConf.structured(experiment.ExperimentConfig) - cfg.data_source_args.dataset_map_provider_class_type = ( + + # Manually override config values. Note that this is not necessary out- + # side of the tests! + cfg = OmegaConf.structured(experiment.Experiment) + cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = ( "JsonIndexDatasetMapProvider" ) dataset_args = ( - cfg.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args + cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args ) dataloader_args = ( - cfg.data_source_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args + cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args ) dataset_args.category = "skateboard" dataset_args.test_restrict_sequence_id = 0 @@ -62,18 +69,80 @@ class TestExperiment(unittest.TestCase): dataset_args.dataset_JsonIndexDataset_args.image_width = 80 dataloader_args.dataset_length_train = 1 dataloader_args.dataset_length_val = 1 - cfg.solver_args.max_epochs = 2 + cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2 + cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False + cfg.optimizer_factory_ImplicitronOptimizerFactory_args.multistep_lr_milestones = [ + 0, + 1, + ] - experiment.run_training(cfg) + if DEBUG: + experiment.dump_cfg(cfg) + with intercept_logs( + logger_name="projects.implicitron_trainer.impl.training_loop", + regexp="LR change!", + ) as intercepted_logs: + experiment_runner = experiment.Experiment(**cfg) + experiment_runner.run() + + # Make sure LR decreased on 0th and 1st epoch 10fold. + self.assertEqual(intercepted_logs[0].split()[-1], "5e-06") + + def test_exponential_lr(self): + # Test making minimal changes to the dataclass defaults. + if not interactive_testing_requested(): + return + cfg = OmegaConf.structured(experiment.Experiment) + cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = ( + "JsonIndexDatasetMapProvider" + ) + dataset_args = ( + cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args + ) + dataloader_args = ( + cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args + ) + dataset_args.category = "skateboard" + dataset_args.test_restrict_sequence_id = 0 + dataset_args.dataset_root = "manifold://co3d/tree/extracted" + dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5 + dataset_args.dataset_JsonIndexDataset_args.image_height = 80 + dataset_args.dataset_JsonIndexDataset_args.image_width = 80 + dataloader_args.dataset_length_train = 1 + dataloader_args.dataset_length_val = 1 + cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2 + cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False + cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential" + cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = ( + 2 + ) + + if DEBUG: + experiment.dump_cfg(cfg) + with intercept_logs( + logger_name="projects.implicitron_trainer.impl.training_loop", + regexp="LR change!", + ) as intercepted_logs: + experiment_runner = experiment.Experiment(**cfg) + experiment_runner.run() + + # Make sure we followed the exponential lr schedule with gamma=0.1, + # exponential_lr_step_size=2 -- so after two epochs, should + # decrease lr 10x to 5e-5. + self.assertEqual(intercepted_logs[0].split()[-1], "0.00015811388300841897") + self.assertEqual(intercepted_logs[1].split()[-1], "5e-05") def test_yaml_contents(self): - cfg = OmegaConf.structured(experiment.ExperimentConfig) + # Check that the default config values, defined by Experiment and its + # members, is what we expect it to be. + cfg = OmegaConf.structured(experiment.Experiment) yaml = OmegaConf.to_yaml(cfg, sort_keys=False) if DEBUG: (DATA_DIR / "experiment.yaml").write_text(yaml) self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text()) def test_load_configs(self): + # Check that all the pre-prepared configs are valid. config_files = [] for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): @@ -89,3 +158,17 @@ class TestExperiment(unittest.TestCase): with self.subTest(file.name): with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)): compose(file.name) + + +class TestNerfRepro(unittest.TestCase): + @unittest.skip("This test reproduces full NERF training.") + def test_nerf_blender(self): + # Train vanilla NERF. + # Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first! + if not interactive_testing_requested(): + return + with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)): + cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[]) + experiment_runner = experiment.Experiment(**cfg) + experiment.dump_cfg(cfg) + experiment_runner.run() diff --git a/projects/implicitron_trainer/tests/utils.py b/projects/implicitron_trainer/tests/utils.py new file mode 100644 index 00000000..c31a4938 --- /dev/null +++ b/projects/implicitron_trainer/tests/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import re +from typing import List + + +@contextlib.contextmanager +def intercept_logs(logger_name: str, regexp: str): + # Intercept logs that match a regexp, from a given logger. + intercepted_messages = [] + logger = logging.getLogger(logger_name) + + class LoggerInterceptor(logging.Filter): + def filter(self, record): + message = record.getMessage() + if re.search(regexp, message): + intercepted_messages.append(message) + return True + + interceptor = LoggerInterceptor() + logger.addFilter(interceptor) + try: + yield intercepted_messages + finally: + logger.removeFilter(interceptor) diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py index 83c10358..f4bb543d 100644 --- a/projects/implicitron_trainer/visualize_reconstruction.py +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -22,7 +22,6 @@ import numpy as np import torch import torch.nn.functional as Fu from omegaconf import OmegaConf -from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.models.base_model import EvaluationMode @@ -37,7 +36,7 @@ from pytorch3d.implicitron.tools.vis_utils import ( ) from tqdm import tqdm -from .experiment import init_model +from .experiment import Experiment def render_sequence( @@ -344,13 +343,14 @@ def export_scenes( os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx) # Load the previously trained model - model, _, _ = init_model(cfg=config, force_load=True, load_model_only=True) + experiment = Experiment(config) + model = experiment.model_factory(force_load=True, load_model_only=True) model.cuda() model.eval() # Setup the dataset - datasource = ImplicitronDataSource(**config.data_source_args) - dataset_map = datasource.dataset_map_provider.get_dataset_map() + data_source = experiment.data_source + dataset_map, _ = data_source.get_datasets_and_dataloaders() dataset = dataset_map[split] if dataset is None: raise ValueError(f"{split} dataset not provided") diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index a83789cc..9696597a 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -40,6 +40,9 @@ class DataSourceBase(ReplaceableBase): """ raise NotImplementedError() + def get_task(self) -> Task: + raise NotImplementedError() + @registry.register class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] diff --git a/pytorch3d/implicitron/evaluation/evaluator.py b/pytorch3d/implicitron/evaluation/evaluator.py new file mode 100644 index 00000000..12875179 --- /dev/null +++ b/pytorch3d/implicitron/evaluation/evaluator.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy + +import json +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + +import lpips +import torch +import tqdm + +from pytorch3d.implicitron.dataset import utils as ds_utils +from pytorch3d.implicitron.dataset.data_source import Task + +from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate +from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase +from pytorch3d.implicitron.tools.config import ( + registry, + ReplaceableBase, + run_auto_creation, +) +from pytorch3d.renderer.cameras import CamerasBase +from torch.utils.data import DataLoader + +logger = logging.getLogger(__name__) + + +class EvaluatorBase(ReplaceableBase): + """ + Evaluate a trained model on given data. Returns a dict of loss/objective + names and their values. + """ + + def run( + self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs + ) -> Dict[str, Any]: + """ + Evaluate the results of Implicitron training. + """ + raise NotImplementedError() + + +@registry.register +class ImplicitronEvaluator(EvaluatorBase): + """ + Evaluate the results of Implicitron training. + + Members: + camera_difficulty_bin_breaks: low/medium vals to divide camera difficulties into + [0-eps, low, medium, 1+eps]. + """ + + camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98 + + def __post_init__(self): + run_auto_creation(self) + + def run( + self, + model: ImplicitronModelBase, + dataloader: DataLoader, + task: Task, + all_train_cameras: Optional[CamerasBase], + device: torch.device, + dump_to_json: bool = False, + exp_dir: Optional[str] = None, + epoch: Optional[int] = None, + **kwargs, + ) -> Dict[str, Any]: + """ + Evaluate the results of Implicitron training. Optionally, dump results to + exp_dir/results_test.json. + + Args: + model: A (trained) model to evaluate. + dataloader: A test dataloader. + task: Type of the novel-view synthesis task we're working on. + all_train_cameras: Camera instances we used for training. + device: A torch device. + dump_to_json: If True, will dump the results to a json file. + exp_dir: Root expeirment directory. + epoch: Evaluation epoch (to be stored in the results dict). + + Returns: + A dictionary of results. + """ + lpips_model = lpips.LPIPS(net="vgg") + lpips_model = lpips_model.to(device) + + model.eval() + + per_batch_eval_results = [] + logger.info("Evaluating model ...") + for frame_data in tqdm.tqdm(dataloader): + frame_data = frame_data.to(device) + + # mask out the unknown images so that the model does not see them + frame_data_for_eval = _get_eval_frame_data(frame_data) + + with torch.no_grad(): + preds = model( + **{ + **frame_data_for_eval, + "evaluation_mode": EvaluationMode.EVALUATION, + } + ) + implicitron_render = copy.deepcopy(preds["implicitron_render"]) + per_batch_eval_results.append( + evaluate.eval_batch( + frame_data, + implicitron_render, + bg_color="black", + lpips_model=lpips_model, + source_cameras=all_train_cameras, + ) + ) + + _, category_result = evaluate.summarize_nvs_eval_results( + per_batch_eval_results, task, self.camera_difficulty_bin_breaks + ) + + results = category_result["results"] + if dump_to_json: + _dump_to_json(epoch, exp_dir, results) + + return category_result["results"] + + +def _dump_to_json( + epoch: Optional[int], exp_dir: Optional[str], results: List[Dict[str, Any]] +) -> None: + if epoch is not None: + for r in results: + r["eval_epoch"] = int(epoch) + logger.info("Evaluation results") + + evaluate.pretty_print_nvs_metrics(results) + if exp_dir is None: + raise ValueError("Cannot save results to json without a specified save path.") + with open(os.path.join(exp_dir, "results_test.json"), "w") as f: + json.dump(results, f) + + +def _get_eval_frame_data(frame_data: Any) -> Any: + """ + Masks the unknown image data to make sure we cannot use it at model evaluation time. + """ + frame_data_for_eval = copy.deepcopy(frame_data) + is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as( + frame_data.image_rgb + )[:, None, None, None] + for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"): + value_masked = getattr(frame_data_for_eval, k).clone() * is_known + setattr(frame_data_for_eval, k, value_masked) + return frame_data_for_eval diff --git a/pytorch3d/implicitron/models/base_model.py b/pytorch3d/implicitron/models/base_model.py index ffd7d19f..2e5bce2c 100644 --- a/pytorch3d/implicitron/models/base_model.py +++ b/pytorch3d/implicitron/models/base_model.py @@ -37,10 +37,12 @@ class ImplicitronRender: ) -class ImplicitronModelBase(ReplaceableBase): +class ImplicitronModelBase(ReplaceableBase, torch.nn.Module): """ Replaceable abstract base for all image generation / rendering models. - `forward()` method produces a render with a depth map. + `forward()` method produces a render with a depth map. Derives from Module + so we can rely on basic functionality provided to torch for model + optimization. """ def __init__(self) -> None: diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index a0aa96d7..7b85780a 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -16,10 +16,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union import torch import tqdm -from pytorch3d.implicitron.models.metrics import ( # noqa - RegularizationMetrics, +from pytorch3d.implicitron.models.metrics import ( RegularizationMetricsBase, - ViewMetrics, ViewMetricsBase, ) from pytorch3d.implicitron.tools import image_utils, vis_utils @@ -67,7 +65,7 @@ logger = logging.getLogger(__name__) @registry.register -class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 +class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 """ GenericModel is a wrapper for the neural implicit rendering and reconstruction pipeline which consists diff --git a/pytorch3d/implicitron/models/model_dbir.py b/pytorch3d/implicitron/models/model_dbir.py index 62dd0a7f..c14dab9d 100644 --- a/pytorch3d/implicitron/models/model_dbir.py +++ b/pytorch3d/implicitron/models/model_dbir.py @@ -22,7 +22,7 @@ from .renderer.base import EvaluationMode @registry.register -class ModelDBIR(ImplicitronModelBase, torch.nn.Module): +class ModelDBIR(ImplicitronModelBase): """ A simple depth-based image rendering model. diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index d1f8201d..b876d906 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -218,7 +218,7 @@ class AdaptiveRaySampler(AbstractMaskRaySampler): def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: """ - Returns the adaptivelly calculated near/far planes. + Returns the adaptively calculated near/far planes. """ min_depth, max_depth = camera_utils.get_min_max_depth_bounds( cameras, self._scene_center, self.scene_extent diff --git a/pytorch3d/implicitron/tools/stats.py b/pytorch3d/implicitron/tools/stats.py index 4f01f55f..6907dda0 100644 --- a/pytorch3d/implicitron/tools/stats.py +++ b/pytorch3d/implicitron/tools/stats.py @@ -74,6 +74,7 @@ class Stats(object): """ stats logging object useful for gathering statistics of training a deep net in pytorch Example: + ``` # init stats structure that logs statistics 'objective' and 'top1e' stats = Stats( ('objective','top1e') ) network = init_net() # init a pytorch module (=nueral network) @@ -94,6 +95,7 @@ class Stats(object): # stores the training plots into '/tmp/epoch_stats.pdf' # and plots into a visdom server running at localhost (if running) stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') + ``` """ def __init__( diff --git a/pytorch3d/implicitron/tools/vis_utils.py b/pytorch3d/implicitron/tools/vis_utils.py index 0672257f..585a15a3 100644 --- a/pytorch3d/implicitron/tools/vis_utils.py +++ b/pytorch3d/implicitron/tools/vis_utils.py @@ -14,20 +14,22 @@ from visdom import Visdom logger = logging.getLogger(__name__) -def get_visdom_env(cfg): +def get_visdom_env(visdom_env: str, exp_dir: str) -> str: """ Parse out visdom environment name from the input config. Args: - cfg: The global config file. + visdom_env: Name of the wisdom environment, could be empty string. + exp_dir: Root experiment directory. Returns: - visdom_env: The name of the visdom environment. + visdom_env: The name of the visdom environment. If the given visdom_env is + empty, return the name of the bottom directory in exp_dir. """ - if len(cfg.visdom_env) == 0: - visdom_env = cfg.exp_dir.split("/")[-1] + if len(visdom_env) == 0: + visdom_env = exp_dir.split("/")[-1] else: - visdom_env = cfg.visdom_env + visdom_env = visdom_env return visdom_env