mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Add the OverfitModel
Summary: Introduces the OverfitModel for NeRF-style training with overfitting to one scene. It is a specific case of GenericModel. It has been disentangle to ease usage. ## General modification 1. Modularize a minimum GenericModel to introduce OverfitModel 2. Introduce OverfitModel and ensure through unit testing that it behaves like GenericModel. ## Modularization The following methods have been extracted from GenericModel to allow modularity with ManyViewModel: - get_objective is now a call to weighted_sum_losses - log_loss_weights - prepare_inputs The generic methods have been moved to an utils.py file. Simplify the code to introduce OverfitModel. Private methods like chunk_generator are now public and can now be used by ManyViewModel. Reviewed By: shapovalov Differential Revision: D43771992 fbshipit-source-id: 6102aeb21c7fdd56aa2ff9cd1dd23fd9fbf26315
This commit is contained in:
parent
7d8b029aae
commit
813e941de5
@ -248,7 +248,7 @@ The main object for this trainer loop is `Experiment`. It has four top-level rep
|
|||||||
* `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`.
|
* `data_source`: This is a `DataSourceBase` which defaults to `ImplicitronDataSource`.
|
||||||
It constructs the data sets and dataloaders.
|
It constructs the data sets and dataloaders.
|
||||||
* `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`.
|
* `model_factory`: This is a `ModelFactoryBase` which defaults to `ImplicitronModelFactory`.
|
||||||
It constructs the model, which is usually an instance of implicitron's main `GenericModel` class, and can load its weights from a checkpoint.
|
It constructs the model, which is usually an instance of `OverfitModel` (for NeRF-style training with overfitting to one scene) or `GenericModel` (that is able to generalize to multiple scenes by NeRFormer-style conditioning on other scene views), and can load its weights from a checkpoint.
|
||||||
* `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`.
|
* `optimizer_factory`: This is an `OptimizerFactoryBase` which defaults to `ImplicitronOptimizerFactory`.
|
||||||
It constructs the optimizer and can load its weights from a checkpoint.
|
It constructs the optimizer and can load its weights from a checkpoint.
|
||||||
* `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop.
|
* `training_loop`: This is a `TrainingLoopBase` which defaults to `ImplicitronTrainingLoop` and defines the main training loop.
|
||||||
@ -292,6 +292,43 @@ model_GenericModel_args: GenericModel
|
|||||||
╘== ReductionFeatureAggregator
|
╘== ReductionFeatureAggregator
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Here is the class structure of OverfitModel:
|
||||||
|
|
||||||
|
```
|
||||||
|
model_OverfitModel_args: OverfitModel
|
||||||
|
└-- raysampler_*_args: RaySampler
|
||||||
|
╘== AdaptiveRaysampler
|
||||||
|
╘== NearFarRaysampler
|
||||||
|
└-- renderer_*_args: BaseRenderer
|
||||||
|
╘== MultiPassEmissionAbsorptionRenderer
|
||||||
|
╘== LSTMRenderer
|
||||||
|
╘== SignedDistanceFunctionRenderer
|
||||||
|
└-- ray_tracer_args: RayTracing
|
||||||
|
└-- ray_normal_coloring_network_args: RayNormalColoringNetwork
|
||||||
|
└-- implicit_function_*_args: ImplicitFunctionBase
|
||||||
|
╘== NeuralRadianceFieldImplicitFunction
|
||||||
|
╘== SRNImplicitFunction
|
||||||
|
└-- raymarch_function_args: SRNRaymarchFunction
|
||||||
|
└-- pixel_generator_args: SRNPixelGenerator
|
||||||
|
╘== SRNHyperNetImplicitFunction
|
||||||
|
└-- hypernet_args: SRNRaymarchHyperNet
|
||||||
|
└-- pixel_generator_args: SRNPixelGenerator
|
||||||
|
╘== IdrFeatureField
|
||||||
|
└-- coarse_implicit_function_*_args: ImplicitFunctionBase
|
||||||
|
╘== NeuralRadianceFieldImplicitFunction
|
||||||
|
╘== SRNImplicitFunction
|
||||||
|
└-- raymarch_function_args: SRNRaymarchFunction
|
||||||
|
└-- pixel_generator_args: SRNPixelGenerator
|
||||||
|
╘== SRNHyperNetImplicitFunction
|
||||||
|
└-- hypernet_args: SRNRaymarchHyperNet
|
||||||
|
└-- pixel_generator_args: SRNPixelGenerator
|
||||||
|
╘== IdrFeatureField
|
||||||
|
```
|
||||||
|
|
||||||
|
OverfitModel has been introduced to create a simple class to disantagle Nerfs which the overfit pattern
|
||||||
|
from the GenericModel.
|
||||||
|
|
||||||
|
|
||||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
||||||
`tests/experiment.yaml` shows every possible option if you have no user-defined classes.
|
`tests/experiment.yaml` shows every possible option if you have no user-defined classes.
|
||||||
|
|
||||||
|
79
projects/implicitron_trainer/configs/overfit_base.yaml
Normal file
79
projects/implicitron_trainer/configs/overfit_base.yaml
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
defaults:
|
||||||
|
- default_config
|
||||||
|
- _self_
|
||||||
|
exp_dir: ./data/exps/overfit_base/
|
||||||
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
|
visdom_port: 8097
|
||||||
|
visualize_interval: 0
|
||||||
|
max_epochs: 1000
|
||||||
|
data_source_ImplicitronDataSource_args:
|
||||||
|
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||||
|
dataset_map_provider_class_type: JsonIndexDatasetMapProvider
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
|
dataset_length_train: 1000
|
||||||
|
dataset_length_val: 1
|
||||||
|
num_workers: 8
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
|
dataset_root: ${oc.env:CO3D_DATASET_ROOT}
|
||||||
|
n_frames_per_sequence: -1
|
||||||
|
test_on_train: true
|
||||||
|
test_restrict_sequence_id: 0
|
||||||
|
dataset_JsonIndexDataset_args:
|
||||||
|
load_point_clouds: false
|
||||||
|
mask_depths: false
|
||||||
|
mask_images: false
|
||||||
|
model_factory_ImplicitronModelFactory_args:
|
||||||
|
model_class_type: "OverfitModel"
|
||||||
|
model_OverfitModel_args:
|
||||||
|
loss_weights:
|
||||||
|
loss_mask_bce: 1.0
|
||||||
|
loss_prev_stage_mask_bce: 1.0
|
||||||
|
loss_autodecoder_norm: 0.01
|
||||||
|
loss_rgb_mse: 1.0
|
||||||
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
|
output_rasterized_mc: false
|
||||||
|
chunk_size_grid: 102400
|
||||||
|
render_image_height: 400
|
||||||
|
render_image_width: 400
|
||||||
|
share_implicit_function_across_passes: false
|
||||||
|
implicit_function_class_type: "NeuralRadianceFieldImplicitFunction"
|
||||||
|
implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_xyz: 256
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
n_layers_xyz: 8
|
||||||
|
append_xyz:
|
||||||
|
- 5
|
||||||
|
coarse_implicit_function_class_type: "NeuralRadianceFieldImplicitFunction"
|
||||||
|
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_xyz: 256
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
n_layers_xyz: 8
|
||||||
|
append_xyz:
|
||||||
|
- 5
|
||||||
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
scene_extent: 8.0
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
|
n_pts_per_ray_fine_training: 64
|
||||||
|
n_pts_per_ray_fine_evaluation: 64
|
||||||
|
append_coarse_samples_to_fine: true
|
||||||
|
density_noise_std_train: 1.0
|
||||||
|
optimizer_factory_ImplicitronOptimizerFactory_args:
|
||||||
|
breed: Adam
|
||||||
|
weight_decay: 0.0
|
||||||
|
lr_policy: MultiStepLR
|
||||||
|
multistep_lr_milestones: []
|
||||||
|
lr: 0.0005
|
||||||
|
gamma: 0.1
|
||||||
|
momentum: 0.9
|
||||||
|
betas:
|
||||||
|
- 0.9
|
||||||
|
- 0.999
|
@ -0,0 +1,42 @@
|
|||||||
|
defaults:
|
||||||
|
- overfit_base
|
||||||
|
- _self_
|
||||||
|
data_source_ImplicitronDataSource_args:
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
|
batch_size: 1
|
||||||
|
dataset_length_train: 1000
|
||||||
|
dataset_length_val: 1
|
||||||
|
num_workers: 8
|
||||||
|
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||||
|
assert_single_seq: true
|
||||||
|
n_frames_per_sequence: -1
|
||||||
|
test_restrict_sequence_id: 0
|
||||||
|
test_on_train: false
|
||||||
|
model_factory_ImplicitronModelFactory_args:
|
||||||
|
model_class_type: "OverfitModel"
|
||||||
|
model_OverfitModel_args:
|
||||||
|
render_image_height: 800
|
||||||
|
render_image_width: 800
|
||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_kl
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
|
optimizer_factory_ImplicitronOptimizerFactory_args:
|
||||||
|
lr: 0.0005
|
||||||
|
multistep_lr_milestones:
|
||||||
|
- 200
|
||||||
|
- 300
|
||||||
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
|
max_epochs: 400
|
@ -0,0 +1,56 @@
|
|||||||
|
defaults:
|
||||||
|
- overfit_singleseq_base
|
||||||
|
- _self_
|
||||||
|
exp_dir: "./data/overfit_nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}"
|
||||||
|
data_source_ImplicitronDataSource_args:
|
||||||
|
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||||
|
dataset_length_train: 100
|
||||||
|
dataset_map_provider_class_type: BlenderDatasetMapProvider
|
||||||
|
dataset_map_provider_BlenderDatasetMapProvider_args:
|
||||||
|
base_dir: ${oc.env:BLENDER_DATASET_ROOT}/${oc.env:BLENDER_SINGLESEQ_CLASS}
|
||||||
|
n_known_frames_for_test: null
|
||||||
|
object_name: ${oc.env:BLENDER_SINGLESEQ_CLASS}
|
||||||
|
path_manager_factory_class_type: PathManagerFactory
|
||||||
|
path_manager_factory_PathManagerFactory_args:
|
||||||
|
silence_logs: true
|
||||||
|
|
||||||
|
model_factory_ImplicitronModelFactory_args:
|
||||||
|
model_class_type: "OverfitModel"
|
||||||
|
model_OverfitModel_args:
|
||||||
|
mask_images: false
|
||||||
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 4096
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
scene_extent: 2.0
|
||||||
|
scene_center:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
|
density_noise_std_train: 0.0
|
||||||
|
n_pts_per_ray_fine_training: 128
|
||||||
|
n_pts_per_ray_fine_evaluation: 128
|
||||||
|
raymarcher_EmissionAbsorptionRaymarcher_args:
|
||||||
|
blend_output: false
|
||||||
|
loss_weights:
|
||||||
|
loss_rgb_mse: 1.0
|
||||||
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
|
loss_mask_bce: 0.0
|
||||||
|
loss_prev_stage_mask_bce: 0.0
|
||||||
|
loss_autodecoder_norm: 0.00
|
||||||
|
|
||||||
|
optimizer_factory_ImplicitronOptimizerFactory_args:
|
||||||
|
exponential_lr_step_size: 3001
|
||||||
|
lr_policy: LinearExponential
|
||||||
|
linear_exponential_lr_milestone: 200
|
||||||
|
|
||||||
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
|
max_epochs: 6000
|
||||||
|
metric_print_interval: 10
|
||||||
|
store_checkpoints_purge: 3
|
||||||
|
test_when_finished: true
|
||||||
|
validation_interval: 100
|
@ -59,7 +59,7 @@ from pytorch3d.implicitron.dataset.data_source import (
|
|||||||
DataSourceBase,
|
DataSourceBase,
|
||||||
ImplicitronDataSource,
|
ImplicitronDataSource,
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.generic_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||||
MultiPassEmissionAbsorptionRenderer,
|
MultiPassEmissionAbsorptionRenderer,
|
||||||
|
@ -561,6 +561,623 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
use_xavier_init: true
|
use_xavier_init: true
|
||||||
view_metrics_ViewMetrics_args: {}
|
view_metrics_ViewMetrics_args: {}
|
||||||
regularization_metrics_RegularizationMetrics_args: {}
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
|
model_OverfitModel_args:
|
||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_rgb_huber
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_mask_beta_prior
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_density_tv
|
||||||
|
- loss_depth_neg_penalty
|
||||||
|
- loss_autodecoder_norm
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_psnr_fg
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
|
mask_images: true
|
||||||
|
mask_depths: true
|
||||||
|
render_image_width: 400
|
||||||
|
render_image_height: 400
|
||||||
|
mask_threshold: 0.5
|
||||||
|
output_rasterized_mc: false
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
chunk_size_grid: 4096
|
||||||
|
render_features_dimensions: 3
|
||||||
|
tqdm_trigger_threshold: 16
|
||||||
|
n_train_target_views: 1
|
||||||
|
sampling_mode_training: mask_sample
|
||||||
|
sampling_mode_evaluation: full_grid
|
||||||
|
global_encoder_class_type: null
|
||||||
|
raysampler_class_type: AdaptiveRaySampler
|
||||||
|
renderer_class_type: MultiPassEmissionAbsorptionRenderer
|
||||||
|
share_implicit_function_across_passes: false
|
||||||
|
implicit_function_class_type: NeuralRadianceFieldImplicitFunction
|
||||||
|
coarse_implicit_function_class_type: null
|
||||||
|
view_metrics_class_type: ViewMetrics
|
||||||
|
regularization_metrics_class_type: RegularizationMetrics
|
||||||
|
loss_weights:
|
||||||
|
loss_rgb_mse: 1.0
|
||||||
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
|
loss_mask_bce: 0.0
|
||||||
|
loss_prev_stage_mask_bce: 0.0
|
||||||
|
global_encoder_HarmonicTimeEncoder_args:
|
||||||
|
n_harmonic_functions: 10
|
||||||
|
append_input: true
|
||||||
|
time_divisor: 1.0
|
||||||
|
global_encoder_SequenceAutodecoder_args:
|
||||||
|
autodecoder_args:
|
||||||
|
encoding_dim: 0
|
||||||
|
n_instances: 1
|
||||||
|
init_scale: 1.0
|
||||||
|
ignore_input: false
|
||||||
|
raysampler_AdaptiveRaySampler_args:
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
n_rays_total_training: null
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
scene_extent: 8.0
|
||||||
|
scene_center:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
raysampler_NearFarRaySampler_args:
|
||||||
|
n_pts_per_ray_training: 64
|
||||||
|
n_pts_per_ray_evaluation: 64
|
||||||
|
n_rays_per_image_sampled_from_mask: 1024
|
||||||
|
n_rays_total_training: null
|
||||||
|
stratified_point_sampling_training: true
|
||||||
|
stratified_point_sampling_evaluation: false
|
||||||
|
min_depth: 0.1
|
||||||
|
max_depth: 8.0
|
||||||
|
renderer_LSTMRenderer_args:
|
||||||
|
num_raymarch_steps: 10
|
||||||
|
init_depth: 17.0
|
||||||
|
init_depth_noise_std: 0.0005
|
||||||
|
hidden_size: 16
|
||||||
|
n_feature_channels: 256
|
||||||
|
bg_color: null
|
||||||
|
verbose: false
|
||||||
|
renderer_MultiPassEmissionAbsorptionRenderer_args:
|
||||||
|
raymarcher_class_type: EmissionAbsorptionRaymarcher
|
||||||
|
n_pts_per_ray_fine_training: 64
|
||||||
|
n_pts_per_ray_fine_evaluation: 64
|
||||||
|
stratified_sampling_coarse_training: true
|
||||||
|
stratified_sampling_coarse_evaluation: false
|
||||||
|
append_coarse_samples_to_fine: true
|
||||||
|
density_noise_std_train: 0.0
|
||||||
|
return_weights: false
|
||||||
|
raymarcher_CumsumRaymarcher_args:
|
||||||
|
surface_thickness: 1
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
replicate_last_interval: false
|
||||||
|
background_opacity: 0.0
|
||||||
|
density_relu: true
|
||||||
|
blend_output: false
|
||||||
|
raymarcher_EmissionAbsorptionRaymarcher_args:
|
||||||
|
surface_thickness: 1
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
replicate_last_interval: false
|
||||||
|
background_opacity: 10000000000.0
|
||||||
|
density_relu: true
|
||||||
|
blend_output: false
|
||||||
|
renderer_SignedDistanceFunctionRenderer_args:
|
||||||
|
ray_normal_coloring_network_args:
|
||||||
|
feature_vector_size: 3
|
||||||
|
mode: idr
|
||||||
|
d_in: 9
|
||||||
|
d_out: 3
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_dir: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
bg_color:
|
||||||
|
- 0.0
|
||||||
|
soft_mask_alpha: 50.0
|
||||||
|
ray_tracer_args:
|
||||||
|
sdf_threshold: 5.0e-05
|
||||||
|
line_search_step: 0.5
|
||||||
|
line_step_iters: 1
|
||||||
|
sphere_tracing_iters: 10
|
||||||
|
n_steps: 100
|
||||||
|
n_secant_steps: 8
|
||||||
|
implicit_function_IdrFeatureField_args:
|
||||||
|
d_in: 3
|
||||||
|
d_out: 1
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
geometric_init: true
|
||||||
|
bias: 1.0
|
||||||
|
skip_in: []
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_xyz: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
implicit_function_NeRFormerImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
transformer_dim_down_factor: 2.0
|
||||||
|
n_hidden_neurons_xyz: 80
|
||||||
|
n_layers_xyz: 2
|
||||||
|
append_xyz:
|
||||||
|
- 1
|
||||||
|
implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
transformer_dim_down_factor: 1.0
|
||||||
|
n_hidden_neurons_xyz: 256
|
||||||
|
n_layers_xyz: 8
|
||||||
|
append_xyz:
|
||||||
|
- 5
|
||||||
|
implicit_function_SRNHyperNetImplicitFunction_args:
|
||||||
|
latent_dim_hypernet: 0
|
||||||
|
hypernet_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
n_hidden_units_hypernet: 256
|
||||||
|
n_layers_hypernet: 1
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
implicit_function_SRNImplicitFunction_args:
|
||||||
|
raymarch_function_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
raymarch_function: null
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
implicit_function_VoxelGridImplicitFunction_args:
|
||||||
|
harmonic_embedder_xyz_density_args:
|
||||||
|
n_harmonic_functions: 6
|
||||||
|
omega_0: 1.0
|
||||||
|
logspace: true
|
||||||
|
append_input: true
|
||||||
|
harmonic_embedder_xyz_color_args:
|
||||||
|
n_harmonic_functions: 6
|
||||||
|
omega_0: 1.0
|
||||||
|
logspace: true
|
||||||
|
append_input: true
|
||||||
|
harmonic_embedder_dir_color_args:
|
||||||
|
n_harmonic_functions: 6
|
||||||
|
omega_0: 1.0
|
||||||
|
logspace: true
|
||||||
|
append_input: true
|
||||||
|
decoder_density_class_type: MLPDecoder
|
||||||
|
decoder_color_class_type: MLPDecoder
|
||||||
|
use_multiple_streams: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
scaffold_calculating_epochs: []
|
||||||
|
scaffold_resolution:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
scaffold_empty_space_threshold: 0.001
|
||||||
|
scaffold_occupancy_chunk_size: -1
|
||||||
|
scaffold_max_pool_kernel_size: 3
|
||||||
|
scaffold_filter_points: true
|
||||||
|
volume_cropping_epochs: []
|
||||||
|
voxel_grid_density_args:
|
||||||
|
voxel_grid_class_type: FullResolutionVoxelGrid
|
||||||
|
extents:
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
translation:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
init_std: 0.1
|
||||||
|
init_mean: 0.0
|
||||||
|
hold_voxel_grid_as_parameters: true
|
||||||
|
param_groups: {}
|
||||||
|
voxel_grid_CPFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: 24
|
||||||
|
basis_matrix: true
|
||||||
|
voxel_grid_FullResolutionVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
voxel_grid_VMFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: null
|
||||||
|
distribution_of_components: null
|
||||||
|
basis_matrix: true
|
||||||
|
voxel_grid_color_args:
|
||||||
|
voxel_grid_class_type: FullResolutionVoxelGrid
|
||||||
|
extents:
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
translation:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
init_std: 0.1
|
||||||
|
init_mean: 0.0
|
||||||
|
hold_voxel_grid_as_parameters: true
|
||||||
|
param_groups: {}
|
||||||
|
voxel_grid_CPFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: 24
|
||||||
|
basis_matrix: true
|
||||||
|
voxel_grid_FullResolutionVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
voxel_grid_VMFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: null
|
||||||
|
distribution_of_components: null
|
||||||
|
basis_matrix: true
|
||||||
|
decoder_density_ElementwiseDecoder_args:
|
||||||
|
scale: 1.0
|
||||||
|
shift: 0.0
|
||||||
|
operation: IDENTITY
|
||||||
|
decoder_density_MLPDecoder_args:
|
||||||
|
param_groups: {}
|
||||||
|
network_args:
|
||||||
|
n_layers: 8
|
||||||
|
output_dim: 256
|
||||||
|
skip_dim: 39
|
||||||
|
hidden_dim: 256
|
||||||
|
input_skips:
|
||||||
|
- 5
|
||||||
|
skip_affine_trans: false
|
||||||
|
last_layer_bias_init: null
|
||||||
|
last_activation: RELU
|
||||||
|
use_xavier_init: true
|
||||||
|
decoder_color_ElementwiseDecoder_args:
|
||||||
|
scale: 1.0
|
||||||
|
shift: 0.0
|
||||||
|
operation: IDENTITY
|
||||||
|
decoder_color_MLPDecoder_args:
|
||||||
|
param_groups: {}
|
||||||
|
network_args:
|
||||||
|
n_layers: 8
|
||||||
|
output_dim: 256
|
||||||
|
skip_dim: 39
|
||||||
|
hidden_dim: 256
|
||||||
|
input_skips:
|
||||||
|
- 5
|
||||||
|
skip_affine_trans: false
|
||||||
|
last_layer_bias_init: null
|
||||||
|
last_activation: RELU
|
||||||
|
use_xavier_init: true
|
||||||
|
coarse_implicit_function_IdrFeatureField_args:
|
||||||
|
d_in: 3
|
||||||
|
d_out: 1
|
||||||
|
dims:
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
- 512
|
||||||
|
geometric_init: true
|
||||||
|
bias: 1.0
|
||||||
|
skip_in: []
|
||||||
|
weight_norm: true
|
||||||
|
n_harmonic_functions_xyz: 0
|
||||||
|
pooled_feature_dim: 0
|
||||||
|
coarse_implicit_function_NeRFormerImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
transformer_dim_down_factor: 2.0
|
||||||
|
n_hidden_neurons_xyz: 80
|
||||||
|
n_layers_xyz: 2
|
||||||
|
append_xyz:
|
||||||
|
- 1
|
||||||
|
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args:
|
||||||
|
n_harmonic_functions_xyz: 10
|
||||||
|
n_harmonic_functions_dir: 4
|
||||||
|
n_hidden_neurons_dir: 128
|
||||||
|
input_xyz: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
transformer_dim_down_factor: 1.0
|
||||||
|
n_hidden_neurons_xyz: 256
|
||||||
|
n_layers_xyz: 8
|
||||||
|
append_xyz:
|
||||||
|
- 5
|
||||||
|
coarse_implicit_function_SRNHyperNetImplicitFunction_args:
|
||||||
|
latent_dim_hypernet: 0
|
||||||
|
hypernet_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
n_hidden_units_hypernet: 256
|
||||||
|
n_layers_hypernet: 1
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
coarse_implicit_function_SRNImplicitFunction_args:
|
||||||
|
raymarch_function_args:
|
||||||
|
n_harmonic_functions: 3
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 3
|
||||||
|
out_features: 256
|
||||||
|
xyz_in_camera_coords: false
|
||||||
|
raymarch_function: null
|
||||||
|
pixel_generator_args:
|
||||||
|
n_harmonic_functions: 4
|
||||||
|
n_hidden_units: 256
|
||||||
|
n_hidden_units_color: 128
|
||||||
|
n_layers: 2
|
||||||
|
in_features: 256
|
||||||
|
out_features: 3
|
||||||
|
ray_dir_in_camera_coords: false
|
||||||
|
coarse_implicit_function_VoxelGridImplicitFunction_args:
|
||||||
|
harmonic_embedder_xyz_density_args:
|
||||||
|
n_harmonic_functions: 6
|
||||||
|
omega_0: 1.0
|
||||||
|
logspace: true
|
||||||
|
append_input: true
|
||||||
|
harmonic_embedder_xyz_color_args:
|
||||||
|
n_harmonic_functions: 6
|
||||||
|
omega_0: 1.0
|
||||||
|
logspace: true
|
||||||
|
append_input: true
|
||||||
|
harmonic_embedder_dir_color_args:
|
||||||
|
n_harmonic_functions: 6
|
||||||
|
omega_0: 1.0
|
||||||
|
logspace: true
|
||||||
|
append_input: true
|
||||||
|
decoder_density_class_type: MLPDecoder
|
||||||
|
decoder_color_class_type: MLPDecoder
|
||||||
|
use_multiple_streams: true
|
||||||
|
xyz_ray_dir_in_camera_coords: false
|
||||||
|
scaffold_calculating_epochs: []
|
||||||
|
scaffold_resolution:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
scaffold_empty_space_threshold: 0.001
|
||||||
|
scaffold_occupancy_chunk_size: -1
|
||||||
|
scaffold_max_pool_kernel_size: 3
|
||||||
|
scaffold_filter_points: true
|
||||||
|
volume_cropping_epochs: []
|
||||||
|
voxel_grid_density_args:
|
||||||
|
voxel_grid_class_type: FullResolutionVoxelGrid
|
||||||
|
extents:
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
translation:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
init_std: 0.1
|
||||||
|
init_mean: 0.0
|
||||||
|
hold_voxel_grid_as_parameters: true
|
||||||
|
param_groups: {}
|
||||||
|
voxel_grid_CPFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: 24
|
||||||
|
basis_matrix: true
|
||||||
|
voxel_grid_FullResolutionVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
voxel_grid_VMFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: null
|
||||||
|
distribution_of_components: null
|
||||||
|
basis_matrix: true
|
||||||
|
voxel_grid_color_args:
|
||||||
|
voxel_grid_class_type: FullResolutionVoxelGrid
|
||||||
|
extents:
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
- 2.0
|
||||||
|
translation:
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
- 0.0
|
||||||
|
init_std: 0.1
|
||||||
|
init_mean: 0.0
|
||||||
|
hold_voxel_grid_as_parameters: true
|
||||||
|
param_groups: {}
|
||||||
|
voxel_grid_CPFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: 24
|
||||||
|
basis_matrix: true
|
||||||
|
voxel_grid_FullResolutionVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
voxel_grid_VMFactorizedVoxelGrid_args:
|
||||||
|
align_corners: true
|
||||||
|
padding: zeros
|
||||||
|
mode: bilinear
|
||||||
|
n_features: 1
|
||||||
|
resolution_changes:
|
||||||
|
0:
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
- 128
|
||||||
|
n_components: null
|
||||||
|
distribution_of_components: null
|
||||||
|
basis_matrix: true
|
||||||
|
decoder_density_ElementwiseDecoder_args:
|
||||||
|
scale: 1.0
|
||||||
|
shift: 0.0
|
||||||
|
operation: IDENTITY
|
||||||
|
decoder_density_MLPDecoder_args:
|
||||||
|
param_groups: {}
|
||||||
|
network_args:
|
||||||
|
n_layers: 8
|
||||||
|
output_dim: 256
|
||||||
|
skip_dim: 39
|
||||||
|
hidden_dim: 256
|
||||||
|
input_skips:
|
||||||
|
- 5
|
||||||
|
skip_affine_trans: false
|
||||||
|
last_layer_bias_init: null
|
||||||
|
last_activation: RELU
|
||||||
|
use_xavier_init: true
|
||||||
|
decoder_color_ElementwiseDecoder_args:
|
||||||
|
scale: 1.0
|
||||||
|
shift: 0.0
|
||||||
|
operation: IDENTITY
|
||||||
|
decoder_color_MLPDecoder_args:
|
||||||
|
param_groups: {}
|
||||||
|
network_args:
|
||||||
|
n_layers: 8
|
||||||
|
output_dim: 256
|
||||||
|
skip_dim: 39
|
||||||
|
hidden_dim: 256
|
||||||
|
input_skips:
|
||||||
|
- 5
|
||||||
|
skip_affine_trans: false
|
||||||
|
last_layer_bias_init: null
|
||||||
|
last_activation: RELU
|
||||||
|
use_xavier_init: true
|
||||||
|
view_metrics_ViewMetrics_args: {}
|
||||||
|
regularization_metrics_RegularizationMetrics_args: {}
|
||||||
optimizer_factory_ImplicitronOptimizerFactory_args:
|
optimizer_factory_ImplicitronOptimizerFactory_args:
|
||||||
betas:
|
betas:
|
||||||
- 0.9
|
- 0.9
|
||||||
|
@ -141,7 +141,11 @@ class TestExperiment(unittest.TestCase):
|
|||||||
# Check that all the pre-prepared configs are valid.
|
# Check that all the pre-prepared configs are valid.
|
||||||
config_files = []
|
config_files = []
|
||||||
|
|
||||||
for pattern in ("repro_singleseq*.yaml", "repro_multiseq*.yaml"):
|
for pattern in (
|
||||||
|
"repro_singleseq*.yaml",
|
||||||
|
"repro_multiseq*.yaml",
|
||||||
|
"overfit_singleseq*.yaml",
|
||||||
|
):
|
||||||
config_files.extend(
|
config_files.extend(
|
||||||
[
|
[
|
||||||
f
|
f
|
||||||
|
@ -3,3 +3,8 @@
|
|||||||
#
|
#
|
||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
# Allows to register the models
|
||||||
|
# see: pytorch3d.implicitron.tools.config.registry:register
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.models.overfit_model import OverfitModel
|
||||||
|
@ -8,11 +8,11 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
from pytorch3d.implicitron.tools.config import ReplaceableBase
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
from .renderer.base import EvaluationMode
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImplicitronRender:
|
class ImplicitronRender:
|
||||||
|
@ -9,14 +9,11 @@
|
|||||||
# which are part of implicitron. They ensure that the registry is prepopulated.
|
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
|
||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import tqdm
|
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch3d.common.compat import prod
|
|
||||||
|
|
||||||
from pytorch3d.implicitron.models.base_model import (
|
from pytorch3d.implicitron.models.base_model import (
|
||||||
ImplicitronModelBase,
|
ImplicitronModelBase,
|
||||||
@ -33,11 +30,9 @@ from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
|||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
|
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( # noqa
|
||||||
NeRFormerImplicitFunction,
|
NeRFormerImplicitFunction,
|
||||||
NeuralRadianceFieldImplicitFunction,
|
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
|
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( # noqa
|
||||||
SRNHyperNetImplicitFunction,
|
SRNHyperNetImplicitFunction,
|
||||||
SRNImplicitFunction,
|
|
||||||
)
|
)
|
||||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
|
from pytorch3d.implicitron.models.implicit_function.voxel_grid_implicit_function import ( # noqa
|
||||||
VoxelGridImplicitFunction,
|
VoxelGridImplicitFunction,
|
||||||
@ -63,8 +58,16 @@ from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
|||||||
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
|
from pytorch3d.implicitron.models.renderer.sdf_renderer import ( # noqa
|
||||||
SignedDistanceFunctionRenderer,
|
SignedDistanceFunctionRenderer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.utils import (
|
||||||
|
apply_chunked,
|
||||||
|
chunk_generator,
|
||||||
|
log_loss_weights,
|
||||||
|
preprocess_input,
|
||||||
|
weighted_sum_losses,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.models.view_pooler.view_pooler import ViewPooler
|
from pytorch3d.implicitron.models.view_pooler.view_pooler import ViewPooler
|
||||||
from pytorch3d.implicitron.tools import image_utils, vis_utils
|
from pytorch3d.implicitron.tools import vis_utils
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
expand_args_fields,
|
expand_args_fields,
|
||||||
registry,
|
registry,
|
||||||
@ -72,7 +75,6 @@ from pytorch3d.implicitron.tools.config import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
||||||
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
|
||||||
from pytorch3d.renderer import utils as rend_utils
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
from pytorch3d.renderer.cameras import CamerasBase
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
@ -323,7 +325,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
|
|
||||||
self._implicit_functions = self._construct_implicit_functions()
|
self._implicit_functions = self._construct_implicit_functions()
|
||||||
|
|
||||||
self.log_loss_weights()
|
log_loss_weights(self.loss_weights, logger)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -367,8 +369,14 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
preds: A dictionary containing all outputs of the forward pass including the
|
preds: A dictionary containing all outputs of the forward pass including the
|
||||||
rendered images, depths, masks, losses and other metrics.
|
rendered images, depths, masks, losses and other metrics.
|
||||||
"""
|
"""
|
||||||
image_rgb, fg_probability, depth_map = self._preprocess_input(
|
image_rgb, fg_probability, depth_map = preprocess_input(
|
||||||
image_rgb, fg_probability, depth_map
|
image_rgb,
|
||||||
|
fg_probability,
|
||||||
|
depth_map,
|
||||||
|
self.mask_images,
|
||||||
|
self.mask_depths,
|
||||||
|
self.mask_threshold,
|
||||||
|
self.bg_color,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Obtain the batch size from the camera as this is the only required input.
|
# Obtain the batch size from the camera as this is the only required input.
|
||||||
@ -453,12 +461,12 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
for func in self._implicit_functions:
|
for func in self._implicit_functions:
|
||||||
func.bind_args(**custom_args)
|
func.bind_args(**custom_args)
|
||||||
|
|
||||||
chunked_renderer_inputs = {}
|
inputs_to_be_chunked = {}
|
||||||
if fg_probability is not None and self.renderer.requires_object_mask():
|
if fg_probability is not None and self.renderer.requires_object_mask():
|
||||||
sampled_fb_prob = rend_utils.ndc_grid_sample(
|
sampled_fb_prob = rend_utils.ndc_grid_sample(
|
||||||
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
|
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
|
||||||
)
|
)
|
||||||
chunked_renderer_inputs["object_mask"] = sampled_fb_prob > 0.5
|
inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5
|
||||||
|
|
||||||
# (5)-(6) Implicit function evaluation and Rendering
|
# (5)-(6) Implicit function evaluation and Rendering
|
||||||
rendered = self._render(
|
rendered = self._render(
|
||||||
@ -466,7 +474,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
sampling_mode=sampling_mode,
|
sampling_mode=sampling_mode,
|
||||||
evaluation_mode=evaluation_mode,
|
evaluation_mode=evaluation_mode,
|
||||||
implicit_functions=self._implicit_functions,
|
implicit_functions=self._implicit_functions,
|
||||||
chunked_inputs=chunked_renderer_inputs,
|
inputs_to_be_chunked=inputs_to_be_chunked,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Unbind the custom arguments to prevent pytorch from storing
|
# Unbind the custom arguments to prevent pytorch from storing
|
||||||
@ -530,30 +538,18 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
raise AssertionError("Unreachable state")
|
raise AssertionError("Unreachable state")
|
||||||
|
|
||||||
# (7) Compute losses
|
# (7) Compute losses
|
||||||
# finally get the optimization objective using self.loss_weights
|
|
||||||
objective = self._get_objective(preds)
|
objective = self._get_objective(preds)
|
||||||
if objective is not None:
|
if objective is not None:
|
||||||
preds["objective"] = objective
|
preds["objective"] = objective
|
||||||
|
|
||||||
return preds
|
return preds
|
||||||
|
|
||||||
def _get_objective(self, preds) -> Optional[torch.Tensor]:
|
def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
A helper function to compute the overall loss as the dot product
|
A helper function to compute the overall loss as the dot product
|
||||||
of individual loss functions with the corresponding weights.
|
of individual loss functions with the corresponding weights.
|
||||||
"""
|
"""
|
||||||
losses_weighted = [
|
return weighted_sum_losses(preds, self.loss_weights)
|
||||||
preds[k] * float(w)
|
|
||||||
for k, w in self.loss_weights.items()
|
|
||||||
if (k in preds and w != 0.0)
|
|
||||||
]
|
|
||||||
if len(losses_weighted) == 0:
|
|
||||||
warnings.warn("No main objective found.")
|
|
||||||
return None
|
|
||||||
loss = sum(losses_weighted)
|
|
||||||
assert torch.is_tensor(loss)
|
|
||||||
# pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
self,
|
self,
|
||||||
@ -585,7 +581,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
ray_bundle: ImplicitronRayBundle,
|
ray_bundle: ImplicitronRayBundle,
|
||||||
chunked_inputs: Dict[str, torch.Tensor],
|
inputs_to_be_chunked: Dict[str, torch.Tensor],
|
||||||
sampling_mode: RenderSamplingMode,
|
sampling_mode: RenderSamplingMode,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> RendererOutput:
|
) -> RendererOutput:
|
||||||
@ -593,7 +589,7 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
Args:
|
Args:
|
||||||
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
sampled rendering rays.
|
sampled rendering rays.
|
||||||
chunked_inputs: A collection of tensor of shape `(B, _, H, W)`. E.g.
|
inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g.
|
||||||
SignedDistanceFunctionRenderer requires "object_mask", shape
|
SignedDistanceFunctionRenderer requires "object_mask", shape
|
||||||
(B, 1, H, W), the silhouette of the object in the image. When
|
(B, 1, H, W), the silhouette of the object in the image. When
|
||||||
chunking, they are passed to the renderer as shape
|
chunking, they are passed to the renderer as shape
|
||||||
@ -605,30 +601,27 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
An instance of RendererOutput
|
An instance of RendererOutput
|
||||||
"""
|
"""
|
||||||
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
|
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
|
||||||
return _apply_chunked(
|
return apply_chunked(
|
||||||
self.renderer,
|
self.renderer,
|
||||||
_chunk_generator(
|
chunk_generator(
|
||||||
self.chunk_size_grid,
|
self.chunk_size_grid,
|
||||||
ray_bundle,
|
ray_bundle,
|
||||||
chunked_inputs,
|
inputs_to_be_chunked,
|
||||||
self.tqdm_trigger_threshold,
|
self.tqdm_trigger_threshold,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
),
|
),
|
||||||
lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]),
|
lambda batch: torch.cat(batch, dim=1).reshape(
|
||||||
|
*ray_bundle.lengths.shape[:-1], -1
|
||||||
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# pyre-fixme[29]: `BaseRenderer` is not a function.
|
# pyre-fixme[29]: `BaseRenderer` is not a function.
|
||||||
return self.renderer(
|
return self.renderer(
|
||||||
ray_bundle=ray_bundle,
|
ray_bundle=ray_bundle,
|
||||||
**chunked_inputs,
|
**inputs_to_be_chunked,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_global_encoder_encoding_dim(self) -> int:
|
|
||||||
if self.global_encoder is None:
|
|
||||||
return 0
|
|
||||||
return self.global_encoder.get_encoding_dim()
|
|
||||||
|
|
||||||
def _get_viewpooled_feature_dim(self) -> int:
|
def _get_viewpooled_feature_dim(self) -> int:
|
||||||
if self.view_pooler is None:
|
if self.view_pooler is None:
|
||||||
return 0
|
return 0
|
||||||
@ -720,30 +713,29 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
function(s) are initialized.
|
function(s) are initialized.
|
||||||
"""
|
"""
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
|
global_encoder_dim = (
|
||||||
|
0 if self.global_encoder is None else self.global_encoder.get_encoding_dim()
|
||||||
|
)
|
||||||
|
viewpooled_feature_dim = self._get_viewpooled_feature_dim()
|
||||||
|
|
||||||
if self.implicit_function_class_type in (
|
if self.implicit_function_class_type in (
|
||||||
"NeuralRadianceFieldImplicitFunction",
|
"NeuralRadianceFieldImplicitFunction",
|
||||||
"NeRFormerImplicitFunction",
|
"NeRFormerImplicitFunction",
|
||||||
):
|
):
|
||||||
extra_args["latent_dim"] = (
|
extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim
|
||||||
self._get_viewpooled_feature_dim()
|
|
||||||
+ self._get_global_encoder_encoding_dim()
|
|
||||||
)
|
|
||||||
extra_args["color_dim"] = self.render_features_dimensions
|
extra_args["color_dim"] = self.render_features_dimensions
|
||||||
|
|
||||||
if self.implicit_function_class_type == "IdrFeatureField":
|
if self.implicit_function_class_type == "IdrFeatureField":
|
||||||
extra_args["feature_vector_size"] = self.render_features_dimensions
|
extra_args["feature_vector_size"] = self.render_features_dimensions
|
||||||
extra_args["encoding_dim"] = self._get_global_encoder_encoding_dim()
|
extra_args["encoding_dim"] = global_encoder_dim
|
||||||
|
|
||||||
if self.implicit_function_class_type == "SRNImplicitFunction":
|
if self.implicit_function_class_type == "SRNImplicitFunction":
|
||||||
extra_args["latent_dim"] = (
|
extra_args["latent_dim"] = viewpooled_feature_dim + global_encoder_dim
|
||||||
self._get_viewpooled_feature_dim()
|
|
||||||
+ self._get_global_encoder_encoding_dim()
|
|
||||||
)
|
|
||||||
|
|
||||||
# srn_hypernet preprocessing
|
# srn_hypernet preprocessing
|
||||||
if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
|
if self.implicit_function_class_type == "SRNHyperNetImplicitFunction":
|
||||||
extra_args["latent_dim"] = self._get_viewpooled_feature_dim()
|
extra_args["latent_dim"] = viewpooled_feature_dim
|
||||||
extra_args["latent_dim_hypernet"] = self._get_global_encoder_encoding_dim()
|
extra_args["latent_dim_hypernet"] = global_encoder_dim
|
||||||
|
|
||||||
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
# check that for srn, srn_hypernet, idr we have self.num_passes=1
|
||||||
implicit_function_type = registry.get(
|
implicit_function_type = registry.get(
|
||||||
@ -770,147 +762,3 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13
|
|||||||
for _ in range(self.num_passes)
|
for _ in range(self.num_passes)
|
||||||
]
|
]
|
||||||
return torch.nn.ModuleList(implicit_functions_list)
|
return torch.nn.ModuleList(implicit_functions_list)
|
||||||
|
|
||||||
def log_loss_weights(self) -> None:
|
|
||||||
"""
|
|
||||||
Print a table of the loss weights.
|
|
||||||
"""
|
|
||||||
loss_weights_message = (
|
|
||||||
"-------\nloss_weights:\n"
|
|
||||||
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in self.loss_weights.items())
|
|
||||||
+ "-------"
|
|
||||||
)
|
|
||||||
logger.info(loss_weights_message)
|
|
||||||
|
|
||||||
def _preprocess_input(
|
|
||||||
self,
|
|
||||||
image_rgb: Optional[torch.Tensor],
|
|
||||||
fg_probability: Optional[torch.Tensor],
|
|
||||||
depth_map: Optional[torch.Tensor],
|
|
||||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
Helper function to preprocess the input images and optional depth maps
|
|
||||||
to apply masking if required.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
|
|
||||||
corresponding to the source viewpoints from which features will be extracted
|
|
||||||
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
|
|
||||||
of foreground masks with values in [0, 1].
|
|
||||||
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Modified image_rgb, fg_mask, depth_map
|
|
||||||
"""
|
|
||||||
if image_rgb is not None and image_rgb.ndim == 3:
|
|
||||||
# The FrameData object is used for both frames and batches of frames,
|
|
||||||
# and a user might get this error if those were confused.
|
|
||||||
# Perhaps a user has a FrameData `fd` representing a single frame and
|
|
||||||
# wrote something like `model(**fd)` instead of
|
|
||||||
# `model(**fd.collate([fd]))`.
|
|
||||||
raise ValueError(
|
|
||||||
"Model received unbatched inputs. "
|
|
||||||
+ "Perhaps they came from a FrameData which had not been collated."
|
|
||||||
)
|
|
||||||
|
|
||||||
fg_mask = fg_probability
|
|
||||||
if fg_mask is not None and self.mask_threshold > 0.0:
|
|
||||||
# threshold masks
|
|
||||||
warnings.warn("Thresholding masks!")
|
|
||||||
fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask)
|
|
||||||
|
|
||||||
if self.mask_images and fg_mask is not None and image_rgb is not None:
|
|
||||||
# mask the image
|
|
||||||
warnings.warn("Masking images!")
|
|
||||||
image_rgb = image_utils.mask_background(
|
|
||||||
image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color)
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.mask_depths and fg_mask is not None and depth_map is not None:
|
|
||||||
# mask the depths
|
|
||||||
assert (
|
|
||||||
self.mask_threshold > 0.0
|
|
||||||
), "Depths should be masked only with thresholded masks"
|
|
||||||
warnings.warn("Masking depths!")
|
|
||||||
depth_map = depth_map * fg_mask
|
|
||||||
|
|
||||||
return image_rgb, fg_mask, depth_map
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_chunked(func, chunk_generator, tensor_collator):
|
|
||||||
"""
|
|
||||||
Helper function to apply a function on a sequence of
|
|
||||||
chunked inputs yielded by a generator and collate
|
|
||||||
the result.
|
|
||||||
"""
|
|
||||||
processed_chunks = [
|
|
||||||
func(*chunk_args, **chunk_kwargs)
|
|
||||||
for chunk_args, chunk_kwargs in chunk_generator
|
|
||||||
]
|
|
||||||
|
|
||||||
return cat_dataclass(processed_chunks, tensor_collator)
|
|
||||||
|
|
||||||
|
|
||||||
def _tensor_collator(batch, new_dims) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Helper function to reshape the batch to the desired shape
|
|
||||||
"""
|
|
||||||
return torch.cat(batch, dim=1).reshape(*new_dims, -1)
|
|
||||||
|
|
||||||
|
|
||||||
def _chunk_generator(
|
|
||||||
chunk_size: int,
|
|
||||||
ray_bundle: ImplicitronRayBundle,
|
|
||||||
chunked_inputs: Dict[str, torch.Tensor],
|
|
||||||
tqdm_trigger_threshold: int,
|
|
||||||
*args,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Helper function which yields chunks of rays from the
|
|
||||||
input ray_bundle, to be used when the number of rays is
|
|
||||||
large and will not fit in memory for rendering.
|
|
||||||
"""
|
|
||||||
(
|
|
||||||
batch_size,
|
|
||||||
*spatial_dim,
|
|
||||||
n_pts_per_ray,
|
|
||||||
) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray
|
|
||||||
if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"chunk_size_grid ({chunk_size}) should be divisible "
|
|
||||||
f"by n_pts_per_ray ({n_pts_per_ray})"
|
|
||||||
)
|
|
||||||
|
|
||||||
n_rays = prod(spatial_dim)
|
|
||||||
# special handling for raytracing-based methods
|
|
||||||
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
|
|
||||||
chunk_size_in_rays = -(-n_rays // n_chunks)
|
|
||||||
|
|
||||||
iter = range(0, n_rays, chunk_size_in_rays)
|
|
||||||
if len(iter) >= tqdm_trigger_threshold:
|
|
||||||
iter = tqdm.tqdm(iter)
|
|
||||||
|
|
||||||
def _safe_slice(
|
|
||||||
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
|
|
||||||
) -> Any:
|
|
||||||
return tensor[start_idx:end_idx] if tensor is not None else None
|
|
||||||
|
|
||||||
for start_idx in iter:
|
|
||||||
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
|
|
||||||
ray_bundle_chunk = ImplicitronRayBundle(
|
|
||||||
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
|
|
||||||
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
|
||||||
:, start_idx:end_idx
|
|
||||||
],
|
|
||||||
lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
|
|
||||||
:, start_idx:end_idx
|
|
||||||
],
|
|
||||||
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
|
||||||
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
|
|
||||||
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
|
|
||||||
)
|
|
||||||
extra_args = kwargs.copy()
|
|
||||||
for k, v in chunked_inputs.items():
|
|
||||||
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
|
|
||||||
yield [ray_bundle_chunk, *args], extra_args
|
|
||||||
|
639
pytorch3d/implicitron/models/overfit_model.py
Normal file
639
pytorch3d/implicitron/models/overfit_model.py
Normal file
@ -0,0 +1,639 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The #noqa comments below are for unused imports of pluggable implementations
|
||||||
|
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from omegaconf import DictConfig
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.base_model import (
|
||||||
|
ImplicitronModelBase,
|
||||||
|
ImplicitronRender,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.global_encoder.global_encoder import GlobalEncoderBase
|
||||||
|
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
|
||||||
|
from pytorch3d.implicitron.models.metrics import (
|
||||||
|
RegularizationMetricsBase,
|
||||||
|
ViewMetricsBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import (
|
||||||
|
BaseRenderer,
|
||||||
|
EvaluationMode,
|
||||||
|
ImplicitronRayBundle,
|
||||||
|
RendererOutput,
|
||||||
|
RenderSamplingMode,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.models.renderer.ray_sampler import RaySamplerBase
|
||||||
|
from pytorch3d.implicitron.models.utils import (
|
||||||
|
apply_chunked,
|
||||||
|
chunk_generator,
|
||||||
|
log_loss_weights,
|
||||||
|
preprocess_input,
|
||||||
|
weighted_sum_losses,
|
||||||
|
)
|
||||||
|
from pytorch3d.implicitron.tools import vis_utils
|
||||||
|
from pytorch3d.implicitron.tools.config import (
|
||||||
|
expand_args_fields,
|
||||||
|
registry,
|
||||||
|
run_auto_creation,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_sparse_ray_bundle
|
||||||
|
from pytorch3d.renderer import utils as rend_utils
|
||||||
|
from pytorch3d.renderer.cameras import CamerasBase
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from visdom import Visdom
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
IMPLICIT_FUNCTION_ARGS_TO_REMOVE: List[str] = [
|
||||||
|
"feature_vector_size",
|
||||||
|
"encoding_dim",
|
||||||
|
"latent_dim",
|
||||||
|
"color_dim",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class OverfitModel(ImplicitronModelBase): # pyre-ignore: 13
|
||||||
|
"""
|
||||||
|
OverfitModel is a wrapper for the neural implicit
|
||||||
|
rendering and reconstruction pipeline which consists
|
||||||
|
of the following sequence of 4 steps:
|
||||||
|
|
||||||
|
|
||||||
|
(1) Ray Sampling
|
||||||
|
------------------
|
||||||
|
Rays are sampled from an image grid based on the target view(s).
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
(2) Implicit Function Evaluation
|
||||||
|
------------------
|
||||||
|
Evaluate the implicit function(s) at the sampled ray points
|
||||||
|
(also optionally pass in a global encoding from global_encoder).
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
(3) Rendering
|
||||||
|
------------------
|
||||||
|
Render the image into the target cameras by raymarching along
|
||||||
|
the sampled rays and aggregating the colors and densities
|
||||||
|
output by the implicit function in (2).
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
(4) Loss Computation
|
||||||
|
------------------
|
||||||
|
Compute losses based on the predicted target image(s).
|
||||||
|
|
||||||
|
|
||||||
|
The `forward` function of OverfitModel executes
|
||||||
|
this sequence of steps. Currently, steps 1, 2, 3
|
||||||
|
can be customized by intializing a subclass of the appropriate
|
||||||
|
base class and adding the newly created module to the registry.
|
||||||
|
Please see https://github.com/facebookresearch/pytorch3d/blob/main/projects/implicitron_trainer/README.md#custom-plugins
|
||||||
|
for more details on how to create and register a custom component.
|
||||||
|
|
||||||
|
In the config .yaml files for experiments, the parameters below are
|
||||||
|
contained in the
|
||||||
|
`model_factory_ImplicitronModelFactory_args.model_OverfitModel_args`
|
||||||
|
node. As OverfitModel derives from ReplaceableBase, the input arguments are
|
||||||
|
parsed by the run_auto_creation function to initialize the
|
||||||
|
necessary member modules. Please see implicitron_trainer/README.md
|
||||||
|
for more details on this process.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
mask_images: Whether or not to mask the RGB image background given the
|
||||||
|
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
|
||||||
|
mask_depths: Whether or not to mask the depth image background given the
|
||||||
|
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
|
||||||
|
render_image_width: Width of the output image to render
|
||||||
|
render_image_height: Height of the output image to render
|
||||||
|
mask_threshold: If greater than 0.0, the foreground mask is
|
||||||
|
thresholded by this value before being applied to the RGB/Depth images
|
||||||
|
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
|
||||||
|
splatting onto an image grid. Default: False.
|
||||||
|
bg_color: RGB values for setting the background color of input image
|
||||||
|
if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
|
||||||
|
way to determine the background color of its output, unrelated to this.
|
||||||
|
chunk_size_grid: The total number of points which can be rendered
|
||||||
|
per chunk. This is used to compute the number of rays used
|
||||||
|
per chunk when the chunked version of the renderer is used (in order
|
||||||
|
to fit rendering on all rays in memory)
|
||||||
|
render_features_dimensions: The number of output features to render.
|
||||||
|
Defaults to 3, corresponding to RGB images.
|
||||||
|
sampling_mode_training: The sampling method to use during training. Must be
|
||||||
|
a value from the RenderSamplingMode Enum.
|
||||||
|
sampling_mode_evaluation: Same as above but for evaluation.
|
||||||
|
global_encoder_class_type: The name of the class to use for global_encoder,
|
||||||
|
which must be available in the registry. Or `None` to disable global encoder.
|
||||||
|
global_encoder: An instance of `GlobalEncoder`. This is used to generate an encoding
|
||||||
|
of the image (referred to as the global_code) that can be used to model aspects of
|
||||||
|
the scene such as multiple objects or morphing objects. It is up to the implicit
|
||||||
|
function definition how to use it, but the most typical way is to broadcast and
|
||||||
|
concatenate to the other inputs for the implicit function.
|
||||||
|
raysampler_class_type: The name of the raysampler class which is available
|
||||||
|
in the global registry.
|
||||||
|
raysampler: An instance of RaySampler which is used to emit
|
||||||
|
rays from the target view(s).
|
||||||
|
renderer_class_type: The name of the renderer class which is available in the global
|
||||||
|
registry.
|
||||||
|
renderer: A renderer class which inherits from BaseRenderer. This is used to
|
||||||
|
generate the images from the target view(s).
|
||||||
|
share_implicit_function_across_passes: If set to True
|
||||||
|
coarse_implicit_function is automatically set as implicit_function
|
||||||
|
(coarse_implicit_function=implicit_funciton). The
|
||||||
|
implicit_functions are then run sequentially during the rendering.
|
||||||
|
implicit_function_class_type: The type of implicit function to use which
|
||||||
|
is available in the global registry.
|
||||||
|
implicit_function: An instance of ImplicitFunctionBase.
|
||||||
|
coarse_implicit_function_class_type: The type of implicit function to use which
|
||||||
|
is available in the global registry.
|
||||||
|
coarse_implicit_function: An instance of ImplicitFunctionBase.
|
||||||
|
If set and `share_implicit_function_across_passes` is set to False,
|
||||||
|
coarse_implicit_function is instantiated on itself. It
|
||||||
|
is then used as the second pass during the rendering.
|
||||||
|
If set to None, we only do a single pass with implicit_function.
|
||||||
|
view_metrics: An instance of ViewMetricsBase used to compute loss terms which
|
||||||
|
are independent of the model's parameters.
|
||||||
|
view_metrics_class_type: The type of view metrics to use, must be available in
|
||||||
|
the global registry.
|
||||||
|
regularization_metrics: An instance of RegularizationMetricsBase used to compute
|
||||||
|
regularization terms which can depend on the model's parameters.
|
||||||
|
regularization_metrics_class_type: The type of regularization metrics to use,
|
||||||
|
must be available in the global registry.
|
||||||
|
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
|
||||||
|
for `ViewMetrics` class for available loss functions.
|
||||||
|
log_vars: A list of variable names which should be logged.
|
||||||
|
The names should correspond to a subset of the keys of the
|
||||||
|
dict `preds` output by the `forward` function.
|
||||||
|
""" # noqa: B950
|
||||||
|
|
||||||
|
mask_images: bool = True
|
||||||
|
mask_depths: bool = True
|
||||||
|
render_image_width: int = 400
|
||||||
|
render_image_height: int = 400
|
||||||
|
mask_threshold: float = 0.5
|
||||||
|
output_rasterized_mc: bool = False
|
||||||
|
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
|
||||||
|
chunk_size_grid: int = 4096
|
||||||
|
render_features_dimensions: int = 3
|
||||||
|
tqdm_trigger_threshold: int = 16
|
||||||
|
|
||||||
|
n_train_target_views: int = 1
|
||||||
|
sampling_mode_training: str = "mask_sample"
|
||||||
|
sampling_mode_evaluation: str = "full_grid"
|
||||||
|
|
||||||
|
# ---- global encoder settings
|
||||||
|
global_encoder_class_type: Optional[str] = None
|
||||||
|
global_encoder: Optional[GlobalEncoderBase]
|
||||||
|
|
||||||
|
# ---- raysampler
|
||||||
|
raysampler_class_type: str = "AdaptiveRaySampler"
|
||||||
|
raysampler: RaySamplerBase
|
||||||
|
|
||||||
|
# ---- renderer configs
|
||||||
|
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
|
||||||
|
renderer: BaseRenderer
|
||||||
|
|
||||||
|
# ---- implicit function settings
|
||||||
|
share_implicit_function_across_passes: bool = False
|
||||||
|
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
|
||||||
|
implicit_function: ImplicitFunctionBase
|
||||||
|
coarse_implicit_function_class_type: Optional[str] = None
|
||||||
|
coarse_implicit_function: Optional[ImplicitFunctionBase]
|
||||||
|
|
||||||
|
# ----- metrics
|
||||||
|
view_metrics: ViewMetricsBase
|
||||||
|
view_metrics_class_type: str = "ViewMetrics"
|
||||||
|
|
||||||
|
regularization_metrics: RegularizationMetricsBase
|
||||||
|
regularization_metrics_class_type: str = "RegularizationMetrics"
|
||||||
|
|
||||||
|
# ---- loss weights
|
||||||
|
loss_weights: Dict[str, float] = field(
|
||||||
|
default_factory=lambda: {
|
||||||
|
"loss_rgb_mse": 1.0,
|
||||||
|
"loss_prev_stage_rgb_mse": 1.0,
|
||||||
|
"loss_mask_bce": 0.0,
|
||||||
|
"loss_prev_stage_mask_bce": 0.0,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- variables to be logged (logger automatically ignores if not computed)
|
||||||
|
log_vars: List[str] = field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
"loss_rgb_psnr_fg",
|
||||||
|
"loss_rgb_psnr",
|
||||||
|
"loss_rgb_mse",
|
||||||
|
"loss_rgb_huber",
|
||||||
|
"loss_depth_abs",
|
||||||
|
"loss_depth_abs_fg",
|
||||||
|
"loss_mask_neg_iou",
|
||||||
|
"loss_mask_bce",
|
||||||
|
"loss_mask_beta_prior",
|
||||||
|
"loss_eikonal",
|
||||||
|
"loss_density_tv",
|
||||||
|
"loss_depth_neg_penalty",
|
||||||
|
"loss_autodecoder_norm",
|
||||||
|
# metrics that are only logged in 2+stage renderes
|
||||||
|
"loss_prev_stage_rgb_mse",
|
||||||
|
"loss_prev_stage_rgb_psnr_fg",
|
||||||
|
"loss_prev_stage_rgb_psnr",
|
||||||
|
"loss_prev_stage_mask_bce",
|
||||||
|
# basic metrics
|
||||||
|
"objective",
|
||||||
|
"epoch",
|
||||||
|
"sec/it",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# The attribute will be filled by run_auto_creation
|
||||||
|
run_auto_creation(self)
|
||||||
|
log_loss_weights(self.loss_weights, logger)
|
||||||
|
# We need to set it here since run_auto_creation
|
||||||
|
# will create coarse_implicit_function before implicit_function
|
||||||
|
if self.share_implicit_function_across_passes:
|
||||||
|
self.coarse_implicit_function = self.implicit_function
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
*, # force keyword-only arguments
|
||||||
|
image_rgb: Optional[torch.Tensor],
|
||||||
|
camera: CamerasBase,
|
||||||
|
fg_probability: Optional[torch.Tensor] = None,
|
||||||
|
mask_crop: Optional[torch.Tensor] = None,
|
||||||
|
depth_map: Optional[torch.Tensor] = None,
|
||||||
|
sequence_name: Optional[List[str]] = None,
|
||||||
|
frame_timestamp: Optional[torch.Tensor] = None,
|
||||||
|
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
|
||||||
|
**kwargs,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
|
||||||
|
the first `min(B, n_train_target_views)` images are considered targets and
|
||||||
|
are used to supervise the renders; the rest corresponding to the source
|
||||||
|
viewpoints from which features will be extracted.
|
||||||
|
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
|
||||||
|
to the viewpoints of target images, from which the rays will be sampled,
|
||||||
|
and source images, which will be used for intersecting with target rays.
|
||||||
|
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
|
||||||
|
foreground masks.
|
||||||
|
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
|
||||||
|
regions in the input images (i.e. regions that do not correspond
|
||||||
|
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
|
||||||
|
"mask_sample", rays will be sampled in the non zero regions.
|
||||||
|
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
||||||
|
sequence_name: A list of `B` strings corresponding to the sequence names
|
||||||
|
from which images `image_rgb` were extracted. They are used to match
|
||||||
|
target frames with relevant source frames.
|
||||||
|
frame_timestamp: Optionally a tensor of shape `(B,)` containing a batch
|
||||||
|
of frame timestamps.
|
||||||
|
evaluation_mode: one of EvaluationMode.TRAINING or
|
||||||
|
EvaluationMode.EVALUATION which determines the settings used for
|
||||||
|
rendering.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
preds: A dictionary containing all outputs of the forward pass including the
|
||||||
|
rendered images, depths, masks, losses and other metrics.
|
||||||
|
"""
|
||||||
|
image_rgb, fg_probability, depth_map = preprocess_input(
|
||||||
|
image_rgb,
|
||||||
|
fg_probability,
|
||||||
|
depth_map,
|
||||||
|
self.mask_images,
|
||||||
|
self.mask_depths,
|
||||||
|
self.mask_threshold,
|
||||||
|
self.bg_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine the used ray sampling mode.
|
||||||
|
sampling_mode = RenderSamplingMode(
|
||||||
|
self.sampling_mode_training
|
||||||
|
if evaluation_mode == EvaluationMode.TRAINING
|
||||||
|
else self.sampling_mode_evaluation
|
||||||
|
)
|
||||||
|
|
||||||
|
# (1) Sample rendering rays with the ray sampler.
|
||||||
|
# pyre-ignore[29]
|
||||||
|
ray_bundle: ImplicitronRayBundle = self.raysampler(
|
||||||
|
camera,
|
||||||
|
evaluation_mode,
|
||||||
|
mask=mask_crop
|
||||||
|
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs_to_be_chunked = {}
|
||||||
|
if fg_probability is not None and self.renderer.requires_object_mask():
|
||||||
|
sampled_fb_prob = rend_utils.ndc_grid_sample(
|
||||||
|
fg_probability, ray_bundle.xys, mode="nearest"
|
||||||
|
)
|
||||||
|
inputs_to_be_chunked["object_mask"] = sampled_fb_prob > 0.5
|
||||||
|
|
||||||
|
# (2)-(3) Implicit function evaluation and Rendering
|
||||||
|
implicit_functions: List[Union[Callable, ImplicitFunctionBase]] = [
|
||||||
|
self.implicit_function
|
||||||
|
]
|
||||||
|
if self.coarse_implicit_function is not None:
|
||||||
|
implicit_functions += [self.coarse_implicit_function]
|
||||||
|
|
||||||
|
if self.global_encoder is not None:
|
||||||
|
global_code = self.global_encoder( # pyre-fixme[29]
|
||||||
|
sequence_name=sequence_name,
|
||||||
|
frame_timestamp=frame_timestamp,
|
||||||
|
)
|
||||||
|
implicit_functions = [
|
||||||
|
functools.partial(implicit_function, global_code=global_code)
|
||||||
|
if isinstance(implicit_function, Callable)
|
||||||
|
else functools.partial(
|
||||||
|
implicit_function.forward, global_code=global_code
|
||||||
|
)
|
||||||
|
for implicit_function in implicit_functions
|
||||||
|
]
|
||||||
|
rendered = self._render(
|
||||||
|
ray_bundle=ray_bundle,
|
||||||
|
sampling_mode=sampling_mode,
|
||||||
|
evaluation_mode=evaluation_mode,
|
||||||
|
implicit_functions=implicit_functions,
|
||||||
|
inputs_to_be_chunked=inputs_to_be_chunked,
|
||||||
|
)
|
||||||
|
|
||||||
|
# A dict to store losses as well as rendering results.
|
||||||
|
preds: Dict[str, Any] = self.view_metrics(
|
||||||
|
results={},
|
||||||
|
raymarched=rendered,
|
||||||
|
ray_bundle=ray_bundle,
|
||||||
|
image_rgb=image_rgb,
|
||||||
|
depth_map=depth_map,
|
||||||
|
fg_probability=fg_probability,
|
||||||
|
mask_crop=mask_crop,
|
||||||
|
)
|
||||||
|
|
||||||
|
preds.update(
|
||||||
|
self.regularization_metrics(
|
||||||
|
results=preds,
|
||||||
|
model=self,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
|
||||||
|
if self.output_rasterized_mc:
|
||||||
|
# Visualize the monte-carlo pixel renders by splatting onto
|
||||||
|
# an image grid.
|
||||||
|
(
|
||||||
|
preds["images_render"],
|
||||||
|
preds["depths_render"],
|
||||||
|
preds["masks_render"],
|
||||||
|
) = rasterize_sparse_ray_bundle(
|
||||||
|
ray_bundle,
|
||||||
|
rendered.features,
|
||||||
|
(self.render_image_height, self.render_image_width),
|
||||||
|
rendered.depths,
|
||||||
|
masks=rendered.masks,
|
||||||
|
)
|
||||||
|
elif sampling_mode == RenderSamplingMode.FULL_GRID:
|
||||||
|
preds["images_render"] = rendered.features.permute(0, 3, 1, 2)
|
||||||
|
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
|
||||||
|
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
preds["implicitron_render"] = ImplicitronRender(
|
||||||
|
image_render=preds["images_render"],
|
||||||
|
depth_render=preds["depths_render"],
|
||||||
|
mask_render=preds["masks_render"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise AssertionError("Unreachable state")
|
||||||
|
|
||||||
|
# (4) Compute losses
|
||||||
|
# finally get the optimization objective using self.loss_weights
|
||||||
|
objective = self._get_objective(preds)
|
||||||
|
if objective is not None:
|
||||||
|
preds["objective"] = objective
|
||||||
|
|
||||||
|
return preds
|
||||||
|
|
||||||
|
def _get_objective(self, preds: Dict[str, torch.Tensor]) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
A helper function to compute the overall loss as the dot product
|
||||||
|
of individual loss functions with the corresponding weights.
|
||||||
|
"""
|
||||||
|
return weighted_sum_losses(preds, self.loss_weights)
|
||||||
|
|
||||||
|
def visualize(
|
||||||
|
self,
|
||||||
|
viz: Optional["Visdom"],
|
||||||
|
visdom_env_imgs: str,
|
||||||
|
preds: Dict[str, Any],
|
||||||
|
prefix: str,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Helper function to visualize the predictions generated
|
||||||
|
in the forward pass.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
viz: Visdom connection object
|
||||||
|
visdom_env_imgs: name of visdom environment for the images.
|
||||||
|
preds: predictions dict like returned by forward()
|
||||||
|
prefix: prepended to the names of images
|
||||||
|
"""
|
||||||
|
if viz is None or not viz.check_connection():
|
||||||
|
logger.info("no visdom server! -> skipping batch vis")
|
||||||
|
return
|
||||||
|
|
||||||
|
idx_image = 0
|
||||||
|
title = f"{prefix}_im{idx_image}"
|
||||||
|
|
||||||
|
vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title)
|
||||||
|
|
||||||
|
def _render(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ray_bundle: ImplicitronRayBundle,
|
||||||
|
inputs_to_be_chunked: Dict[str, torch.Tensor],
|
||||||
|
sampling_mode: RenderSamplingMode,
|
||||||
|
**kwargs,
|
||||||
|
) -> RendererOutput:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ray_bundle: A `ImplicitronRayBundle` object containing the parametrizations of the
|
||||||
|
sampled rendering rays.
|
||||||
|
inputs_to_be_chunked: A collection of tensor of shape `(B, _, H, W)`. E.g.
|
||||||
|
SignedDistanceFunctionRenderer requires "object_mask", shape
|
||||||
|
(B, 1, H, W), the silhouette of the object in the image. When
|
||||||
|
chunking, they are passed to the renderer as shape
|
||||||
|
`(B, _, chunksize)`.
|
||||||
|
sampling_mode: The sampling method to use. Must be a value from the
|
||||||
|
RenderSamplingMode Enum.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of RendererOutput
|
||||||
|
"""
|
||||||
|
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
|
||||||
|
return apply_chunked(
|
||||||
|
self.renderer,
|
||||||
|
chunk_generator(
|
||||||
|
self.chunk_size_grid,
|
||||||
|
ray_bundle,
|
||||||
|
inputs_to_be_chunked,
|
||||||
|
self.tqdm_trigger_threshold,
|
||||||
|
**kwargs,
|
||||||
|
),
|
||||||
|
lambda batch: torch.cat(batch, dim=1).reshape(
|
||||||
|
*ray_bundle.lengths.shape[:-1], -1
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# pyre-fixme[29]: `BaseRenderer` is not a function.
|
||||||
|
return self.renderer(
|
||||||
|
ray_bundle=ray_bundle,
|
||||||
|
**inputs_to_be_chunked,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def raysampler_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain fields of the raysampler because we want to set
|
||||||
|
them from our own members.
|
||||||
|
"""
|
||||||
|
del args["sampling_mode_training"]
|
||||||
|
del args["sampling_mode_evaluation"]
|
||||||
|
del args["image_width"]
|
||||||
|
del args["image_height"]
|
||||||
|
|
||||||
|
def create_raysampler(self):
|
||||||
|
extra_args = {
|
||||||
|
"sampling_mode_training": self.sampling_mode_training,
|
||||||
|
"sampling_mode_evaluation": self.sampling_mode_evaluation,
|
||||||
|
"image_width": self.render_image_width,
|
||||||
|
"image_height": self.render_image_height,
|
||||||
|
}
|
||||||
|
raysampler_args = getattr(
|
||||||
|
self, "raysampler_" + self.raysampler_class_type + "_args"
|
||||||
|
)
|
||||||
|
self.raysampler = registry.get(RaySamplerBase, self.raysampler_class_type)(
|
||||||
|
**raysampler_args, **extra_args
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def renderer_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain fields of the renderer because we want to set
|
||||||
|
them based on other inputs.
|
||||||
|
"""
|
||||||
|
args.pop("render_features_dimensions", None)
|
||||||
|
args.pop("object_bounding_sphere", None)
|
||||||
|
|
||||||
|
def create_renderer(self):
|
||||||
|
extra_args = {}
|
||||||
|
|
||||||
|
if self.renderer_class_type == "SignedDistanceFunctionRenderer":
|
||||||
|
extra_args["render_features_dimensions"] = self.render_features_dimensions
|
||||||
|
if not hasattr(self.raysampler, "scene_extent"):
|
||||||
|
raise ValueError(
|
||||||
|
"SignedDistanceFunctionRenderer requires"
|
||||||
|
+ " a raysampler that defines the 'scene_extent' field"
|
||||||
|
+ " (this field is supported by, e.g., the adaptive raysampler - "
|
||||||
|
+ " self.raysampler_class_type='AdaptiveRaySampler')."
|
||||||
|
)
|
||||||
|
extra_args["object_bounding_sphere"] = self.raysampler.scene_extent
|
||||||
|
|
||||||
|
renderer_args = getattr(self, "renderer_" + self.renderer_class_type + "_args")
|
||||||
|
self.renderer = registry.get(BaseRenderer, self.renderer_class_type)(
|
||||||
|
**renderer_args, **extra_args
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain implicit_function fields because we want to set
|
||||||
|
them based on other inputs.
|
||||||
|
"""
|
||||||
|
for arg in IMPLICIT_FUNCTION_ARGS_TO_REMOVE:
|
||||||
|
args.pop(arg, None)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def coarse_implicit_function_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
We don't expose certain implicit_function fields because we want to set
|
||||||
|
them based on other inputs.
|
||||||
|
"""
|
||||||
|
for arg in IMPLICIT_FUNCTION_ARGS_TO_REMOVE:
|
||||||
|
args.pop(arg, None)
|
||||||
|
|
||||||
|
def _create_extra_args_for_implicit_function(self) -> Dict[str, Any]:
|
||||||
|
extra_args = {}
|
||||||
|
global_encoder_dim = (
|
||||||
|
0 if self.global_encoder is None else self.global_encoder.get_encoding_dim()
|
||||||
|
)
|
||||||
|
if self.implicit_function_class_type in (
|
||||||
|
"NeuralRadianceFieldImplicitFunction",
|
||||||
|
"NeRFormerImplicitFunction",
|
||||||
|
):
|
||||||
|
extra_args["latent_dim"] = global_encoder_dim
|
||||||
|
extra_args["color_dim"] = self.render_features_dimensions
|
||||||
|
|
||||||
|
if self.implicit_function_class_type == "IdrFeatureField":
|
||||||
|
extra_args["feature_work_size"] = global_encoder_dim
|
||||||
|
extra_args["feature_vector_size"] = self.render_features_dimensions
|
||||||
|
|
||||||
|
if self.implicit_function_class_type == "SRNImplicitFunction":
|
||||||
|
extra_args["latent_dim"] = global_encoder_dim
|
||||||
|
return extra_args
|
||||||
|
|
||||||
|
def create_implicit_function(self) -> None:
|
||||||
|
implicit_function_type = registry.get(
|
||||||
|
ImplicitFunctionBase, self.implicit_function_class_type
|
||||||
|
)
|
||||||
|
expand_args_fields(implicit_function_type)
|
||||||
|
|
||||||
|
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
|
||||||
|
config = getattr(self, config_name, None)
|
||||||
|
if config is None:
|
||||||
|
raise ValueError(f"{config_name} not present")
|
||||||
|
|
||||||
|
extra_args = self._create_extra_args_for_implicit_function()
|
||||||
|
self.implicit_function = implicit_function_type(**config, **extra_args)
|
||||||
|
|
||||||
|
def create_coarse_implicit_function(self) -> None:
|
||||||
|
# If coarse_implicit_function_class_type has been defined
|
||||||
|
# then we init a module based on its arguments
|
||||||
|
if (
|
||||||
|
self.coarse_implicit_function_class_type is not None
|
||||||
|
and not self.share_implicit_function_across_passes
|
||||||
|
):
|
||||||
|
config_name = "coarse_implicit_function_{0}_args".format(
|
||||||
|
self.coarse_implicit_function_class_type
|
||||||
|
)
|
||||||
|
config = getattr(self, config_name, {})
|
||||||
|
|
||||||
|
implicit_function_type = registry.get(
|
||||||
|
ImplicitFunctionBase,
|
||||||
|
# pyre-ignore: config is None allow to check if this is None.
|
||||||
|
self.coarse_implicit_function_class_type,
|
||||||
|
)
|
||||||
|
expand_args_fields(implicit_function_type)
|
||||||
|
|
||||||
|
extra_args = self._create_extra_args_for_implicit_function()
|
||||||
|
self.coarse_implicit_function = implicit_function_type(
|
||||||
|
**config, **extra_args
|
||||||
|
)
|
||||||
|
elif self.share_implicit_function_across_passes:
|
||||||
|
# Since coarse_implicit_function is initialised before
|
||||||
|
# implicit_function we handle this case in the post_init.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.coarse_implicit_function = None
|
195
pytorch3d/implicitron/models/utils.py
Normal file
195
pytorch3d/implicitron/models/utils.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
# Note: The #noqa comments below are for unused imports of pluggable implementations
|
||||||
|
# which are part of implicitron. They ensure that the registry is prepopulated.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from logging import Logger
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import tqdm
|
||||||
|
from pytorch3d.common.compat import prod
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.tools import image_utils
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.tools.utils import cat_dataclass
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_input(
|
||||||
|
image_rgb: Optional[torch.Tensor],
|
||||||
|
fg_probability: Optional[torch.Tensor],
|
||||||
|
depth_map: Optional[torch.Tensor],
|
||||||
|
mask_images: bool,
|
||||||
|
mask_depths: bool,
|
||||||
|
mask_threshold: float,
|
||||||
|
bg_color: Tuple[float, float, float],
|
||||||
|
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Helper function to preprocess the input images and optional depth maps
|
||||||
|
to apply masking if required.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
|
||||||
|
corresponding to the source viewpoints from which features will be extracted
|
||||||
|
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
|
||||||
|
of foreground masks with values in [0, 1].
|
||||||
|
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
|
||||||
|
mask_images: Whether or not to mask the RGB image background given the
|
||||||
|
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
|
||||||
|
mask_depths: Whether or not to mask the depth image background given the
|
||||||
|
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
|
||||||
|
mask_threshold: If greater than 0.0, the foreground mask is
|
||||||
|
thresholded by this value before being applied to the RGB/Depth images
|
||||||
|
bg_color: RGB values for setting the background color of input image
|
||||||
|
if mask_images=True. Defaults to (0.0, 0.0, 0.0). Each renderer has its own
|
||||||
|
way to determine the background color of its output, unrelated to this.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Modified image_rgb, fg_mask, depth_map
|
||||||
|
"""
|
||||||
|
if image_rgb is not None and image_rgb.ndim == 3:
|
||||||
|
# The FrameData object is used for both frames and batches of frames,
|
||||||
|
# and a user might get this error if those were confused.
|
||||||
|
# Perhaps a user has a FrameData `fd` representing a single frame and
|
||||||
|
# wrote something like `model(**fd)` instead of
|
||||||
|
# `model(**fd.collate([fd]))`.
|
||||||
|
raise ValueError(
|
||||||
|
"Model received unbatched inputs. "
|
||||||
|
+ "Perhaps they came from a FrameData which had not been collated."
|
||||||
|
)
|
||||||
|
|
||||||
|
fg_mask = fg_probability
|
||||||
|
if fg_mask is not None and mask_threshold > 0.0:
|
||||||
|
# threshold masks
|
||||||
|
warnings.warn("Thresholding masks!")
|
||||||
|
fg_mask = (fg_mask >= mask_threshold).type_as(fg_mask)
|
||||||
|
|
||||||
|
if mask_images and fg_mask is not None and image_rgb is not None:
|
||||||
|
# mask the image
|
||||||
|
warnings.warn("Masking images!")
|
||||||
|
image_rgb = image_utils.mask_background(
|
||||||
|
image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(bg_color)
|
||||||
|
)
|
||||||
|
|
||||||
|
if mask_depths and fg_mask is not None and depth_map is not None:
|
||||||
|
# mask the depths
|
||||||
|
assert (
|
||||||
|
mask_threshold > 0.0
|
||||||
|
), "Depths should be masked only with thresholded masks"
|
||||||
|
warnings.warn("Masking depths!")
|
||||||
|
depth_map = depth_map * fg_mask
|
||||||
|
|
||||||
|
return image_rgb, fg_mask, depth_map
|
||||||
|
|
||||||
|
|
||||||
|
def log_loss_weights(loss_weights: Dict[str, float], logger: Logger) -> None:
|
||||||
|
"""
|
||||||
|
Print a table of the loss weights.
|
||||||
|
"""
|
||||||
|
loss_weights_message = (
|
||||||
|
"-------\nloss_weights:\n"
|
||||||
|
+ "\n".join(f"{k:40s}: {w:1.2e}" for k, w in loss_weights.items())
|
||||||
|
+ "-------"
|
||||||
|
)
|
||||||
|
logger.info(loss_weights_message)
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_sum_losses(
|
||||||
|
preds: Dict[str, torch.Tensor], loss_weights: Dict[str, float]
|
||||||
|
) -> Optional[torch.Tensor]:
|
||||||
|
"""
|
||||||
|
A helper function to compute the overall loss as the dot product
|
||||||
|
of individual loss functions with the corresponding weights.
|
||||||
|
"""
|
||||||
|
losses_weighted = [
|
||||||
|
preds[k] * float(w)
|
||||||
|
for k, w in loss_weights.items()
|
||||||
|
if (k in preds and w != 0.0)
|
||||||
|
]
|
||||||
|
if len(losses_weighted) == 0:
|
||||||
|
warnings.warn("No main objective found.")
|
||||||
|
return None
|
||||||
|
loss = sum(losses_weighted)
|
||||||
|
assert torch.is_tensor(loss)
|
||||||
|
# pyre-fixme[7]: Expected `Optional[Tensor]` but got `int`.
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def apply_chunked(func, chunk_generator, tensor_collator):
|
||||||
|
"""
|
||||||
|
Helper function to apply a function on a sequence of
|
||||||
|
chunked inputs yielded by a generator and collate
|
||||||
|
the result.
|
||||||
|
"""
|
||||||
|
processed_chunks = [
|
||||||
|
func(*chunk_args, **chunk_kwargs)
|
||||||
|
for chunk_args, chunk_kwargs in chunk_generator
|
||||||
|
]
|
||||||
|
|
||||||
|
return cat_dataclass(processed_chunks, tensor_collator)
|
||||||
|
|
||||||
|
|
||||||
|
def chunk_generator(
|
||||||
|
chunk_size: int,
|
||||||
|
ray_bundle: ImplicitronRayBundle,
|
||||||
|
chunked_inputs: Dict[str, torch.Tensor],
|
||||||
|
tqdm_trigger_threshold: int,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Helper function which yields chunks of rays from the
|
||||||
|
input ray_bundle, to be used when the number of rays is
|
||||||
|
large and will not fit in memory for rendering.
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
batch_size,
|
||||||
|
*spatial_dim,
|
||||||
|
n_pts_per_ray,
|
||||||
|
) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray
|
||||||
|
if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"chunk_size_grid ({chunk_size}) should be divisible "
|
||||||
|
f"by n_pts_per_ray ({n_pts_per_ray})"
|
||||||
|
)
|
||||||
|
|
||||||
|
n_rays = prod(spatial_dim)
|
||||||
|
# special handling for raytracing-based methods
|
||||||
|
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
|
||||||
|
chunk_size_in_rays = -(-n_rays // n_chunks)
|
||||||
|
|
||||||
|
iter = range(0, n_rays, chunk_size_in_rays)
|
||||||
|
if len(iter) >= tqdm_trigger_threshold:
|
||||||
|
iter = tqdm.tqdm(iter)
|
||||||
|
|
||||||
|
def _safe_slice(
|
||||||
|
tensor: Optional[torch.Tensor], start_idx: int, end_idx: int
|
||||||
|
) -> Any:
|
||||||
|
return tensor[start_idx:end_idx] if tensor is not None else None
|
||||||
|
|
||||||
|
for start_idx in iter:
|
||||||
|
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
|
||||||
|
ray_bundle_chunk = ImplicitronRayBundle(
|
||||||
|
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
|
||||||
|
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
|
||||||
|
:, start_idx:end_idx
|
||||||
|
],
|
||||||
|
lengths=ray_bundle.lengths.reshape(batch_size, n_rays, n_pts_per_ray)[
|
||||||
|
:, start_idx:end_idx
|
||||||
|
],
|
||||||
|
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
|
||||||
|
camera_ids=_safe_slice(ray_bundle.camera_ids, start_idx, end_idx),
|
||||||
|
camera_counts=_safe_slice(ray_bundle.camera_counts, start_idx, end_idx),
|
||||||
|
)
|
||||||
|
extra_args = kwargs.copy()
|
||||||
|
for k, v in chunked_inputs.items():
|
||||||
|
extra_args[k] = v.flatten(2)[:, :, start_idx:end_idx]
|
||||||
|
yield [ray_bundle_chunk, *args], extra_args
|
@ -7,8 +7,9 @@
|
|||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import pytorch3d
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.implicitron.models.renderer.base import ImplicitronRayBundle
|
|
||||||
from pytorch3d.ops import packed_to_padded
|
from pytorch3d.ops import packed_to_padded
|
||||||
from pytorch3d.renderer import PerspectiveCameras
|
from pytorch3d.renderer import PerspectiveCameras
|
||||||
from pytorch3d.structures import Pointclouds
|
from pytorch3d.structures import Pointclouds
|
||||||
@ -18,7 +19,7 @@ from .point_cloud_utils import render_point_cloud_pytorch3d
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def rasterize_sparse_ray_bundle(
|
def rasterize_sparse_ray_bundle(
|
||||||
ray_bundle: ImplicitronRayBundle,
|
ray_bundle: "pytorch3d.implicitron.models.renderer.base.ImplicitronRayBundle",
|
||||||
features: torch.Tensor,
|
features: torch.Tensor,
|
||||||
image_size_hw: Tuple[int, int],
|
image_size_hw: Tuple[int, int],
|
||||||
depth: torch.Tensor,
|
depth: torch.Tensor,
|
||||||
|
5
tests/implicitron/models/__init__.py
Normal file
5
tests/implicitron/models/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
217
tests/implicitron/models/test_overfit_model.py
Normal file
217
tests/implicitron/models/test_overfit_model.py
Normal file
@ -0,0 +1,217 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||||
|
from pytorch3d.implicitron.models.overfit_model import OverfitModel
|
||||||
|
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||||
|
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||||
|
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
||||||
|
|
||||||
|
DEVICE = torch.device("cuda:0")
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_fake_inputs(N: int, H: int, W: int) -> Dict[str, Any]:
|
||||||
|
R, T = look_at_view_transform(azim=torch.rand(N) * 360)
|
||||||
|
return {
|
||||||
|
"camera": PerspectiveCameras(R=R, T=T, device=DEVICE),
|
||||||
|
"fg_probability": torch.randint(
|
||||||
|
high=2, size=(N, 1, H, W), device=DEVICE
|
||||||
|
).float(),
|
||||||
|
"depth_map": torch.rand((N, 1, H, W), device=DEVICE) + 0.1,
|
||||||
|
"mask_crop": torch.randint(high=2, size=(N, 1, H, W), device=DEVICE).float(),
|
||||||
|
"sequence_name": ["sequence"] * N,
|
||||||
|
"image_rgb": torch.rand((N, 1, H, W), device=DEVICE),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def mock_safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
|
||||||
|
"""Return non deterministic indexes to mock safe_multinomial
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: tensor of shape [B, n] containing non-negative values;
|
||||||
|
rows are interpreted as unnormalized event probabilities
|
||||||
|
in categorical distributions.
|
||||||
|
num_samples: number of samples to take.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of shape [B, num_samples]
|
||||||
|
"""
|
||||||
|
batch_size = input.shape[0]
|
||||||
|
return torch.arange(num_samples).repeat(batch_size, 1).to(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOverfitModel(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
def test_overfit_model_vs_generic_model_with_batch_size_one(self):
|
||||||
|
"""In this test we compare OverfitModel to GenericModel behavior.
|
||||||
|
|
||||||
|
We use a Nerf setup (2 rendering passes).
|
||||||
|
|
||||||
|
OverfitModel is a specific case of GenericModel. Hence, with the same inputs,
|
||||||
|
they should provide the exact same results.
|
||||||
|
"""
|
||||||
|
expand_args_fields(OverfitModel)
|
||||||
|
expand_args_fields(GenericModel)
|
||||||
|
batch_size, image_height, image_width = 1, 80, 80
|
||||||
|
assert batch_size == 1
|
||||||
|
overfit_model = OverfitModel(
|
||||||
|
render_image_height=image_height,
|
||||||
|
render_image_width=image_width,
|
||||||
|
coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
|
||||||
|
# To avoid randomization to compare the outputs of our model
|
||||||
|
# we deactivate the stratified_point_sampling_training
|
||||||
|
raysampler_AdaptiveRaySampler_args={
|
||||||
|
"stratified_point_sampling_training": False
|
||||||
|
},
|
||||||
|
global_encoder_class_type="SequenceAutodecoder",
|
||||||
|
global_encoder_SequenceAutodecoder_args={
|
||||||
|
"autodecoder_args": {
|
||||||
|
"n_instances": 1000,
|
||||||
|
"init_scale": 1.0,
|
||||||
|
"encoding_dim": 64,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
generic_model = GenericModel(
|
||||||
|
render_image_height=image_height,
|
||||||
|
render_image_width=image_width,
|
||||||
|
n_train_target_views=batch_size,
|
||||||
|
num_passes=2,
|
||||||
|
# To avoid randomization to compare the outputs of our model
|
||||||
|
# we deactivate the stratified_point_sampling_training
|
||||||
|
raysampler_AdaptiveRaySampler_args={
|
||||||
|
"stratified_point_sampling_training": False
|
||||||
|
},
|
||||||
|
global_encoder_class_type="SequenceAutodecoder",
|
||||||
|
global_encoder_SequenceAutodecoder_args={
|
||||||
|
"autodecoder_args": {
|
||||||
|
"n_instances": 1000,
|
||||||
|
"init_scale": 1.0,
|
||||||
|
"encoding_dim": 64,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if they do share the number of parameters
|
||||||
|
num_params_mvm = sum(p.numel() for p in overfit_model.parameters())
|
||||||
|
num_params_gm = sum(p.numel() for p in generic_model.parameters())
|
||||||
|
self.assertEqual(num_params_mvm, num_params_gm)
|
||||||
|
|
||||||
|
# Adapt the mapping from generic model to overfit model
|
||||||
|
mapping_om_from_gm = {
|
||||||
|
key.replace("_implicit_functions.0._fn", "implicit_function").replace(
|
||||||
|
"_implicit_functions.1._fn", "coarse_implicit_function"
|
||||||
|
): val
|
||||||
|
for key, val in generic_model.state_dict().items()
|
||||||
|
}
|
||||||
|
# Copy parameters from generic_model to overfit_model
|
||||||
|
overfit_model.load_state_dict(mapping_om_from_gm)
|
||||||
|
|
||||||
|
overfit_model.to(DEVICE)
|
||||||
|
generic_model.to(DEVICE)
|
||||||
|
inputs_ = _generate_fake_inputs(batch_size, image_height, image_width)
|
||||||
|
|
||||||
|
# training forward pass
|
||||||
|
overfit_model.train()
|
||||||
|
generic_model.train()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"pytorch3d.renderer.implicit.raysampling._safe_multinomial",
|
||||||
|
side_effect=mock_safe_multinomial,
|
||||||
|
):
|
||||||
|
train_preds_om = overfit_model(
|
||||||
|
**inputs_,
|
||||||
|
evaluation_mode=EvaluationMode.TRAINING,
|
||||||
|
)
|
||||||
|
train_preds_gm = generic_model(
|
||||||
|
**inputs_,
|
||||||
|
evaluation_mode=EvaluationMode.TRAINING,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(len(train_preds_om) == len(train_preds_gm))
|
||||||
|
|
||||||
|
self.assertTrue(train_preds_om["objective"].isfinite().item())
|
||||||
|
# We avoid all the randomization and the weights are the same
|
||||||
|
# The objective should be the same
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(train_preds_om["objective"], train_preds_gm["objective"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test if the evaluation works
|
||||||
|
overfit_model.eval()
|
||||||
|
generic_model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
eval_preds_om = overfit_model(
|
||||||
|
**inputs_,
|
||||||
|
evaluation_mode=EvaluationMode.EVALUATION,
|
||||||
|
)
|
||||||
|
eval_preds_gm = generic_model(
|
||||||
|
**inputs_,
|
||||||
|
evaluation_mode=EvaluationMode.EVALUATION,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
eval_preds_om["images_render"].shape,
|
||||||
|
(batch_size, 3, image_height, image_width),
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(eval_preds_om["objective"], eval_preds_gm["objective"])
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
eval_preds_om["images_render"], eval_preds_gm["images_render"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_overfit_model_check_share_weights(self):
|
||||||
|
model = OverfitModel(share_implicit_function_across_passes=True)
|
||||||
|
for p1, p2 in zip(
|
||||||
|
model.implicit_function.parameters(),
|
||||||
|
model.coarse_implicit_function.parameters(),
|
||||||
|
):
|
||||||
|
self.assertEqual(id(p1), id(p2))
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
inputs_ = _generate_fake_inputs(2, 80, 80)
|
||||||
|
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
|
||||||
|
|
||||||
|
def test_overfit_model_check_no_share_weights(self):
|
||||||
|
model = OverfitModel(
|
||||||
|
share_implicit_function_across_passes=False,
|
||||||
|
coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
|
||||||
|
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args={
|
||||||
|
"transformer_dim_down_factor": 1.0,
|
||||||
|
"n_hidden_neurons_xyz": 256,
|
||||||
|
"n_layers_xyz": 8,
|
||||||
|
"append_xyz": (5,),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for p1, p2 in zip(
|
||||||
|
model.implicit_function.parameters(),
|
||||||
|
model.coarse_implicit_function.parameters(),
|
||||||
|
):
|
||||||
|
self.assertNotEqual(id(p1), id(p2))
|
||||||
|
|
||||||
|
model.to(DEVICE)
|
||||||
|
inputs_ = _generate_fake_inputs(2, 80, 80)
|
||||||
|
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
|
||||||
|
|
||||||
|
def test_overfit_model_coarse_implicit_function_is_none(self):
|
||||||
|
model = OverfitModel(
|
||||||
|
share_implicit_function_across_passes=False,
|
||||||
|
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args=None,
|
||||||
|
)
|
||||||
|
self.assertIsNone(model.coarse_implicit_function)
|
||||||
|
model.to(DEVICE)
|
||||||
|
inputs_ = _generate_fake_inputs(2, 80, 80)
|
||||||
|
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
|
66
tests/implicitron/models/test_utils.py
Normal file
66
tests/implicitron/models/test_utils.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses
|
||||||
|
|
||||||
|
|
||||||
|
class TestUtils(unittest.TestCase):
|
||||||
|
def test_prepare_inputs_wrong_num_dim(self):
|
||||||
|
img = torch.randn(3, 3, 3)
|
||||||
|
with self.assertRaises(ValueError) as context:
|
||||||
|
img, fg_prob, depth_map = preprocess_input(
|
||||||
|
img, None, None, True, True, 0.5, (0.0, 0.0, 0.0)
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
"Model received unbatched inputs. "
|
||||||
|
+ "Perhaps they came from a FrameData which had not been collated.",
|
||||||
|
context.exception,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prepare_inputs_mask_image_true(self):
|
||||||
|
batch, channels, height, width = 2, 3, 10, 10
|
||||||
|
img = torch.ones(batch, channels, height, width)
|
||||||
|
# Create a mask on the lower triangular matrix
|
||||||
|
fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
|
||||||
|
|
||||||
|
out_img, out_fg_prob, out_depth_map = preprocess_input(
|
||||||
|
img, fg_prob, None, True, False, 0.3, (0.0, 0.0, 0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(out_img, torch.tril(img)))
|
||||||
|
self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
|
||||||
|
self.assertIsNone(out_depth_map)
|
||||||
|
|
||||||
|
def test_prepare_inputs_mask_depth_true(self):
|
||||||
|
batch, channels, height, width = 2, 3, 10, 10
|
||||||
|
img = torch.ones(batch, channels, height, width)
|
||||||
|
depth_map = torch.randn(batch, channels, height, width)
|
||||||
|
# Create a mask on the lower triangular matrix
|
||||||
|
fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
|
||||||
|
|
||||||
|
out_img, out_fg_prob, out_depth_map = preprocess_input(
|
||||||
|
img, fg_prob, depth_map, False, True, 0.3, (0.0, 0.0, 0.0)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(torch.equal(out_img, img))
|
||||||
|
self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
|
||||||
|
self.assertTrue(torch.equal(out_depth_map, torch.tril(depth_map)))
|
||||||
|
|
||||||
|
def test_weighted_sum_losses(self):
|
||||||
|
preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
|
||||||
|
weights = {"a": 2.0, "b": 0.0}
|
||||||
|
loss = weighted_sum_losses(preds, weights)
|
||||||
|
self.assertEqual(loss, 4.0)
|
||||||
|
|
||||||
|
def test_weighted_sum_losses_raise_warning(self):
|
||||||
|
preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
|
||||||
|
weights = {"c": 2.0, "d": 2.0}
|
||||||
|
self.assertIsNone(weighted_sum_losses(preds, weights))
|
Loading…
x
Reference in New Issue
Block a user