Replace pluggable components to create a proper Configurable hierarchy.

Summary:
This large diff rewrites a significant portion of Implicitron's config hierarchy. The new hierarchy, and some of the default implementation classes, are as follows:
```
Experiment
    data_source: ImplicitronDataSource
        dataset_map_provider
        data_loader_map_provider
    model_factory: ImplicitronModelFactory
        model: GenericModel
    optimizer_factory: ImplicitronOptimizerFactory
    training_loop: ImplicitronTrainingLoop
        evaluator: ImplicitronEvaluator
```

1) Experiment (used to be ExperimentConfig) is now a top-level Configurable and contains as members mainly (mostly new) high-level factory Configurables.
2) Experiment's job is to run factories, do some accelerate setup and then pass the results to the main training loop.
3) ImplicitronOptimizerFactory and ImplicitronModelFactory are new high-level factories that create the optimizer, scheduler, model, and stats objects.
4) TrainingLoop is a new configurable that runs the main training loop and the inner train-validate step.
5) Evaluator is a new configurable that TrainingLoop uses to run validation/test steps.
6) GenericModel is not the only model choice anymore. Instead, ImplicitronModelBase (by default instantiated with GenericModel) is a member of Experiment and can be easily replaced by a custom implementation by the user.

All the new Configurables are children of ReplaceableBase, and can be easily replaced with custom implementations.

In addition, I added support for the exponential LR schedule, updated the config files and the test, as well as added a config file that reproduces NERF results and a test to run the repro experiment.

Reviewed By: bottler

Differential Revision: D37723227

fbshipit-source-id: b36bee880d6aa53efdd2abfaae4489d8ab1e8a27
This commit is contained in:
Krzysztof Chalupka 2022-07-29 17:32:51 -07:00 committed by Facebook GitHub Bot
parent 6b481595f0
commit 1b0584f7bd
42 changed files with 2045 additions and 1478 deletions

View File

@ -2,10 +2,10 @@ defaults:
- default_config - default_config
- _self_ - _self_
exp_dir: ./data/exps/base/ exp_dir: ./data/exps/base/
architecture: generic training_loop_ImplicitronTrainingLoop_args:
visualize_interval: 0 visualize_interval: 0
visdom_port: 8097 max_epochs: 1000
data_source_args: data_source_ImplicitronDataSource_args:
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_class_type: JsonIndexDatasetMapProvider dataset_map_provider_class_type: JsonIndexDatasetMapProvider
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
@ -21,55 +21,61 @@ data_source_args:
load_point_clouds: false load_point_clouds: false
mask_depths: false mask_depths: false
mask_images: false mask_images: false
generic_model_args: model_factory_ImplicitronModelFactory_args:
loss_weights: visdom_port: 8097
loss_mask_bce: 1.0 model_GenericModel_args:
loss_prev_stage_mask_bce: 1.0 loss_weights:
loss_autodecoder_norm: 0.01 loss_mask_bce: 1.0
loss_rgb_mse: 1.0 loss_prev_stage_mask_bce: 1.0
loss_prev_stage_rgb_mse: 1.0 loss_autodecoder_norm: 0.01
output_rasterized_mc: false loss_rgb_mse: 1.0
chunk_size_grid: 102400 loss_prev_stage_rgb_mse: 1.0
render_image_height: 400 output_rasterized_mc: false
render_image_width: 400 chunk_size_grid: 102400
num_passes: 2 render_image_height: 400
implicit_function_NeuralRadianceFieldImplicitFunction_args: render_image_width: 400
n_harmonic_functions_xyz: 10 num_passes: 2
n_harmonic_functions_dir: 4 implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_hidden_neurons_xyz: 256 n_harmonic_functions_xyz: 10
n_hidden_neurons_dir: 128 n_harmonic_functions_dir: 4
n_layers_xyz: 8 n_hidden_neurons_xyz: 256
append_xyz: n_hidden_neurons_dir: 128
- 5 n_layers_xyz: 8
latent_dim: 0 append_xyz:
raysampler_AdaptiveRaySampler_args: - 5
n_rays_per_image_sampled_from_mask: 1024 latent_dim: 0
scene_extent: 8.0 raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 64 n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_evaluation: 64 scene_extent: 8.0
stratified_point_sampling_training: true n_pts_per_ray_training: 64
stratified_point_sampling_evaluation: false n_pts_per_ray_evaluation: 64
renderer_MultiPassEmissionAbsorptionRenderer_args: stratified_point_sampling_training: true
n_pts_per_ray_fine_training: 64 stratified_point_sampling_evaluation: false
n_pts_per_ray_fine_evaluation: 64 renderer_MultiPassEmissionAbsorptionRenderer_args:
append_coarse_samples_to_fine: true n_pts_per_ray_fine_training: 64
density_noise_std_train: 1.0 n_pts_per_ray_fine_evaluation: 64
view_pooler_args: append_coarse_samples_to_fine: true
view_sampler_args: density_noise_std_train: 1.0
masked_sampling: false view_pooler_args:
image_feature_extractor_ResNetFeatureExtractor_args: view_sampler_args:
stages: masked_sampling: false
- 1 image_feature_extractor_ResNetFeatureExtractor_args:
- 2 stages:
- 3 - 1
- 4 - 2
proj_dim: 16 - 3
image_rescale: 0.32 - 4
first_max_pool: false proj_dim: 16
solver_args: image_rescale: 0.32
breed: adam first_max_pool: false
lr: 0.0005 optimizer_factory_ImplicitronOptimizerFactory_args:
lr_policy: multistep breed: Adam
max_epochs: 2000
momentum: 0.9
weight_decay: 0.0 weight_decay: 0.0
lr_policy: MultiStepLR
multistep_lr_milestones: []
lr: 0.0005
gamma: 0.1
momentum: 0.9
betas:
- 0.9
- 0.999

View File

@ -1,17 +1,18 @@
generic_model_args: model_factory_ImplicitronModelFactory_args:
image_feature_extractor_class_type: ResNetFeatureExtractor model_GenericModel_args:
image_feature_extractor_ResNetFeatureExtractor_args: image_feature_extractor_class_type: ResNetFeatureExtractor
add_images: true image_feature_extractor_ResNetFeatureExtractor_args:
add_masks: true add_images: true
first_max_pool: true add_masks: true
image_rescale: 0.375 first_max_pool: true
l2_norm: true image_rescale: 0.375
name: resnet34 l2_norm: true
normalize_image: true name: resnet34
pretrained: true normalize_image: true
stages: pretrained: true
- 1 stages:
- 2 - 1
- 3 - 2
- 4 - 3
proj_dim: 32 - 4
proj_dim: 32

View File

@ -1,17 +1,18 @@
generic_model_args: model_factory_ImplicitronModelFactory_args:
image_feature_extractor_class_type: ResNetFeatureExtractor model_GenericModel_args:
image_feature_extractor_ResNetFeatureExtractor_args: image_feature_extractor_class_type: ResNetFeatureExtractor
add_images: true image_feature_extractor_ResNetFeatureExtractor_args:
add_masks: true add_images: true
first_max_pool: false add_masks: true
image_rescale: 0.375 first_max_pool: false
l2_norm: true image_rescale: 0.375
name: resnet34 l2_norm: true
normalize_image: true name: resnet34
pretrained: true normalize_image: true
stages: pretrained: true
- 1 stages:
- 2 - 1
- 3 - 2
- 4 - 3
proj_dim: 16 - 4
proj_dim: 16

View File

@ -1,18 +1,19 @@
generic_model_args: model_factory_ImplicitronModelFactory_args:
image_feature_extractor_class_type: ResNetFeatureExtractor model_GenericModel_args:
image_feature_extractor_ResNetFeatureExtractor_args: image_feature_extractor_class_type: ResNetFeatureExtractor
stages: image_feature_extractor_ResNetFeatureExtractor_args:
- 1 stages:
- 2 - 1
- 3 - 2
first_max_pool: false - 3
proj_dim: -1 first_max_pool: false
l2_norm: false proj_dim: -1
image_rescale: 0.375 l2_norm: false
name: resnet34 image_rescale: 0.375
normalize_image: true name: resnet34
pretrained: true normalize_image: true
view_pooler_args: pretrained: true
feature_aggregator_AngleWeightedReductionFeatureAggregator_args: view_pooler_args:
reduction_functions: feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
- AVG reduction_functions:
- AVG

View File

@ -1,7 +1,7 @@
defaults: defaults:
- repro_base.yaml - repro_base.yaml
- _self_ - _self_
data_source_args: data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10 batch_size: 10
dataset_length_train: 1000 dataset_length_train: 1000
@ -26,10 +26,12 @@ data_source_args:
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_on_train: true test_on_train: true
test_restrict_sequence_id: 0 test_restrict_sequence_id: 0
solver_args: optimizer_factory_ImplicitronOptimizerFactory_args:
max_epochs: 3000 multistep_lr_milestones:
milestones:
- 1000 - 1000
camera_difficulty_bin_breaks: training_loop_ImplicitronTrainingLoop_args:
- 0.666667 max_epochs: 3000
- 0.833334 evaluator_ImplicitronEvaluator_args:
camera_difficulty_bin_breaks:
- 0.666667
- 0.833334

View File

@ -1,65 +1,66 @@
defaults: defaults:
- repro_multiseq_base.yaml - repro_multiseq_base.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
loss_weights: model_GenericModel_args:
loss_mask_bce: 100.0 loss_weights:
loss_kl: 0.0 loss_mask_bce: 100.0
loss_rgb_mse: 1.0 loss_kl: 0.0
loss_eikonal: 0.1 loss_rgb_mse: 1.0
chunk_size_grid: 65536 loss_eikonal: 0.1
num_passes: 1 chunk_size_grid: 65536
output_rasterized_mc: true num_passes: 1
sampling_mode_training: mask_sample output_rasterized_mc: true
global_encoder_class_type: SequenceAutodecoder sampling_mode_training: mask_sample
global_encoder_SequenceAutodecoder_args: global_encoder_class_type: SequenceAutodecoder
autodecoder_args: global_encoder_SequenceAutodecoder_args:
n_instances: 20000 autodecoder_args:
init_scale: 1.0 n_instances: 20000
encoding_dim: 256 init_scale: 1.0
implicit_function_IdrFeatureField_args: encoding_dim: 256
n_harmonic_functions_xyz: 6 implicit_function_IdrFeatureField_args:
bias: 0.6 n_harmonic_functions_xyz: 6
d_in: 3 bias: 0.6
d_out: 1 d_in: 3
dims: d_out: 1
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true
renderer_SignedDistanceFunctionRenderer_args:
ray_tracer_args:
line_search_step: 0.5
line_step_iters: 3
n_secant_steps: 8
n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims: dims:
- 512 - 512
- 512 - 512
- 512 - 512
- 512 - 512
mode: idr - 512
n_harmonic_functions_dir: 4 - 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0 pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true weight_norm: true
raysampler_AdaptiveRaySampler_args: renderer_SignedDistanceFunctionRenderer_args:
n_rays_per_image_sampled_from_mask: 1024 ray_tracer_args:
n_pts_per_ray_training: 0 line_search_step: 0.5
n_pts_per_ray_evaluation: 0 line_step_iters: 3
scene_extent: 8.0 n_secant_steps: 8
renderer_class_type: SignedDistanceFunctionRenderer n_steps: 100
implicit_function_class_type: IdrFeatureField object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
mode: idr
n_harmonic_functions_dir: 4
pooled_feature_dim: 0
weight_norm: true
raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0
scene_extent: 8.0
renderer_class_type: SignedDistanceFunctionRenderer
implicit_function_class_type: IdrFeatureField

View File

@ -1,11 +1,12 @@
defaults: defaults:
- repro_multiseq_base.yaml - repro_multiseq_base.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
chunk_size_grid: 16000 model_GenericModel_args:
view_pooler_enabled: false chunk_size_grid: 16000
global_encoder_class_type: SequenceAutodecoder view_pooler_enabled: false
global_encoder_SequenceAutodecoder_args: global_encoder_class_type: SequenceAutodecoder
autodecoder_args: global_encoder_SequenceAutodecoder_args:
n_instances: 20000 autodecoder_args:
encoding_dim: 256 n_instances: 20000
encoding_dim: 256

View File

@ -2,9 +2,11 @@ defaults:
- repro_multiseq_base.yaml - repro_multiseq_base.yaml
- repro_feat_extractor_unnormed.yaml - repro_feat_extractor_unnormed.yaml
- _self_ - _self_
clip_grad: 1.0 model_factory_ImplicitronModelFactory_args:
generic_model_args: model_GenericModel_args:
chunk_size_grid: 16000 chunk_size_grid: 16000
view_pooler_enabled: true view_pooler_enabled: true
raysampler_AdaptiveRaySampler_args: raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 850 n_rays_per_image_sampled_from_mask: 850
training_loop_ImplicitronTrainingLoop_args:
clip_grad: 1.0

View File

@ -2,16 +2,17 @@ defaults:
- repro_multiseq_base.yaml - repro_multiseq_base.yaml
- repro_feat_extractor_transformer.yaml - repro_feat_extractor_transformer.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
chunk_size_grid: 16000 model_GenericModel_args:
raysampler_AdaptiveRaySampler_args: chunk_size_grid: 16000
n_rays_per_image_sampled_from_mask: 800 raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 32 n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_evaluation: 32 n_pts_per_ray_training: 32
renderer_MultiPassEmissionAbsorptionRenderer_args: n_pts_per_ray_evaluation: 32
n_pts_per_ray_fine_training: 16 renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_evaluation: 16 n_pts_per_ray_fine_training: 16
implicit_function_class_type: NeRFormerImplicitFunction n_pts_per_ray_fine_evaluation: 16
view_pooler_enabled: true implicit_function_class_type: NeRFormerImplicitFunction
view_pooler_args: view_pooler_enabled: true
feature_aggregator_class_type: IdentityFeatureAggregator view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@ -1,6 +1,7 @@
defaults: defaults:
- repro_multiseq_nerformer.yaml - repro_multiseq_nerformer.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
view_pooler_args: model_GenericModel_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator view_pooler_args:
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator

View File

@ -1,34 +1,35 @@
defaults: defaults:
- repro_multiseq_base.yaml - repro_multiseq_base.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
chunk_size_grid: 16000 model_GenericModel_args:
view_pooler_enabled: false chunk_size_grid: 16000
n_train_target_views: -1 view_pooler_enabled: false
num_passes: 1 n_train_target_views: -1
loss_weights: num_passes: 1
loss_rgb_mse: 200.0 loss_weights:
loss_prev_stage_rgb_mse: 0.0 loss_rgb_mse: 200.0
loss_mask_bce: 1.0 loss_prev_stage_rgb_mse: 0.0
loss_prev_stage_mask_bce: 0.0 loss_mask_bce: 1.0
loss_autodecoder_norm: 0.001 loss_prev_stage_mask_bce: 0.0
depth_neg_penalty: 10000.0 loss_autodecoder_norm: 0.001
global_encoder_class_type: SequenceAutodecoder depth_neg_penalty: 10000.0
global_encoder_SequenceAutodecoder_args: global_encoder_class_type: SequenceAutodecoder
autodecoder_args: global_encoder_SequenceAutodecoder_args:
encoding_dim: 256 autodecoder_args:
n_instances: 20000 encoding_dim: 256
raysampler_class_type: NearFarRaySampler n_instances: 20000
raysampler_NearFarRaySampler_args: raysampler_class_type: NearFarRaySampler
n_rays_per_image_sampled_from_mask: 2048 raysampler_NearFarRaySampler_args:
min_depth: 0.05 n_rays_per_image_sampled_from_mask: 2048
max_depth: 0.05 min_depth: 0.05
n_pts_per_ray_training: 1 max_depth: 0.05
n_pts_per_ray_evaluation: 1 n_pts_per_ray_training: 1
stratified_point_sampling_training: false n_pts_per_ray_evaluation: 1
stratified_point_sampling_evaluation: false stratified_point_sampling_training: false
renderer_class_type: LSTMRenderer stratified_point_sampling_evaluation: false
implicit_function_class_type: SRNHyperNetImplicitFunction renderer_class_type: LSTMRenderer
solver_args: implicit_function_class_type: SRNHyperNetImplicitFunction
breed: adam optimizer_factory_ImplicitronOptimizerFactory_args:
breed: Adam
lr: 5.0e-05 lr: 5.0e-05

View File

@ -1,10 +1,11 @@
defaults: defaults:
- repro_multiseq_srn_ad_hypernet.yaml - repro_multiseq_srn_ad_hypernet.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
num_passes: 1 model_GenericModel_args:
implicit_function_SRNHyperNetImplicitFunction_args: num_passes: 1
pixel_generator_args: implicit_function_SRNHyperNetImplicitFunction_args:
n_harmonic_functions: 0 pixel_generator_args:
hypernet_args: n_harmonic_functions: 0
n_harmonic_functions: 0 hypernet_args:
n_harmonic_functions: 0

View File

@ -2,29 +2,30 @@ defaults:
- repro_multiseq_base.yaml - repro_multiseq_base.yaml
- repro_feat_extractor_normed.yaml - repro_feat_extractor_normed.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
chunk_size_grid: 32000 model_GenericModel_args:
num_passes: 1 chunk_size_grid: 32000
n_train_target_views: -1 num_passes: 1
loss_weights: n_train_target_views: -1
loss_rgb_mse: 200.0 loss_weights:
loss_prev_stage_rgb_mse: 0.0 loss_rgb_mse: 200.0
loss_mask_bce: 1.0 loss_prev_stage_rgb_mse: 0.0
loss_prev_stage_mask_bce: 0.0 loss_mask_bce: 1.0
loss_autodecoder_norm: 0.0 loss_prev_stage_mask_bce: 0.0
depth_neg_penalty: 10000.0 loss_autodecoder_norm: 0.0
raysampler_class_type: NearFarRaySampler depth_neg_penalty: 10000.0
raysampler_NearFarRaySampler_args: raysampler_class_type: NearFarRaySampler
n_rays_per_image_sampled_from_mask: 2048 raysampler_NearFarRaySampler_args:
min_depth: 0.05 n_rays_per_image_sampled_from_mask: 2048
max_depth: 0.05 min_depth: 0.05
n_pts_per_ray_training: 1 max_depth: 0.05
n_pts_per_ray_evaluation: 1 n_pts_per_ray_training: 1
stratified_point_sampling_training: false n_pts_per_ray_evaluation: 1
stratified_point_sampling_evaluation: false stratified_point_sampling_training: false
renderer_class_type: LSTMRenderer stratified_point_sampling_evaluation: false
implicit_function_class_type: SRNImplicitFunction renderer_class_type: LSTMRenderer
view_pooler_enabled: true implicit_function_class_type: SRNImplicitFunction
solver_args: view_pooler_enabled: true
breed: adam optimizer_factory_ImplicitronOptimizerFactory_args:
breed: Adam
lr: 5.0e-05 lr: 5.0e-05

View File

@ -1,10 +1,11 @@
defaults: defaults:
- repro_multiseq_srn_wce.yaml - repro_multiseq_srn_wce.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
num_passes: 1 model_GenericModel_args:
implicit_function_SRNImplicitFunction_args: num_passes: 1
pixel_generator_args: implicit_function_SRNImplicitFunction_args:
n_harmonic_functions: 0 pixel_generator_args:
raymarch_function_args: n_harmonic_functions: 0
n_harmonic_functions: 0 raymarch_function_args:
n_harmonic_functions: 0

View File

@ -1,7 +1,7 @@
defaults: defaults:
- repro_base - repro_base
- _self_ - _self_
data_source_args: data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1 batch_size: 1
dataset_length_train: 1000 dataset_length_train: 1000
@ -12,28 +12,30 @@ data_source_args:
n_frames_per_sequence: -1 n_frames_per_sequence: -1
test_restrict_sequence_id: 0 test_restrict_sequence_id: 0
test_on_train: false test_on_train: false
generic_model_args: model_factory_ImplicitronModelFactory_args:
render_image_height: 800 model_GenericModel_args:
render_image_width: 800 render_image_height: 800
log_vars: render_image_width: 800
- loss_rgb_psnr_fg log_vars:
- loss_rgb_psnr - loss_rgb_psnr_fg
- loss_eikonal - loss_rgb_psnr
- loss_prev_stage_rgb_psnr - loss_eikonal
- loss_mask_bce - loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce - loss_mask_bce
- loss_rgb_mse - loss_prev_stage_mask_bce
- loss_prev_stage_rgb_mse - loss_rgb_mse
- loss_depth_abs - loss_prev_stage_rgb_mse
- loss_depth_abs_fg - loss_depth_abs
- loss_kl - loss_depth_abs_fg
- loss_mask_neg_iou - loss_kl
- objective - loss_mask_neg_iou
- epoch - objective
- sec/it - epoch
solver_args: - sec/it
optimizer_factory_ImplicitronOptimizerFactory_args:
lr: 0.0005 lr: 0.0005
max_epochs: 400 multistep_lr_milestones:
milestones:
- 200 - 200
- 300 - 300
training_loop_ImplicitronTrainingLoop_args:
max_epochs: 400

View File

@ -1,57 +1,58 @@
defaults: defaults:
- repro_singleseq_base - repro_singleseq_base
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
loss_weights: model_GenericModel_args:
loss_mask_bce: 100.0 loss_weights:
loss_kl: 0.0 loss_mask_bce: 100.0
loss_rgb_mse: 1.0 loss_kl: 0.0
loss_eikonal: 0.1 loss_rgb_mse: 1.0
chunk_size_grid: 65536 loss_eikonal: 0.1
num_passes: 1 chunk_size_grid: 65536
view_pooler_enabled: false num_passes: 1
implicit_function_IdrFeatureField_args: view_pooler_enabled: false
n_harmonic_functions_xyz: 6 implicit_function_IdrFeatureField_args:
bias: 0.6 n_harmonic_functions_xyz: 6
d_in: 3 bias: 0.6
d_out: 1 d_in: 3
dims: d_out: 1
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true
renderer_SignedDistanceFunctionRenderer_args:
ray_tracer_args:
line_search_step: 0.5
line_step_iters: 3
n_secant_steps: 8
n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims: dims:
- 512 - 512
- 512 - 512
- 512 - 512
- 512 - 512
mode: idr - 512
n_harmonic_functions_dir: 4 - 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0 pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true weight_norm: true
raysampler_AdaptiveRaySampler_args: renderer_SignedDistanceFunctionRenderer_args:
n_rays_per_image_sampled_from_mask: 1024 ray_tracer_args:
n_pts_per_ray_training: 0 line_search_step: 0.5
n_pts_per_ray_evaluation: 0 line_step_iters: 3
renderer_class_type: SignedDistanceFunctionRenderer n_secant_steps: 8
implicit_function_class_type: IdrFeatureField n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
mode: idr
n_harmonic_functions_dir: 4
pooled_feature_dim: 0
weight_norm: true
raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0
renderer_class_type: SignedDistanceFunctionRenderer
implicit_function_class_type: IdrFeatureField

View File

@ -0,0 +1,38 @@
defaults:
- repro_singleseq_base
- _self_
exp_dir: "./data/nerf_blender_publ/${oc.env:BLENDER_SINGLESEQ_CLASS}"
data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
dataset_length_train: 100
dataset_map_provider_class_type: BlenderDatasetMapProvider
dataset_map_provider_BlenderDatasetMapProvider_args:
base_dir: ${oc.env:BLENDER_DATASET_ROOT}
object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS}
path_manager_factory_class_type: PathManagerFactory
n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args:
silence_logs: true
model_factory_ImplicitronModelFactory_args:
model_GenericModel_args:
raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 4096
scene_extent: 2.0
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 128
n_pts_per_ray_fine_evaluation: 128
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.00
optimizer_factory_ImplicitronOptimizerFactory_args:
exponential_lr_step_size: 2000
training_loop_ImplicitronTrainingLoop_args:
max_epochs: 2000
visualize_interval: 0
validation_interval: 30

View File

@ -2,8 +2,9 @@ defaults:
- repro_singleseq_wce_base.yaml - repro_singleseq_wce_base.yaml
- repro_feat_extractor_unnormed.yaml - repro_feat_extractor_unnormed.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
chunk_size_grid: 16000 model_GenericModel_args:
view_pooler_enabled: true chunk_size_grid: 16000
raysampler_AdaptiveRaySampler_args: view_pooler_enabled: true
n_rays_per_image_sampled_from_mask: 850 raysampler_AdaptiveRaySampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@ -2,16 +2,17 @@ defaults:
- repro_singleseq_wce_base.yaml - repro_singleseq_wce_base.yaml
- repro_feat_extractor_transformer.yaml - repro_feat_extractor_transformer.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
chunk_size_grid: 16000 model_GenericModel_args:
view_pooler_enabled: true chunk_size_grid: 16000
implicit_function_class_type: NeRFormerImplicitFunction view_pooler_enabled: true
raysampler_AdaptiveRaySampler_args: implicit_function_class_type: NeRFormerImplicitFunction
n_rays_per_image_sampled_from_mask: 800 raysampler_AdaptiveRaySampler_args:
n_pts_per_ray_training: 32 n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_evaluation: 32 n_pts_per_ray_training: 32
renderer_MultiPassEmissionAbsorptionRenderer_args: n_pts_per_ray_evaluation: 32
n_pts_per_ray_fine_training: 16 renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_evaluation: 16 n_pts_per_ray_fine_training: 16
view_pooler_args: n_pts_per_ray_fine_evaluation: 16
feature_aggregator_class_type: IdentityFeatureAggregator view_pooler_args:
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@ -1,28 +1,29 @@
defaults: defaults:
- repro_singleseq_base.yaml - repro_singleseq_base.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
num_passes: 1 model_GenericModel_args:
chunk_size_grid: 32000 num_passes: 1
view_pooler_enabled: false chunk_size_grid: 32000
loss_weights: view_pooler_enabled: false
loss_rgb_mse: 200.0 loss_weights:
loss_prev_stage_rgb_mse: 0.0 loss_rgb_mse: 200.0
loss_mask_bce: 1.0 loss_prev_stage_rgb_mse: 0.0
loss_prev_stage_mask_bce: 0.0 loss_mask_bce: 1.0
loss_autodecoder_norm: 0.0 loss_prev_stage_mask_bce: 0.0
depth_neg_penalty: 10000.0 loss_autodecoder_norm: 0.0
raysampler_class_type: NearFarRaySampler depth_neg_penalty: 10000.0
raysampler_NearFarRaySampler_args: raysampler_class_type: NearFarRaySampler
n_rays_per_image_sampled_from_mask: 2048 raysampler_NearFarRaySampler_args:
min_depth: 0.05 n_rays_per_image_sampled_from_mask: 2048
max_depth: 0.05 min_depth: 0.05
n_pts_per_ray_training: 1 max_depth: 0.05
n_pts_per_ray_evaluation: 1 n_pts_per_ray_training: 1
stratified_point_sampling_training: false n_pts_per_ray_evaluation: 1
stratified_point_sampling_evaluation: false stratified_point_sampling_training: false
renderer_class_type: LSTMRenderer stratified_point_sampling_evaluation: false
implicit_function_class_type: SRNImplicitFunction renderer_class_type: LSTMRenderer
solver_args: implicit_function_class_type: SRNImplicitFunction
breed: adam optimizer_factory_ImplicitronOptimizerFactory_args:
breed: Adam
lr: 5.0e-05 lr: 5.0e-05

View File

@ -1,10 +1,11 @@
defaults: defaults:
- repro_singleseq_srn.yaml - repro_singleseq_srn.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
num_passes: 1 model_GenericModel_args:
implicit_function_SRNImplicitFunction_args: num_passes: 1
pixel_generator_args: implicit_function_SRNImplicitFunction_args:
n_harmonic_functions: 0 pixel_generator_args:
raymarch_function_args: n_harmonic_functions: 0
n_harmonic_functions: 0 raymarch_function_args:
n_harmonic_functions: 0

View File

@ -2,28 +2,29 @@ defaults:
- repro_singleseq_wce_base - repro_singleseq_wce_base
- repro_feat_extractor_normed.yaml - repro_feat_extractor_normed.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
num_passes: 1 model_GenericModel_args:
chunk_size_grid: 32000 num_passes: 1
view_pooler_enabled: true chunk_size_grid: 32000
loss_weights: view_pooler_enabled: true
loss_rgb_mse: 200.0 loss_weights:
loss_prev_stage_rgb_mse: 0.0 loss_rgb_mse: 200.0
loss_mask_bce: 1.0 loss_prev_stage_rgb_mse: 0.0
loss_prev_stage_mask_bce: 0.0 loss_mask_bce: 1.0
loss_autodecoder_norm: 0.0 loss_prev_stage_mask_bce: 0.0
depth_neg_penalty: 10000.0 loss_autodecoder_norm: 0.0
raysampler_class_type: NearFarRaySampler depth_neg_penalty: 10000.0
raysampler_NearFarRaySampler_args: raysampler_class_type: NearFarRaySampler
n_rays_per_image_sampled_from_mask: 2048 raysampler_NearFarRaySampler_args:
min_depth: 0.05 n_rays_per_image_sampled_from_mask: 2048
max_depth: 0.05 min_depth: 0.05
n_pts_per_ray_training: 1 max_depth: 0.05
n_pts_per_ray_evaluation: 1 n_pts_per_ray_training: 1
stratified_point_sampling_training: false n_pts_per_ray_evaluation: 1
stratified_point_sampling_evaluation: false stratified_point_sampling_training: false
renderer_class_type: LSTMRenderer stratified_point_sampling_evaluation: false
implicit_function_class_type: SRNImplicitFunction renderer_class_type: LSTMRenderer
solver_args: implicit_function_class_type: SRNImplicitFunction
breed: adam optimizer_factory_ImplicitronOptimizerFactory_args:
breed: Adam
lr: 5.0e-05 lr: 5.0e-05

View File

@ -1,10 +1,11 @@
defaults: defaults:
- repro_singleseq_srn_wce.yaml - repro_singleseq_srn_wce.yaml
- _self_ - _self_
generic_model_args: model_factory_ImplicitronModelFactory_args:
num_passes: 1 model_GenericModel_args:
implicit_function_SRNImplicitFunction_args: num_passes: 1
pixel_generator_args: implicit_function_SRNImplicitFunction_args:
n_harmonic_functions: 0 pixel_generator_args:
raymarch_function_args: n_harmonic_functions: 0
n_harmonic_functions: 0 raymarch_function_args:
n_harmonic_functions: 0

View File

@ -1,7 +1,7 @@
defaults: defaults:
- repro_singleseq_base - repro_singleseq_base
- _self_ - _self_
data_source_args: data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10 batch_size: 10
dataset_length_train: 1000 dataset_length_train: 1000

View File

@ -8,27 +8,28 @@
"""" """"
This file is the entry point for launching experiments with Implicitron. This file is the entry point for launching experiments with Implicitron.
Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
pass, visualizations and metric printing
Launch Training Launch Training
--------------- ---------------
Experiment config .yaml files are located in the Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch `projects/implicitron_trainer/configs` folder. To launch an experiment,
an experiment, specify the name of the file. Specific config values can specify the name of the file. Specific config values can also be overridden
also be overridden from the command line, for example: from the command line, for example:
``` ```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84 ./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
``` ```
To run an experiment on a specific GPU, specify the `gpu_idx` key To run an experiment on a specific GPU, specify the `gpu_idx` key in the
in the config file / CLI. To run on a different device, specify the config file / CLI. To run on a different device, specify the device in
device in `run_training`. `run_training`.
Main functions
---------------
- The Experiment class defines `run` which creates the model, optimizer, and other
objects used in training, then starts TrainingLoop's `run` function.
- TrainingLoop takes care of the actual training logic: forward and backward passes,
evaluation and testing, as well as model checkpointing, visualization, and metric
printing.
Outputs Outputs
-------- --------
@ -45,43 +46,38 @@ The outputs of the experiment are saved and logged in multiple ways:
config file. config file.
""" """
import copy
import json
import logging import logging
import os import os
import random
import time
import warnings import warnings
from typing import Any, Dict, Optional, Tuple
from dataclasses import field
import hydra import hydra
import lpips
import numpy as np
import torch import torch
import tqdm
from accelerate import Accelerator from accelerate import Accelerator
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from packaging import version from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap from pytorch3d.implicitron.dataset.data_source import (
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task DataSourceBase,
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap ImplicitronDataSource,
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate )
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase
from pytorch3d.implicitron.models.renderer.multipass_ea import ( from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer, MultiPassEmissionAbsorptionRenderer,
) )
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
Configurable,
expand_args_fields, expand_args_fields,
remove_unused_components, remove_unused_components,
run_auto_creation,
) )
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
from .impl.experiment_config import ExperimentConfig from .impl.model_factory import ModelFactoryBase
from .impl.optimization import init_optimizer from .impl.optimizer_factory import OptimizerFactoryBase
from .impl.training_loop import TrainingLoopBase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -100,551 +96,146 @@ except ModuleNotFoundError:
no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None no_accelerate = os.environ.get("PYTORCH3D_NO_ACCELERATE") is not None
def init_model( class Experiment(Configurable): # pyre-ignore: 13
*,
cfg: DictConfig,
accelerator: Optional[Accelerator] = None,
force_load: bool = False,
clear_stats: bool = False,
load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
""" """
Returns an instance of `GenericModel`. This class is at the top level of Implicitron's config hierarchy. Its
members are high-level components necessary for training an implicit rende-
ring network.
If `cfg.resume` is set or `force_load` is true, Members:
attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so data_source: An object that produces datasets and dataloaders.
will return the model with initial weights, unless `force_load` is passed, model_factory: An object that produces an implicit rendering model as
in which case a FileNotFoundError is raised. well as its corresponding Stats object.
optimizer_factory: An object that produces the optimizer and lr
Args: scheduler.
force_load: If true, force load model from checkpoint even if training_loop: An object that runs training given the outputs produced
cfg.resume is false. by the data_source, model_factory and optimizer_factory.
clear_stats: If true, clear the stats object loaded from checkpoint detect_anomaly: Whether torch.autograd should detect anomalies. Useful
load_model_only: If true, load only the model weights from checkpoint for debugging, but might slow down the training.
and do not load the state of the optimizer and stats. exp_dir: Root experimentation directory. Checkpoints and training stats
will be saved here.
Returns:
model: The model with optionally loaded weights from checkpoint
stats: The stats structure (optionally loaded from checkpoint)
optimizer_state: The optimizer state dict containing
`state` and `param_groups` keys (optionally loaded from checkpoint)
Raise:
FileNotFoundError if `force_load` is passed but checkpoint is not found.
""" """
# Initialize the model data_source: DataSourceBase
if cfg.architecture == "generic": data_source_class_type: str = "ImplicitronDataSource"
model = GenericModel(**cfg.generic_model_args) model_factory: ModelFactoryBase
else: model_factory_class_type: str = "ImplicitronModelFactory"
raise ValueError(f"No such arch {cfg.architecture}.") optimizer_factory: OptimizerFactoryBase
optimizer_factory_class_type: str = "ImplicitronOptimizerFactory"
training_loop: TrainingLoopBase
training_loop_class_type: str = "ImplicitronTrainingLoop"
# Determine the network outputs that should be logged detect_anomaly: bool = False
if hasattr(model, "log_vars"): exp_dir: str = "./data/default_experiment/"
log_vars = copy.deepcopy(list(model.log_vars))
else:
log_vars = ["objective"]
visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts" hydra: dict = field(
default_factory=lambda: {
# Init the stats struct "run": {"dir": "."}, # Make hydra not change the working dir.
stats = Stats( "output_subdir": None, # disable storing the .hydra logs
log_vars, }
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=cfg.visdom_server,
visdom_port=cfg.visdom_port,
) )
# Retrieve the last checkpoint def __post_init__(self):
if cfg.resume_epoch > 0: run_auto_creation(self)
model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(cfg.exp_dir)
optimizer_state = None def run(self) -> None:
if model_path is not None: # Make sure the config settings are self-consistent.
logger.info("found previous model %s" % model_path) self._check_config_consistent()
if force_load or cfg.resume:
logger.info(" -> resuming")
map_location = None # Initialize the accelerator if desired.
if accelerator is not None and not accelerator.is_local_main_process: if no_accelerate:
map_location = { accelerator = None
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index device = torch.device("cuda:0")
}
if load_model_only:
model_state_dict = torch.load(
model_io.get_model_path(model_path), map_location=map_location
)
stats_load, optimizer_state = None, None
else:
model_state_dict, stats_load, optimizer_state = model_io.load_model(
model_path, map_location=map_location
)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = cfg.visdom_server
stats.visdom_port = cfg.visdom_port
stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
try:
# TODO: fix on creation of the buffers
# after the hack above, this will not pass in most cases
# ... but this is fine for now
model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info("Cant load state dict in strict mode! -> trying non-strict")
model.load_state_dict(model_state_dict, strict=False)
model.log_vars = log_vars
else: else:
logger.info(" -> but not resuming -> starting from scratch") accelerator = Accelerator(device_placement=False)
elif force_load: logger.info(accelerator.state)
raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!") device = accelerator.device
return model, stats, optimizer_state logger.info(f"Running experiment on device: {device}")
os.makedirs(self.exp_dir, exist_ok=True)
# set the debug mode
if self.detect_anomaly:
logger.info("Anomaly detection!")
torch.autograd.set_detect_anomaly(self.detect_anomaly)
def trainvalidate( # Initialize the datasets and dataloaders.
model, datasets, dataloaders = self.data_source.get_datasets_and_dataloaders()
stats,
epoch,
loader,
optimizer,
validation: bool,
*,
accelerator: Optional[Accelerator],
device: torch.device,
bp_var: str = "objective",
metric_print_interval: int = 5,
visualize_interval: int = 100,
visdom_env_root: str = "trainvalidate",
clip_grad: float = 0.0,
**kwargs,
) -> None:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args: # Init the model and the corresponding Stats object.
model: The model module optionally loaded from checkpoint model = self.model_factory(
stats: The stats struct, also optionally loaded from checkpoint accelerator=accelerator,
epoch: The index of the current epoch exp_dir=self.exp_dir,
loader: The dataloader to use for the loop
optimizer: The optimizer module optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env_root: The name of the visdom environment to use for plotting
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
device: The device on which to run the model.
Returns:
None
"""
if validation:
model.eval()
trainmode = "val"
else:
model.train()
trainmode = "train"
t_start = time.time()
# get the visdom env name
visdom_env_imgs = visdom_env_root + "_images_" + trainmode
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
# Iterate through the batches
n_batches = len(loader)
for it, net_input in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = net_input.to(device)
# run the forward pass
if not validation:
optimizer.zero_grad()
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
else:
with torch.no_grad():
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
)
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
# merge everything into one big dict
preds.update(net_input)
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if it % metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if (
(accelerator is None or accelerator.is_local_main_process)
and visualize_interval > 0
and it % visualize_interval == 0
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
model.visualize(
viz,
visdom_env_imgs,
preds,
prefix,
)
# optimizer step
if not validation:
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
if accelerator is None:
loss.backward()
else:
accelerator.backward(loss)
if clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
model.parameters(), clip_grad
)
if total_norm > clip_grad:
logger.info(
f"Clipping gradient: {total_norm}"
+ f" with coef {clip_grad / float(total_norm)}."
)
optimizer.step()
def run_training(cfg: DictConfig) -> None:
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
# Initialize the accelerator
if no_accelerate:
accelerator = None
device = torch.device("cuda:0")
else:
accelerator = Accelerator(device_placement=False)
logger.info(accelerator.state)
device = accelerator.device
logger.info(f"Running experiment on device: {device}")
# set the debug mode
if cfg.detect_anomaly:
logger.info("Anomaly detection!")
torch.autograd.set_detect_anomaly(cfg.detect_anomaly)
# create the output folder
os.makedirs(cfg.exp_dir, exist_ok=True)
_seed_all_random_engines(cfg.seed)
remove_unused_components(cfg)
# dump the exp config to the exp dir
try:
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
OmegaConf.save(config=cfg, f=cfg_filename)
except PermissionError:
warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets
datasource = ImplicitronDataSource(**cfg.data_source_args)
datasets, dataloaders = datasource.get_datasets_and_dataloaders()
task = datasource.get_task()
# init the model
model, stats, optimizer_state = init_model(cfg=cfg, accelerator=accelerator)
start_epoch = stats.epoch + 1
# move model to gpu
model.to(device)
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(
cfg,
task,
datasource.all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device,
) )
return
# init the optimizer stats = self.model_factory.load_stats(
optimizer, scheduler = init_optimizer( exp_dir=self.exp_dir,
model, log_vars=model.log_vars,
optimizer_state=optimizer_state, )
last_epoch=start_epoch, start_epoch = stats.epoch + 1
**cfg.solver_args,
)
# check the scheduler and stats have been initialized correctly model.to(device)
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
# Wrap all modules in the distributed library # Init the optimizer and LR scheduler.
# Note: we don't pass the scheduler to prepare as it optimizer, scheduler = self.optimizer_factory(
# doesn't need to be stepped at each optimizer step accelerator=accelerator,
train_loader = dataloaders.train exp_dir=self.exp_dir,
val_loader = dataloaders.val last_epoch=start_epoch,
if accelerator is not None: model=model,
( )
model,
optimizer,
train_loader,
val_loader,
) = accelerator.prepare(model, optimizer, train_loader, val_loader)
past_scheduler_lrs = [] # Wrap all modules in the distributed library
# loop through epochs # Note: we don't pass the scheduler to prepare as it
for epoch in range(start_epoch, cfg.solver_args.max_epochs): # doesn't need to be stepped at each optimizer step
# automatic new_epoch and plotting of stats at every epoch start train_loader = dataloaders.train
with stats: val_loader = dataloaders.val
test_loader = dataloaders.test
# Make sure to re-seed random generators to ensure reproducibility if accelerator is not None:
# even after restart. (
_seed_all_random_engines(cfg.seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1])
logger.info(f"scheduler lr = {cur_lr:1.2e}")
past_scheduler_lrs.append(cur_lr)
# train loop
trainvalidate(
model, model,
stats,
epoch,
train_loader,
optimizer, optimizer,
False, train_loader,
visdom_env_root=vis_utils.get_visdom_env(cfg), val_loader,
device=device, ) = accelerator.prepare(model, optimizer, train_loader, val_loader)
accelerator=accelerator,
**cfg,
)
# val loop (optional) task = self.data_source.get_task()
if val_loader is not None and epoch % cfg.validation_interval == 0: all_train_cameras = self.data_source.all_train_cameras
trainvalidate(
model,
stats,
epoch,
val_loader,
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
accelerator=accelerator,
**cfg,
)
# eval loop (optional) # Enter the main training loop.
if ( self.training_loop.run(
dataloaders.test is not None train_loader=train_loader,
and cfg.test_interval > 0 val_loader=val_loader,
and epoch % cfg.test_interval == 0 test_loader=test_loader,
): model=model,
_run_eval( optimizer=optimizer,
model, scheduler=scheduler,
datasource.all_train_cameras, all_train_cameras=all_train_cameras,
dataloaders.test, accelerator=accelerator,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
assert stats.epoch == epoch, "inconsistent stats!"
# delete previous models if required
# save model only on the main process
if cfg.store_checkpoints and (
accelerator is None or accelerator.is_local_main_process
):
if cfg.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
unwrapped_model = (
model if accelerator is None else accelerator.unwrap_model(model)
)
model_io.safe_save_model(
unwrapped_model, stats, outfile, optimizer=optimizer
)
scheduler.step()
new_lr = float(scheduler.get_last_lr()[-1])
if new_lr != cur_lr:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(
cfg,
task,
datasource.all_train_cameras,
datasets,
dataloaders,
model,
stats,
device=device, device=device,
exp_dir=self.exp_dir,
stats=stats,
task=task,
) )
def _check_config_consistent(self) -> None:
def _eval_and_dump( if hasattr(self.optimizer_factory, "resume") and hasattr(
cfg, self.model_factory, "resume"
task: Task, ):
all_train_cameras: Optional[CamerasBase], assert (
datasets: DatasetMap, # pyre-ignore [16]
dataloaders: DataLoaderMap, not self.optimizer_factory.resume
model, # pyre-ignore [16]
stats, or self.model_factory.resume
device, ), "Cannot resume the optimizer without resuming the model."
) -> None: if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
""" self.model_factory, "resume_epoch"
Run the evaluation loop with the test data loader and ):
save the predictions to the `exp_dir`. assert (
""" # pyre-ignore [16]
self.optimizer_factory.resume_epoch
dataloader = dataloaders.test # pyre-ignore [16]
== self.model_factory.resume_epoch
if dataloader is None: ), "Optimizer and model must resume from the same epoch."
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
results = _run_eval(
model,
all_train_cameras,
dataloader,
task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device,
)
# add the evaluation epoch to the results
for r in results:
r["eval_epoch"] = int(stats.epoch)
logger.info("Evaluation results")
evaluate.pretty_print_nvs_metrics(results)
with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
json.dump(results, f)
def _get_eval_frame_data(frame_data):
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval = copy.deepcopy(frame_data)
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
frame_data.image_rgb
)[:, None, None, None]
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
setattr(frame_data_for_eval, k, value_masked)
return frame_data_for_eval
def _run_eval(
model,
all_train_cameras,
loader,
task: Task,
camera_difficulty_bin_breaks: Tuple[float, float],
device,
):
"""
Run the evaluation loop on the test dataloader
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(loader):
frame_data = frame_data.to(device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
with torch.no_grad():
preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
)
# TODO: Cannot use accelerate gather for two reasons:.
# (1) TypeError: Can't apply _gpu_gather_one on object of type
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
# (2) Same error above but for frame_data which contains Cameras.
implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
implicitron_render,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_train_cameras,
)
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task, camera_difficulty_bin_breaks
)
return category_result["results"]
def _seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def _setup_envvars_for_cluster() -> bool: def _setup_envvars_for_cluster() -> bool:
@ -678,9 +269,20 @@ def _setup_envvars_for_cluster() -> bool:
return True return True
expand_args_fields(ExperimentConfig) def dump_cfg(cfg: DictConfig) -> None:
remove_unused_components(cfg)
# dump the exp config to the exp dir
os.makedirs(cfg.exp_dir, exist_ok=True)
try:
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
OmegaConf.save(config=cfg, f=cfg_filename)
except PermissionError:
warnings.warn("Can't dump config due to insufficient permissions!")
expand_args_fields(Experiment)
cs = hydra.core.config_store.ConfigStore.instance() cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig) cs.store(name="default_config", node=Experiment)
@hydra.main(config_path="./configs/", config_name="default_config") @hydra.main(config_path="./configs/", config_name="default_config")
@ -694,12 +296,14 @@ def experiment(cfg: DictConfig) -> None:
logger.info("Running locally") logger.info("Running locally")
# TODO: The following may be needed for hydra/submitit it to work # TODO: The following may be needed for hydra/submitit it to work
expand_args_fields(GenericModel) expand_args_fields(ImplicitronModelBase)
expand_args_fields(AdaptiveRaySampler) expand_args_fields(AdaptiveRaySampler)
expand_args_fields(MultiPassEmissionAbsorptionRenderer) expand_args_fields(MultiPassEmissionAbsorptionRenderer)
expand_args_fields(ImplicitronDataSource) expand_args_fields(ImplicitronDataSource)
run_training(cfg) experiment = Experiment(**cfg)
dump_cfg(cfg)
experiment.run()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Any, Dict, Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
from .optimization import init_optimizer
class ExperimentConfig(Configurable):
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: Dict[str, Any] = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)

View File

@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import List, Optional
import torch.optim
from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats
logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase):
def __call__(self, **kwargs) -> ImplicitronModelBase:
"""
Initialize the model (possibly from a previously saved state).
Returns: An instance of ImplicitronModelBase.
"""
raise NotImplementedError()
def load_stats(self, **kwargs) -> Stats:
"""
Initialize or load a Stats object.
"""
raise NotImplementedError()
@registry.register
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
"""
A factory class that initializes an implicit rendering model.
Members:
force_load: If True, throw a FileNotFoundError if `resume` is True but
a model checkpoint cannot be found.
model: An ImplicitronModelBase object.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a model with ini-
tial weights unless `force_load` is True.
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
`resume_epoch` <= 0, then resume from the latest checkpoint.
visdom_env: The name of the Visdom environment to use for plotting.
visdom_port: The Visdom port.
visdom_server: Address of the Visdom server.
"""
force_load: bool = False
model: ImplicitronModelBase
model_class_type: str = "GenericModel"
resume: bool = False
resume_epoch: int = -1
visdom_env: str = ""
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
def __post_init__(self):
run_auto_creation(self)
def __call__(
self,
exp_dir: str,
accelerator: Optional[Accelerator] = None,
) -> ImplicitronModelBase:
"""
Returns an instance of `ImplicitronModelBase`, possibly loaded from a
checkpoint (if self.resume, self.resume_epoch specify so).
Args:
exp_dir: Root experiment directory.
accelerator: An Accelerator object.
Returns:
model: The model with optionally loaded weights from checkpoint
Raise:
FileNotFoundError if `force_load` is True but checkpoint not found.
"""
# Determine the network outputs that should be logged
if hasattr(self.model, "log_vars"):
log_vars = list(self.model.log_vars) # pyre-ignore [6]
else:
log_vars = ["objective"]
# Retrieve the last checkpoint
if self.resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
logger.info("found previous model %s" % model_path)
if self.force_load or self.resume:
logger.info(" -> resuming")
map_location = None
if accelerator is not None and not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
model_state_dict = torch.load(
model_io.get_model_path(model_path), map_location=map_location
)
try:
self.model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info(
"Cant load state dict in strict mode! -> trying non-strict"
)
self.model.load_state_dict(model_state_dict, strict=False)
self.model.log_vars = log_vars # pyre-ignore [16]
else:
logger.info(" -> but not resuming -> starting from scratch")
elif self.force_load:
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
return self.model
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
clear_stats: bool = False,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
clear_stats: If True, do not load stats from the checkpoint speci-
fied by self.resume and self.resume_epoch; instead, create a
fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
if self.resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
return stats

View File

@ -1,109 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import enable_get_default_args
logger = logging.getLogger(__name__)
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: Tuple[int, ...] = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)

View File

@ -0,0 +1,197 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import Any, Dict, Optional, Tuple
import torch.optim
from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
logger = logging.getLogger(__name__)
class OptimizerFactoryBase(ReplaceableBase):
def __call__(
self, model: ImplicitronModelBase, **kwargs
) -> Tuple[torch.optim.Optimizer, Any]:
"""
Initialize the optimizer and lr scheduler.
Args:
model: The model with optionally loaded weights.
Returns:
An optimizer module (optionally loaded from a checkpoint) and
a learning rate scheduler module (should be a subclass of torch.optim's
lr_scheduler._LRScheduler).
"""
raise NotImplementedError()
@registry.register
class ImplicitronOptimizerFactory(OptimizerFactoryBase):
"""
A factory that initializes the optimizer and lr scheduler.
Members:
betas: Beta parameters for the Adam optimizer.
breed: The type of optimizer to use. We currently support SGD, Adagrad
and Adam.
exponential_lr_step_size: With Exponential policy only,
lr = lr * gamma ** (epoch/step_size)
gamma: Multiplicative factor of learning rate decay.
lr: The value for the initial learning rate.
lr_policy: The policy to use for learning rate. We currently support
MultiStepLR and Exponential policies.
momentum: A momentum value (for SGD only).
multistep_lr_milestones: With MultiStepLR policy only: list of
increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a newly initialized
optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
"""
betas: Tuple[float, ...] = (0.9, 0.999)
breed: str = "Adam"
exponential_lr_step_size: int = 250
gamma: float = 0.1
lr: float = 0.0005
lr_policy: str = "MultiStepLR"
momentum: float = 0.9
multistep_lr_milestones: tuple = ()
resume: bool = False
resume_epoch: int = -1
weight_decay: float = 0.0
def __post_init__(self):
run_auto_creation(self)
def __call__(
self,
last_epoch: int,
model: ImplicitronModelBase,
accelerator: Optional[Accelerator] = None,
exp_dir: Optional[str] = None,
**kwargs,
) -> Tuple[torch.optim.Optimizer, Any]:
"""
Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer.
Args:
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved.
model: The model with optionally loaded weights.
accelerator: An optional Accelerator instance.
exp_dir: Root experiment directory.
Returns:
An optimizer module (optionally loaded from a checkpoint) and
a learning rate scheduler module (should be a subclass of torch.optim's
lr_scheduler._LRScheduler).
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": self.lr}]
# Intialize the optimizer
if self.breed == "SGD":
optimizer = torch.optim.SGD(
p_groups,
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
elif self.breed == "Adagrad":
optimizer = torch.optim.Adagrad(
p_groups, lr=self.lr, weight_decay=self.weight_decay
)
elif self.breed == "Adam":
optimizer = torch.optim.Adam(
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
)
else:
raise ValueError("no such solver type %s" % self.breed)
logger.info(" -> solver type = %s" % self.breed)
# Load state from checkpoint
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if self.lr_policy.casefold() == "MultiStepLR".casefold():
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=self.multistep_lr_milestones,
gamma=self.gamma,
)
elif self.lr_policy.casefold() == "Exponential".casefold():
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
verbose=False,
)
else:
raise ValueError("no such lr policy %s" % self.lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning.
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
def _get_optimizer_state(
self,
exp_dir: Optional[str],
accelerator: Optional[Accelerator] = None,
) -> Optional[Dict[str, Any]]:
"""
Load an optimizer state from a checkpoint.
"""
if exp_dir is None or not self.resume:
return None
if self.resume_epoch > 0:
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
save_path = model_io.find_last_checkpoint(exp_dir)
optimizer_state = None
if save_path is not None:
logger.info(f"Found previous optimizer state {save_path}.")
logger.info(" -> resuming")
opt_path = model_io.get_optimizer_path(save_path)
if os.path.isfile(opt_path):
map_location = None
if accelerator is not None and not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
optimizer_state = torch.load(opt_path, map_location)
else:
optimizer_state = None
return optimizer_state

View File

@ -0,0 +1,365 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
import time
from typing import Any, Optional
import numpy as np
import torch
from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
class TrainingLoopBase(ReplaceableBase):
def run(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader],
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
scheduler: Any,
**kwargs,
) -> None:
raise NotImplementedError()
@registry.register
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
"""
Members:
eval_only: If True, only run evaluation using the test dataloader.
evaluator: An EvaluatorBase instance, used to evaluate training results.
max_epochs: Train for this many epochs. Note that if the model was
loaded from a checkpoint, we will restart training at the appropriate
epoch and run for (max_epochs - checkpoint_epoch) epochs.
seed: A random seed to ensure reproducibility.
store_checkpoints: If True, store model and optimizer state checkpoints.
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
to this many epochs.
test_interval: Evaluate on a test dataloader each `test_interval` epochs.
test_when_finished: If True, evaluate on a test dataloader when training
completes.
validation_interval: Validate each `validation_interval` epochs.
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
"""
# Parameters of the outer training loop.
eval_only: bool = False
evaluator: EvaluatorBase
evaluator_class_type: str = "ImplicitronEvaluator"
max_epochs: int = 1000
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
# Parameters of a single training-validation step.
clip_grad: float = 0.0
metric_print_interval: int = 5
visualize_interval: int = 1000
def __post_init__(self):
run_auto_creation(self)
def run(
self,
*,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader],
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
scheduler: Any,
accelerator: Optional[Accelerator],
all_train_cameras: Optional[CamerasBase],
device: torch.device,
exp_dir: str,
stats: Stats,
task: Task,
**kwargs,
):
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
_seed_all_random_engines(self.seed)
start_epoch = stats.epoch + 1
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
# only run evaluation on the test dataloader
if self.eval_only:
if test_loader is not None:
self.evaluator.run(
all_train_cameras=all_train_cameras,
dataloader=test_loader,
device=device,
dump_to_json=True,
epoch=stats.epoch,
exp_dir=exp_dir,
model=model,
task=task,
)
return
else:
raise ValueError(
"Cannot evaluate and dump results to json, no test data provided."
)
# loop through epochs
for epoch in range(start_epoch, self.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start
with stats:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
_seed_all_random_engines(self.seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1])
logger.debug(f"scheduler lr = {cur_lr:1.2e}")
# train loop
self._training_or_validation_epoch(
accelerator=accelerator,
device=device,
epoch=epoch,
loader=train_loader,
model=model,
optimizer=optimizer,
stats=stats,
validation=False,
)
# val loop (optional)
if val_loader is not None and epoch % self.validation_interval == 0:
self._training_or_validation_epoch(
accelerator=accelerator,
device=device,
epoch=epoch,
loader=val_loader,
model=model,
optimizer=optimizer,
stats=stats,
validation=True,
)
# eval loop (optional)
if (
test_loader is not None
and self.test_interval > 0
and epoch % self.test_interval == 0
):
self.evaluator.run(
all_train_cameras=all_train_cameras,
device=device,
dataloader=test_loader,
model=model,
task=task,
)
assert stats.epoch == epoch, "inconsistent stats!"
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats)
scheduler.step()
new_lr = float(scheduler.get_last_lr()[-1])
if new_lr != cur_lr:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if self.test_when_finished:
if test_loader is not None:
self.evaluator.run(
all_train_cameras=all_train_cameras,
device=device,
dump_to_json=True,
epoch=stats.epoch,
exp_dir=exp_dir,
dataloader=test_loader,
model=model,
task=task,
)
else:
raise ValueError(
"Cannot evaluate and dump results to json, no test data provided."
)
def _training_or_validation_epoch(
self,
epoch: int,
loader: DataLoader,
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
stats: Stats,
validation: bool,
*,
accelerator: Optional[Accelerator],
bp_var: str = "objective",
device: torch.device,
**kwargs,
) -> None:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args:
epoch: The index of the current epoch
loader: The dataloader to use for the loop
model: The model module optionally loaded from checkpoint
optimizer: The optimizer module optionally loaded from checkpoint
stats: The stats struct, also optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
accelerator: An optional Accelerator instance.
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
device: The device on which to run the model.
"""
if validation:
model.eval()
trainmode = "val"
else:
model.train()
trainmode = "train"
t_start = time.time()
# get the visdom env name
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
# Iterate through the batches
n_batches = len(loader)
for it, net_input in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = net_input.to(device)
# run the forward pass
if not validation:
optimizer.zero_grad()
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}
)
else:
with torch.no_grad():
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
)
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
# merge everything into one big dict
preds.update(net_input)
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
# pyre-ignore [16]
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if it % self.metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if (
(accelerator is None or accelerator.is_local_main_process)
and self.visualize_interval > 0
and it % self.visualize_interval == 0
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
if hasattr(model, "visualize"):
# pyre-ignore [29]
model.visualize(
viz,
visdom_env_imgs,
preds,
prefix,
)
# optimizer step
if not validation:
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
if accelerator is None:
loss.backward()
else:
accelerator.backward(loss)
if self.clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
model.parameters(), self.clip_grad
)
if total_norm > self.clip_grad:
logger.debug(
f"Clipping gradient: {total_norm}"
+ f" with coef {self.clip_grad / float(total_norm)}."
)
optimizer.step()
def _checkpoint(
self,
accelerator: Optional[Accelerator],
epoch: int,
exp_dir: str,
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
stats: Stats,
):
"""
Save a model and its corresponding Stats object to a file, if
`self.store_checkpoints` is True. In addition, if
`self.store_checkpoints_purge` is True, remove any checkpoints older
than `self.store_checkpoints_purge` epochs old.
"""
if self.store_checkpoints and (
accelerator is None or accelerator.is_local_main_process
):
if self.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - self.store_checkpoints_purge):
model_io.purge_epoch(exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(exp_dir, epoch)
unwrapped_model = (
model if accelerator is None else accelerator.unwrap_model(model)
)
model_io.safe_save_model(
unwrapped_model, stats, outfile, optimizer=optimizer
)
def _seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)

View File

@ -1,296 +1,14 @@
generic_model_args: data_source_class_type: ImplicitronDataSource
mask_images: true model_factory_class_type: ImplicitronModelFactory
mask_depths: true optimizer_factory_class_type: ImplicitronOptimizerFactory
render_image_width: 400 training_loop_class_type: ImplicitronTrainingLoop
render_image_height: 400 detect_anomaly: false
mask_threshold: 0.5 exp_dir: ./data/default_experiment/
output_rasterized_mc: false hydra:
bg_color: run:
- 0.0 dir: .
- 0.0 output_subdir: null
- 0.0 data_source_ImplicitronDataSource_args:
num_passes: 1
chunk_size_grid: 4096
render_features_dimensions: 3
tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
global_encoder_class_type: null
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: MultiPassEmissionAbsorptionRenderer
image_feature_extractor_class_type: null
view_pooler_enabled: false
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
view_metrics_class_type: ViewMetrics
regularization_metrics_class_type: RegularizationMetrics
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
global_encoder_HarmonicTimeEncoder_args:
n_harmonic_functions: 10
append_input: true
time_divisor: 1.0
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
scene_center:
- 0.0
- 0.0
- 0.0
raysampler_NearFarRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
min_depth: 0.1
max_depth: 8.0
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0
init_depth_noise_std: 0.0005
hidden_size: 16
n_feature_channels: 256
bg_color: null
verbose: false
renderer_MultiPassEmissionAbsorptionRenderer_args:
raymarcher_class_type: EmissionAbsorptionRaymarcher
n_pts_per_ray_fine_training: 64
n_pts_per_ray_fine_evaluation: 64
stratified_sampling_coarse_training: true
stratified_sampling_coarse_evaluation: false
append_coarse_samples_to_fine: true
density_noise_std_train: 0.0
return_weights: false
raymarcher_CumsumRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
background_opacity: 0.0
density_relu: true
blend_output: false
raymarcher_EmissionAbsorptionRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
background_opacity: 10000000000.0
density_relu: true
blend_output: false
renderer_SignedDistanceFunctionRenderer_args:
render_features_dimensions: 3
ray_tracer_args:
object_bounding_sphere: 1.0
sdf_threshold: 5.0e-05
line_search_step: 0.5
line_step_iters: 1
sphere_tracing_iters: 10
n_steps: 100
n_secant_steps: 8
ray_normal_coloring_network_args:
feature_vector_size: 3
mode: idr
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
weight_norm: true
n_harmonic_functions_dir: 0
pooled_feature_dim: 0
bg_color:
- 0.0
soft_mask_alpha: 50.0
image_feature_extractor_ResNetFeatureExtractor_args:
name: resnet34
pretrained: true
stages:
- 1
- 2
- 3
- 4
normalize_image: true
image_rescale: 0.16
first_max_pool: true
proj_dim: 32
l2_norm: true
add_masks: true
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_pooler_args:
feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
reduction_functions:
- AVG
- STD
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_IdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
feature_aggregator_ReductionFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
reduction_functions:
- AVG
- STD
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0
implicit_function_NeRFormerImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true
xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
append_xyz:
- 1
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true
xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
append_xyz:
- 5
implicit_function_SRNHyperNetImplicitFunction_args:
hypernet_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
n_hidden_units_hypernet: 256
n_layers_hypernet: 1
in_features: 3
out_features: 256
latent_dim_hypernet: 0
latent_dim: 0
xyz_in_camera_coords: false
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
implicit_function_SRNImplicitFunction_args:
raymarch_function_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
in_features: 3
out_features: 256
latent_dim: 0
xyz_in_camera_coords: false
raymarch_function: null
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {}
solver_args:
breed: adam
weight_decay: 0.0
lr_policy: multistep
lr: 0.0005
gamma: 0.1
momentum: 0.9
betas:
- 0.9
- 0.999
milestones: []
max_epochs: 1000
data_source_args:
dataset_map_provider_class_type: ??? dataset_map_provider_class_type: ???
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_BlenderDatasetMapProvider_args: dataset_map_provider_BlenderDatasetMapProvider_args:
@ -396,30 +114,322 @@ data_source_args:
sample_consecutive_frames: false sample_consecutive_frames: false
consecutive_frames_max_gap: 0 consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1 consecutive_frames_max_gap_seconds: 0.1
architecture: generic model_factory_ImplicitronModelFactory_args:
detect_anomaly: false force_load: false
eval_only: false model_class_type: GenericModel
exp_dir: ./data/default_experiment/ resume: false
exp_idx: 0 resume_epoch: -1
gpu_idx: 0 visdom_env: ''
metric_print_interval: 5 visdom_port: 8097
resume: true visdom_server: http://127.0.0.1
resume_epoch: -1 model_GenericModel_args:
seed: 0 mask_images: true
store_checkpoints: true mask_depths: true
store_checkpoints_purge: 1 render_image_width: 400
test_interval: -1 render_image_height: 400
test_when_finished: false mask_threshold: 0.5
validation_interval: 1 output_rasterized_mc: false
visdom_env: '' bg_color:
visdom_port: 8097 - 0.0
visdom_server: http://127.0.0.1 - 0.0
visualize_interval: 1000 - 0.0
clip_grad: 0.0 num_passes: 1
camera_difficulty_bin_breaks: chunk_size_grid: 4096
- 0.97 render_features_dimensions: 3
- 0.98 tqdm_trigger_threshold: 16
hydra: n_train_target_views: 1
run: sampling_mode_training: mask_sample
dir: . sampling_mode_evaluation: full_grid
output_subdir: null global_encoder_class_type: null
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: MultiPassEmissionAbsorptionRenderer
image_feature_extractor_class_type: null
view_pooler_enabled: false
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
view_metrics_class_type: ViewMetrics
regularization_metrics_class_type: RegularizationMetrics
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
global_encoder_HarmonicTimeEncoder_args:
n_harmonic_functions: 10
append_input: true
time_divisor: 1.0
global_encoder_SequenceAutodecoder_args:
autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
scene_center:
- 0.0
- 0.0
- 0.0
raysampler_NearFarRaySampler_args:
image_width: 400
image_height: 400
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
min_depth: 0.1
max_depth: 8.0
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0
init_depth_noise_std: 0.0005
hidden_size: 16
n_feature_channels: 256
bg_color: null
verbose: false
renderer_MultiPassEmissionAbsorptionRenderer_args:
raymarcher_class_type: EmissionAbsorptionRaymarcher
n_pts_per_ray_fine_training: 64
n_pts_per_ray_fine_evaluation: 64
stratified_sampling_coarse_training: true
stratified_sampling_coarse_evaluation: false
append_coarse_samples_to_fine: true
density_noise_std_train: 0.0
return_weights: false
raymarcher_CumsumRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
background_opacity: 0.0
density_relu: true
blend_output: false
raymarcher_EmissionAbsorptionRaymarcher_args:
surface_thickness: 1
bg_color:
- 0.0
background_opacity: 10000000000.0
density_relu: true
blend_output: false
renderer_SignedDistanceFunctionRenderer_args:
render_features_dimensions: 3
ray_tracer_args:
object_bounding_sphere: 1.0
sdf_threshold: 5.0e-05
line_search_step: 0.5
line_step_iters: 1
sphere_tracing_iters: 10
n_steps: 100
n_secant_steps: 8
ray_normal_coloring_network_args:
feature_vector_size: 3
mode: idr
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
weight_norm: true
n_harmonic_functions_dir: 0
pooled_feature_dim: 0
bg_color:
- 0.0
soft_mask_alpha: 50.0
image_feature_extractor_ResNetFeatureExtractor_args:
name: resnet34
pretrained: true
stages:
- 1
- 2
- 3
- 4
normalize_image: true
image_rescale: 0.16
first_max_pool: true
proj_dim: 32
l2_norm: true
add_masks: true
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_pooler_args:
feature_aggregator_class_type: AngleWeightedReductionFeatureAggregator
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
reduction_functions:
- AVG
- STD
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
feature_aggregator_IdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
feature_aggregator_ReductionFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
reduction_functions:
- AVG
- STD
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0
implicit_function_NeRFormerImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true
xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 2.0
n_hidden_neurons_xyz: 80
n_layers_xyz: 2
append_xyz:
- 1
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_dir: 128
latent_dim: 0
input_xyz: true
xyz_ray_dir_in_camera_coords: false
color_dim: 3
transformer_dim_down_factor: 1.0
n_hidden_neurons_xyz: 256
n_layers_xyz: 8
append_xyz:
- 5
implicit_function_SRNHyperNetImplicitFunction_args:
hypernet_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
n_hidden_units_hypernet: 256
n_layers_hypernet: 1
in_features: 3
out_features: 256
latent_dim_hypernet: 0
latent_dim: 0
xyz_in_camera_coords: false
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
implicit_function_SRNImplicitFunction_args:
raymarch_function_args:
n_harmonic_functions: 3
n_hidden_units: 256
n_layers: 2
in_features: 3
out_features: 256
latent_dim: 0
xyz_in_camera_coords: false
raymarch_function: null
pixel_generator_args:
n_harmonic_functions: 4
n_hidden_units: 256
n_hidden_units_color: 128
n_layers: 2
in_features: 256
out_features: 3
ray_dir_in_camera_coords: false
view_metrics_ViewMetrics_args: {}
regularization_metrics_RegularizationMetrics_args: {}
optimizer_factory_ImplicitronOptimizerFactory_args:
betas:
- 0.9
- 0.999
breed: Adam
exponential_lr_step_size: 250
gamma: 0.1
lr: 0.0005
lr_policy: MultiStepLR
momentum: 0.9
multistep_lr_milestones: []
resume: false
resume_epoch: -1
weight_decay: 0.0
training_loop_ImplicitronTrainingLoop_args:
eval_only: false
evaluator_class_type: ImplicitronEvaluator
max_epochs: 1000
seed: 0
store_checkpoints: true
store_checkpoints_purge: 1
test_interval: -1
test_when_finished: false
validation_interval: 1
clip_grad: 0.0
metric_print_interval: 5
visualize_interval: 1000
evaluator_ImplicitronEvaluator_args:
camera_difficulty_bin_breaks:
- 0.97
- 0.98

View File

@ -12,6 +12,7 @@ from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .. import experiment from .. import experiment
from .utils import intercept_logs
def interactive_testing_requested() -> bool: def interactive_testing_requested() -> bool:
@ -33,7 +34,10 @@ DEBUG: bool = False
# TODO: # TODO:
# - add enough files to skateboard_first_5 that this works on RE. # - add enough files to skateboard_first_5 that this works on RE.
# - share common code with PyTorch3D tests? # - share common code with PyTorch3D tests?
# - deal with the temporary output files this test creates
def _parse_float_from_log(line):
return float(line.split()[-1])
class TestExperiment(unittest.TestCase): class TestExperiment(unittest.TestCase):
@ -44,15 +48,18 @@ class TestExperiment(unittest.TestCase):
# Test making minimal changes to the dataclass defaults. # Test making minimal changes to the dataclass defaults.
if not interactive_testing_requested() or not internal: if not interactive_testing_requested() or not internal:
return return
cfg = OmegaConf.structured(experiment.ExperimentConfig)
cfg.data_source_args.dataset_map_provider_class_type = ( # Manually override config values. Note that this is not necessary out-
# side of the tests!
cfg = OmegaConf.structured(experiment.Experiment)
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
"JsonIndexDatasetMapProvider" "JsonIndexDatasetMapProvider"
) )
dataset_args = ( dataset_args = (
cfg.data_source_args.dataset_map_provider_JsonIndexDatasetMapProvider_args cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
) )
dataloader_args = ( dataloader_args = (
cfg.data_source_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
) )
dataset_args.category = "skateboard" dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0 dataset_args.test_restrict_sequence_id = 0
@ -62,18 +69,80 @@ class TestExperiment(unittest.TestCase):
dataset_args.dataset_JsonIndexDataset_args.image_width = 80 dataset_args.dataset_JsonIndexDataset_args.image_width = 80
dataloader_args.dataset_length_train = 1 dataloader_args.dataset_length_train = 1
dataloader_args.dataset_length_val = 1 dataloader_args.dataset_length_val = 1
cfg.solver_args.max_epochs = 2 cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.multistep_lr_milestones = [
0,
1,
]
experiment.run_training(cfg) if DEBUG:
experiment.dump_cfg(cfg)
with intercept_logs(
logger_name="projects.implicitron_trainer.impl.training_loop",
regexp="LR change!",
) as intercepted_logs:
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()
# Make sure LR decreased on 0th and 1st epoch 10fold.
self.assertEqual(intercepted_logs[0].split()[-1], "5e-06")
def test_exponential_lr(self):
# Test making minimal changes to the dataclass defaults.
if not interactive_testing_requested():
return
cfg = OmegaConf.structured(experiment.Experiment)
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_class_type = (
"JsonIndexDatasetMapProvider"
)
dataset_args = (
cfg.data_source_ImplicitronDataSource_args.dataset_map_provider_JsonIndexDatasetMapProvider_args
)
dataloader_args = (
cfg.data_source_ImplicitronDataSource_args.data_loader_map_provider_SequenceDataLoaderMapProvider_args
)
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
dataset_args.dataset_JsonIndexDataset_args.image_height = 80
dataset_args.dataset_JsonIndexDataset_args.image_width = 80
dataloader_args.dataset_length_train = 1
dataloader_args.dataset_length_val = 1
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 2
cfg.training_loop_ImplicitronTrainingLoop_args.store_checkpoints = False
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.lr_policy = "Exponential"
cfg.optimizer_factory_ImplicitronOptimizerFactory_args.exponential_lr_step_size = (
2
)
if DEBUG:
experiment.dump_cfg(cfg)
with intercept_logs(
logger_name="projects.implicitron_trainer.impl.training_loop",
regexp="LR change!",
) as intercepted_logs:
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()
# Make sure we followed the exponential lr schedule with gamma=0.1,
# exponential_lr_step_size=2 -- so after two epochs, should
# decrease lr 10x to 5e-5.
self.assertEqual(intercepted_logs[0].split()[-1], "0.00015811388300841897")
self.assertEqual(intercepted_logs[1].split()[-1], "5e-05")
def test_yaml_contents(self): def test_yaml_contents(self):
cfg = OmegaConf.structured(experiment.ExperimentConfig) # Check that the default config values, defined by Experiment and its
# members, is what we expect it to be.
cfg = OmegaConf.structured(experiment.Experiment)
yaml = OmegaConf.to_yaml(cfg, sort_keys=False) yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG: if DEBUG:
(DATA_DIR / "experiment.yaml").write_text(yaml) (DATA_DIR / "experiment.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text()) self.assertEqual(yaml, (DATA_DIR / "experiment.yaml").read_text())
def test_load_configs(self): def test_load_configs(self):
# Check that all the pre-prepared configs are valid.
config_files = [] config_files = []
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"): for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
@ -89,3 +158,17 @@ class TestExperiment(unittest.TestCase):
with self.subTest(file.name): with self.subTest(file.name):
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)): with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
compose(file.name) compose(file.name)
class TestNerfRepro(unittest.TestCase):
@unittest.skip("This test reproduces full NERF training.")
def test_nerf_blender(self):
# Train vanilla NERF.
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
if not interactive_testing_requested():
return
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
experiment_runner = experiment.Experiment(**cfg)
experiment.dump_cfg(cfg)
experiment_runner.run()

View File

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import re
from typing import List
@contextlib.contextmanager
def intercept_logs(logger_name: str, regexp: str):
# Intercept logs that match a regexp, from a given logger.
intercepted_messages = []
logger = logging.getLogger(logger_name)
class LoggerInterceptor(logging.Filter):
def filter(self, record):
message = record.getMessage()
if re.search(regexp, message):
intercepted_messages.append(message)
return True
interceptor = LoggerInterceptor()
logger.addFilter(interceptor)
try:
yield intercepted_messages
finally:
logger.removeFilter(interceptor)

View File

@ -22,7 +22,6 @@ import numpy as np
import torch import torch
import torch.nn.functional as Fu import torch.nn.functional as Fu
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
from pytorch3d.implicitron.dataset.utils import is_train_frame from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base_model import EvaluationMode from pytorch3d.implicitron.models.base_model import EvaluationMode
@ -37,7 +36,7 @@ from pytorch3d.implicitron.tools.vis_utils import (
) )
from tqdm import tqdm from tqdm import tqdm
from .experiment import init_model from .experiment import Experiment
def render_sequence( def render_sequence(
@ -344,13 +343,14 @@ def export_scenes(
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx) os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
# Load the previously trained model # Load the previously trained model
model, _, _ = init_model(cfg=config, force_load=True, load_model_only=True) experiment = Experiment(config)
model = experiment.model_factory(force_load=True, load_model_only=True)
model.cuda() model.cuda()
model.eval() model.eval()
# Setup the dataset # Setup the dataset
datasource = ImplicitronDataSource(**config.data_source_args) data_source = experiment.data_source
dataset_map = datasource.dataset_map_provider.get_dataset_map() dataset_map, _ = data_source.get_datasets_and_dataloaders()
dataset = dataset_map[split] dataset = dataset_map[split]
if dataset is None: if dataset is None:
raise ValueError(f"{split} dataset not provided") raise ValueError(f"{split} dataset not provided")

View File

@ -40,6 +40,9 @@ class DataSourceBase(ReplaceableBase):
""" """
raise NotImplementedError() raise NotImplementedError()
def get_task(self) -> Task:
raise NotImplementedError()
@registry.register @registry.register
class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]

View File

@ -0,0 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
import lpips
import torch
import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base_model import EvaluationMode, ImplicitronModelBase
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
class EvaluatorBase(ReplaceableBase):
"""
Evaluate a trained model on given data. Returns a dict of loss/objective
names and their values.
"""
def run(
self, model: ImplicitronModelBase, dataloader: DataLoader, **kwargs
) -> Dict[str, Any]:
"""
Evaluate the results of Implicitron training.
"""
raise NotImplementedError()
@registry.register
class ImplicitronEvaluator(EvaluatorBase):
"""
Evaluate the results of Implicitron training.
Members:
camera_difficulty_bin_breaks: low/medium vals to divide camera difficulties into
[0-eps, low, medium, 1+eps].
"""
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
def __post_init__(self):
run_auto_creation(self)
def run(
self,
model: ImplicitronModelBase,
dataloader: DataLoader,
task: Task,
all_train_cameras: Optional[CamerasBase],
device: torch.device,
dump_to_json: bool = False,
exp_dir: Optional[str] = None,
epoch: Optional[int] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Evaluate the results of Implicitron training. Optionally, dump results to
exp_dir/results_test.json.
Args:
model: A (trained) model to evaluate.
dataloader: A test dataloader.
task: Type of the novel-view synthesis task we're working on.
all_train_cameras: Camera instances we used for training.
device: A torch device.
dump_to_json: If True, will dump the results to a json file.
exp_dir: Root expeirment directory.
epoch: Evaluation epoch (to be stored in the results dict).
Returns:
A dictionary of results.
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(dataloader):
frame_data = frame_data.to(device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
with torch.no_grad():
preds = model(
**{
**frame_data_for_eval,
"evaluation_mode": EvaluationMode.EVALUATION,
}
)
implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
implicitron_render,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_train_cameras,
)
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task, self.camera_difficulty_bin_breaks
)
results = category_result["results"]
if dump_to_json:
_dump_to_json(epoch, exp_dir, results)
return category_result["results"]
def _dump_to_json(
epoch: Optional[int], exp_dir: Optional[str], results: List[Dict[str, Any]]
) -> None:
if epoch is not None:
for r in results:
r["eval_epoch"] = int(epoch)
logger.info("Evaluation results")
evaluate.pretty_print_nvs_metrics(results)
if exp_dir is None:
raise ValueError("Cannot save results to json without a specified save path.")
with open(os.path.join(exp_dir, "results_test.json"), "w") as f:
json.dump(results, f)
def _get_eval_frame_data(frame_data: Any) -> Any:
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval = copy.deepcopy(frame_data)
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
frame_data.image_rgb
)[:, None, None, None]
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
setattr(frame_data_for_eval, k, value_masked)
return frame_data_for_eval

View File

@ -37,10 +37,12 @@ class ImplicitronRender:
) )
class ImplicitronModelBase(ReplaceableBase): class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
""" """
Replaceable abstract base for all image generation / rendering models. Replaceable abstract base for all image generation / rendering models.
`forward()` method produces a render with a depth map. `forward()` method produces a render with a depth map. Derives from Module
so we can rely on basic functionality provided to torch for model
optimization.
""" """
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -16,10 +16,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import tqdm import tqdm
from pytorch3d.implicitron.models.metrics import ( # noqa from pytorch3d.implicitron.models.metrics import (
RegularizationMetrics,
RegularizationMetricsBase, RegularizationMetricsBase,
ViewMetrics,
ViewMetricsBase, ViewMetricsBase,
) )
from pytorch3d.implicitron.tools import image_utils, vis_utils from pytorch3d.implicitron.tools import image_utils, vis_utils
@ -67,7 +65,7 @@ logger = logging.getLogger(__name__)
@registry.register @registry.register
class GenericModel(ImplicitronModelBase, torch.nn.Module): # pyre-ignore: 13 class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
""" """
GenericModel is a wrapper for the neural implicit GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists rendering and reconstruction pipeline which consists

View File

@ -22,7 +22,7 @@ from .renderer.base import EvaluationMode
@registry.register @registry.register
class ModelDBIR(ImplicitronModelBase, torch.nn.Module): class ModelDBIR(ImplicitronModelBase):
""" """
A simple depth-based image rendering model. A simple depth-based image rendering model.

View File

@ -218,7 +218,7 @@ class AdaptiveRaySampler(AbstractMaskRaySampler):
def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]: def _get_min_max_depth_bounds(self, cameras: CamerasBase) -> Tuple[float, float]:
""" """
Returns the adaptivelly calculated near/far planes. Returns the adaptively calculated near/far planes.
""" """
min_depth, max_depth = camera_utils.get_min_max_depth_bounds( min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self._scene_center, self.scene_extent cameras, self._scene_center, self.scene_extent

View File

@ -74,6 +74,7 @@ class Stats(object):
""" """
stats logging object useful for gathering statistics of training a deep net in pytorch stats logging object useful for gathering statistics of training a deep net in pytorch
Example: Example:
```
# init stats structure that logs statistics 'objective' and 'top1e' # init stats structure that logs statistics 'objective' and 'top1e'
stats = Stats( ('objective','top1e') ) stats = Stats( ('objective','top1e') )
network = init_net() # init a pytorch module (=nueral network) network = init_net() # init a pytorch module (=nueral network)
@ -94,6 +95,7 @@ class Stats(object):
# stores the training plots into '/tmp/epoch_stats.pdf' # stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running) # and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
```
""" """
def __init__( def __init__(

View File

@ -14,20 +14,22 @@ from visdom import Visdom
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_visdom_env(cfg): def get_visdom_env(visdom_env: str, exp_dir: str) -> str:
""" """
Parse out visdom environment name from the input config. Parse out visdom environment name from the input config.
Args: Args:
cfg: The global config file. visdom_env: Name of the wisdom environment, could be empty string.
exp_dir: Root experiment directory.
Returns: Returns:
visdom_env: The name of the visdom environment. visdom_env: The name of the visdom environment. If the given visdom_env is
empty, return the name of the bottom directory in exp_dir.
""" """
if len(cfg.visdom_env) == 0: if len(visdom_env) == 0:
visdom_env = cfg.exp_dir.split("/")[-1] visdom_env = exp_dir.split("/")[-1]
else: else:
visdom_env = cfg.visdom_env visdom_env = visdom_env
return visdom_env return visdom_env