implicitron v0 (#1133)

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
This commit is contained in:
Jeremy Reizenstein 2022-03-21 20:20:10 +00:00 committed by GitHub
parent 0e377c6850
commit cdd2142dd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
90 changed files with 17075 additions and 0 deletions

View File

@ -0,0 +1,276 @@
# Introduction
Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling the neural-network based representations.
# License
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
# Installation
There are three ways to set up Implicitron, depending on the flexibility level required.
If you only want to train or evaluate models as they are implemented changing only the parameters, you can just install the package.
Implicitron also provides a flexible API that supports user-defined plug-ins;
if you want to re-implement some of the components without changing the high-level pipeline, you need to create a custom launcher script.
The most flexible option, though, is cloning PyTorch3D repo and building it from sources, which allows changing the code in arbitrary ways.
Below, we descibe all three options in more details.
## [Option 1] Running an executable from the package
This option allows you to use the code as is without changing the implementations.
Only configuration can be changed (see [Configuration system](#configuration-system)).
For this setup, install the dependencies and PyTorch3D from conda following [the guide](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#1-install-with-cuda-support-from-anaconda-cloud-on-linux-only). Then, install implicitron-specific dependencies:
```shell
pip install "hydra-core>=1.1" visdom lpips matplotlib
```
Runner executable is available as `pytorch3d_implicitron_runner` shell command.
See [Running](#running) section below for examples of training and evaluation commands.
## [Option 2] Supporting custom implementations
To plug in custom implementations, for example, of renderer or implicit-function protocols, you need to create your own runner script and import the plug-in implementations there.
First, install PyTorch3D and Implicitron dependencies as described in the previous section.
Then, implement the custom script; copying `pytorch3d/projects/implicitron_trainer/experiment.py` is a good place to start.
See [Custom plugins](#custom-plugins) for more information on how to import implementations and enable them in the configs.
## [Option 3] Cloning PyTorch3D repo
This is the most flexible way to set up Implicitron as it allows changing the code directly.
It allows modifying the high-level rendering pipeline or implementing yet-unsupported loss functions.
Please follow the instructions to [install PyTorch3D from a local clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone).
Then, install Implicitron-specific dependencies:
```shell
pip install "hydra-core>=1.1" visdom lpips matplotlib
```
You are still encouraged to implement custom plugins as above where possible as it makes reusing the code easier.
The executable is located in `pytorch3d/projects/implicitron_trainer`.
# Running
This section assumes that you use the executable provided by the installed package.
If you have a custom `experiment.py` script (as in the Option 2 above), replace the executable with the path to your script.
## Training
To run training, pass a yaml config file, followed by a list of overridden arguments.
For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run:
```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir=<CHECKPOINT_DIR>
```
Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location;
`--config-name` picks the config (in this case, `repro_singleseq_nerf.yaml`);
`test_when_finished` will launch evaluation script once training is finished.
Replace `<DATASET_ROOT>` with the location where the dataset in Implicitron format is stored
and `<CHECKPOINT_DIR>` with a directory where checkpoints will be dumped during training.
Other configuration parameters can be overridden in the same way.
See [Configuration system](#configuration-system) section for more information on this.
## Evaluation
To run evaluation on the latest checkpoint after (or during) training, simply add `eval_only=True` to your training command.
E.g. for executing the evaluation on the NeRF skateboard sequence, you can run:
```shell
pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root=<CO3D_DATASET_ROOT> dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir=<CHECKPOINT_DIR> eval_only=True
```
Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`.
## Visualisation
The script produces a video of renders by a trained model assuming a pre-defined camera trajectory.
In order for it to work, `ffmpeg` needs to be installed:
```shell
conda install ffmpeg
```
Here is an example of calling the script:
```shell
projects/implicitron_trainer/visualize_reconstruction.py exp_dir=<CHECKPOINT_DIR> visdom_show_preds=True n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
```
The argument `n_eval_cameras` sets the number of renderring viewpoints sampled on a trajectory, which defaults to a circular fly-around;
`render_size` sets the size of a render passed to the model, which can be resized to `video_size` before writing.
Rendered videos of images, masks, and depth maps will be saved to `<CHECKPOINT_DIR>/vis`.
# Configuration system
We use hydra and OmegaConf to parse the configs.
The config schema and default values are defined by the dataclasses implementing the modules.
More specifically, if a class derives from `Configurable`, its fields can be set in config yaml files or overridden in CLI.
For example, `GenericModel` has a field `render_image_width` with the default value 400.
If it is specified in the yaml config file or in CLI command, the new value will be used.
Configurables can form hierarchies.
For example, `GenericModel` has a field `raysampler: RaySampler`, which is also Configurable.
In the config, inner parameters can be propagated using `_args` postfix, e.g. to change `raysampler.n_pts_per_ray_training` (the number of sampled points per ray), the node `raysampler_args.n_pts_per_ray_training` should be specified.
The root of the hierarchy is defined by `ExperimentConfig` dataclass.
It has top-level fields like `eval_only` which was used above for running evaluation by adding a CLI override.
Additionally, it has non-leaf nodes like `generic_model_args`, which dispatches the config parameters to `GenericModel`. Thus, changing the model parameters may be achieved in two ways: either by editing the config file, e.g.
```yaml
generic_model_args:
render_image_width: 800
raysampler_args:
n_pts_per_ray_training: 128
```
or, equivalently, by adding the following to `pytorch3d_implicitron_runner` arguments:
```shell
generic_model_args.render_image_width=800 generic_model_args.raysampler_args.n_pts_per_ray_training=128
```
See the documentation in `pytorch3d/implicitron/tools/config.py` for more details.
## Replaceable implementations
Sometimes changing the model parameters does not provide enough flexibility, and you want to provide a new implementation for a building block.
The configuration system also supports it!
Abstract classes like `BaseRenderer` derive from `ReplaceableBase` instead of `Configurable`.
This means that other Configurables can refer to them using the base type, while the specific implementation is chosen in the config using `_class_type`-postfixed node.
In that case, `_args` node name has to include the implementation type.
More specifically, to change renderer settings, the config will look like this:
```yaml
generic_model_args:
renderer_class_type: LSTMRenderer
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
hidden_size: 16
```
See the documentation in `pytorch3d/implicitron/tools/config.py` for more details on the configuration system.
## Custom plugins
If you have an idea for another implementation of a replaceable component, it can be plugged in without changing the core code.
For that, you need to set up Implicitron through option 2 or 3 above.
Let's say you want to implement a renderer that accumulates opacities similar to an X-ray machine.
First, create a module `x_ray_renderer.py` with a class deriving from `BaseRenderer`:
```python
from pytorch3d.implicitron.tools.config import registry
@registry.register
class XRayRenderer(BaseRenderer, torch.nn.Module):
n_pts_per_ray: int = 64
# if there are other base classes, make sure to call `super().__init__()` explicitly
def __post_init__(self):
super().__init__()
# custom initialization
def forward(
self,
ray_bundle,
implicit_functions=[],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:
...
```
Please note `@registry.register` decorator that registers the plug-in as an implementation of `Renderer`.
IMPORTANT: In order for it to run, the class (or its enclosing module) has to be imported in your launch script. Additionally, this has to be done before parsing the root configuration class `ExperimentConfig`.
Simply add `import .x_ray_renderer` in the beginning of `experiment.py`.
After that, you should be able to change the config with:
```yaml
generic_model_args:
renderer_class_type: XRayRenderer
renderer_XRayRenderer_args:
n_pts_per_ray: 128
```
to replace the implementation and potentially override the parameters.
# Code and config structure
As per above, the config structure is parsed automatically from the module hierarchy.
In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node.
Here is the class structure (single-line edges show aggregation, while double lines show available implementations):
```
generic_model_args: GenericModel
└-- sequence_autodecoder_args: Autodecoder
└-- raysampler_args: RaySampler
└-- renderer_*_args: BaseRenderer
╘== MultiPassEmissionAbsorptionRenderer
╘== LSTMRenderer
╘== SignedDistanceFunctionRenderer
└-- ray_tracer_args: RayTracing
└-- ray_normal_coloring_network_args: RayNormalColoringNetwork
└-- implicit_function_*_args: ImplicitFunctionBase
╘== NeuralRadianceFieldImplicitFunction
╘== SRNImplicitFunction
└-- raymarch_function_args: SRNRaymarchFunction
└-- pixel_generator_args: SRNPixelGenerator
╘== SRNHyperNetImplicitFunction
└-- hypernet_args: SRNRaymarchHyperNet
└-- pixel_generator_args: SRNPixelGenerator
╘== IdrFeatureField
└-- image_feature_extractor_args: ResNetFeatureExtractor
└-- view_sampler_args: ViewSampler
└-- feature_aggregator_*_args: FeatureAggregatorBase
╘== IdentityFeatureAggregator
╘== AngleWeightedIdentityFeatureAggregator
╘== AngleWeightedReductionFeatureAggregator
╘== ReductionFeatureAggregator
solver_args: init_optimizer
dataset_args: dataset_zoo
dataloader_args: dataloader_zoo
```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
# Reproducing CO3D experiments
Common Objects in 3D (CO3D) is a large-scale dataset of videos of rigid objects grouped into 50 common categories.
Implicitron provides implementations and config files to reproduce the results from [the paper](https://arxiv.org/abs/2109.00512).
Please follow [the link](https://github.com/facebookresearch/co3d#automatic-batch-download) for the instructions to download the dataset.
In training and evaluation scripts, use the download location as `<DATASET_ROOT>`.
It is also possible to define environment variable `CO3D_DATASET_ROOT` instead of specifying it.
To reproduce the experiments from the paper, use the following configs. For single-sequence experiments:
| Method | config file |
|-----------------|-------------------------------------|
| NeRF | repro_singleseq_nerf.yaml |
| NeRF + WCE | repro_singleseq_nerf_wce.yaml |
| NerFormer | repro_singleseq_nerformer.yaml |
| IDR | repro_singleseq_idr.yaml |
| SRN | repro_singleseq_srn_noharm.yaml |
| SRN + γ | repro_singleseq_srn.yaml |
| SRN + WCE | repro_singleseq_srn_wce_noharm.yaml |
| SRN + WCE + γ | repro_singleseq_srn_wce_noharm.yaml |
For multi-sequence experiments (without generalisation to new sequences):
| Method | config file |
|-----------------|--------------------------------------------|
| NeRF + AD | repro_multiseq_nerf_ad.yaml |
| SRN + AD | repro_multiseq_srn_ad_hypernet_noharm.yaml |
| SRN + γ + AD | repro_multiseq_srn_ad_hypernet.yaml |
For multi-sequence experiments (with generalisation to new sequences):
| Method | config file |
|-----------------|--------------------------------------|
| NeRF + WCE | repro_multiseq_nerf_wce.yaml |
| NerFormer | repro_multiseq_nerformer.yaml |
| SRN + WCE | repro_multiseq_srn_wce_noharm.yaml |
| SRN + WCE + γ | repro_multiseq_srn_wce.yaml |

View File

@ -0,0 +1,83 @@
defaults:
- default_config
- _self_
exp_dir: ./data/exps/base/
architecture: generic
visualize_interval: 0
visdom_port: 8097
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
dataset_root: ${oc.env:CO3D_DATASET_ROOT}"
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
generic_model_args:
loss_weights:
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 1.0
loss_autodecoder_norm: 0.01
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
output_rasterized_mc: false
chunk_size_grid: 102400
render_image_height: 400
render_image_width: 400
num_passes: 2
implicit_function_NeuralRadianceFieldImplicitFunction_args:
n_harmonic_functions_xyz: 10
n_harmonic_functions_dir: 4
n_hidden_neurons_xyz: 256
n_hidden_neurons_dir: 128
n_layers_xyz: 8
append_xyz:
- 5
latent_dim: 0
raysampler_args:
n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.0
max_depth: 0.0
scene_extent: 8.0
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 64
n_pts_per_ray_fine_evaluation: 64
append_coarse_samples_to_fine: true
density_noise_std_train: 1.0
view_sampler_args:
masked_sampling: false
image_feature_extractor_args:
stages:
- 1
- 2
- 3
- 4
proj_dim: 16
image_rescale: 0.32
first_max_pool: false
solver_args:
breed: adam
lr: 0.0005
lr_policy: multistep
max_epochs: 2000
momentum: 0.9
weight_decay: 0.0

View File

@ -0,0 +1,16 @@
generic_model_args:
image_feature_extractor_args:
add_images: true
add_masks: true
first_max_pool: true
image_rescale: 0.375
l2_norm: true
name: resnet34
normalize_image: true
pretrained: true
stages:
- 1
- 2
- 3
- 4
proj_dim: 32

View File

@ -0,0 +1,16 @@
generic_model_args:
image_feature_extractor_args:
add_images: true
add_masks: true
first_max_pool: false
image_rescale: 0.375
l2_norm: true
name: resnet34
normalize_image: true
pretrained: true
stages:
- 1
- 2
- 3
- 4
proj_dim: 16

View File

@ -0,0 +1,16 @@
generic_model_args:
image_feature_extractor_args:
stages:
- 1
- 2
- 3
first_max_pool: false
proj_dim: -1
l2_norm: false
image_rescale: 0.375
name: resnet34
normalize_image: true
pretrained: true
feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
reduction_functions:
- AVG

View File

@ -0,0 +1,31 @@
defaults:
- repro_base.yaml
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
dataset_args:
assert_single_seq: false
dataset_name: co3d_multisequence
load_point_clouds: false
mask_depths: false
mask_images: false
n_frames_per_sequence: -1
test_on_train: true
test_restrict_sequence_id: 0
solver_args:
max_epochs: 3000
milestones:
- 1000

View File

@ -0,0 +1,64 @@
defaults:
- repro_multiseq_base.yaml
- _self_
generic_model_args:
loss_weights:
loss_mask_bce: 100.0
loss_kl: 0.0
loss_rgb_mse: 1.0
loss_eikonal: 0.1
chunk_size_grid: 65536
num_passes: 1
output_rasterized_mc: true
sampling_mode_training: mask_sample
view_pool: false
sequence_autodecoder_args:
n_instances: 20000
init_scale: 1.0
encoding_dim: 256
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true
renderer_SignedDistanceFunctionRenderer_args:
ray_tracer_args:
line_search_step: 0.5
line_step_iters: 3
n_secant_steps: 8
n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
mode: idr
n_harmonic_functions_dir: 4
pooled_feature_dim: 0
weight_norm: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0
scene_extent: 8.0
renderer_class_type: SignedDistanceFunctionRenderer
implicit_function_class_type: IdrFeatureField

View File

@ -0,0 +1,9 @@
defaults:
- repro_multiseq_base.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
sequence_autodecoder_args:
n_instances: 20000
encoding_dim: 256

View File

@ -0,0 +1,10 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_unnormed.yaml
- _self_
clip_grad: 1.0
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@ -0,0 +1,16 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_transformer.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@ -0,0 +1,16 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_transformer.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
implicit_function_class_type: NeRFormerImplicitFunction
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator

View File

@ -0,0 +1,32 @@
defaults:
- repro_multiseq_base.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: false
n_train_target_views: -1
num_passes: 1
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.001
depth_neg_penalty: 10000.0
sequence_autodecoder_args:
encoding_dim: 256
n_instances: 20000
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNHyperNetImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@ -0,0 +1,10 @@
defaults:
- repro_multiseq_srn_ad_hypernet.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNHyperNetImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
hypernet_args:
n_harmonic_functions: 0

View File

@ -0,0 +1,30 @@
defaults:
- repro_multiseq_base.yaml
- repro_feat_extractor_normed.yaml
- _self_
generic_model_args:
chunk_size_grid: 32000
view_pool: true
num_passes: 1
n_train_target_views: -1
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@ -0,0 +1,10 @@
defaults:
- repro_multiseq_srn_wce.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0

View File

@ -0,0 +1,41 @@
defaults:
- repro_base
- _self_
dataloader_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
dataset_args:
dataset_name: co3d_singlesequence
assert_single_seq: true
n_frames_per_sequence: -1
test_restrict_sequence_id: 0
test_on_train: false
generic_model_args:
render_image_height: 800
render_image_width: 800
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_eikonal
- loss_prev_stage_rgb_psnr
- loss_mask_bce
- loss_prev_stage_mask_bce
- loss_rgb_mse
- loss_prev_stage_rgb_mse
- loss_depth_abs
- loss_depth_abs_fg
- loss_kl
- loss_mask_neg_iou
- objective
- epoch
- sec/it
solver_args:
lr: 0.0005
max_epochs: 400
milestones:
- 200
- 300

View File

@ -0,0 +1,57 @@
defaults:
- repro_singleseq_base
- _self_
generic_model_args:
loss_weights:
loss_mask_bce: 100.0
loss_kl: 0.0
loss_rgb_mse: 1.0
loss_eikonal: 0.1
chunk_size_grid: 65536
num_passes: 1
view_pool: false
implicit_function_IdrFeatureField_args:
n_harmonic_functions_xyz: 6
bias: 0.6
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
pooled_feature_dim: 0
skip_in:
- 6
weight_norm: true
renderer_SignedDistanceFunctionRenderer_args:
ray_tracer_args:
line_search_step: 0.5
line_step_iters: 3
n_secant_steps: 8
n_steps: 100
object_bounding_sphere: 8.0
sdf_threshold: 5.0e-05
ray_normal_coloring_network_args:
d_in: 9
d_out: 3
dims:
- 512
- 512
- 512
- 512
mode: idr
n_harmonic_functions_dir: 4
pooled_feature_dim: 0
weight_norm: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 1024
n_pts_per_ray_training: 0
n_pts_per_ray_evaluation: 0
renderer_class_type: SignedDistanceFunctionRenderer
implicit_function_class_type: IdrFeatureField

View File

@ -0,0 +1,4 @@
defaults:
- repro_singleseq_base
- _self_
exp_dir: ./data/nerf_single_apple/

View File

@ -0,0 +1,9 @@
defaults:
- repro_singleseq_wce_base.yaml
- repro_feat_extractor_unnormed.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
raysampler_args:
n_rays_per_image_sampled_from_mask: 850

View File

@ -0,0 +1,16 @@
defaults:
- repro_singleseq_wce_base.yaml
- repro_feat_extractor_transformer.yaml
- _self_
generic_model_args:
chunk_size_grid: 16000
view_pool: true
implicit_function_class_type: NeRFormerImplicitFunction
raysampler_args:
n_rays_per_image_sampled_from_mask: 800
n_pts_per_ray_training: 32
n_pts_per_ray_evaluation: 32
renderer_MultiPassEmissionAbsorptionRenderer_args:
n_pts_per_ray_fine_training: 16
n_pts_per_ray_fine_evaluation: 16
feature_aggregator_class_type: IdentityFeatureAggregator

View File

@ -0,0 +1,28 @@
defaults:
- repro_singleseq_base.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: false
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@ -0,0 +1,10 @@
defaults:
- repro_singleseq_srn.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0

View File

@ -0,0 +1,29 @@
defaults:
- repro_singleseq_wce_base
- repro_feat_extractor_normed.yaml
- _self_
generic_model_args:
num_passes: 1
chunk_size_grid: 32000
view_pool: true
loss_weights:
loss_rgb_mse: 200.0
loss_prev_stage_rgb_mse: 0.0
loss_mask_bce: 1.0
loss_prev_stage_mask_bce: 0.0
loss_autodecoder_norm: 0.0
depth_neg_penalty: 10000.0
raysampler_args:
n_rays_per_image_sampled_from_mask: 2048
min_depth: 0.05
max_depth: 0.05
scene_extent: 0.0
n_pts_per_ray_training: 1
n_pts_per_ray_evaluation: 1
stratified_point_sampling_training: false
stratified_point_sampling_evaluation: false
renderer_class_type: LSTMRenderer
implicit_function_class_type: SRNImplicitFunction
solver_args:
breed: adam
lr: 5.0e-05

View File

@ -0,0 +1,10 @@
defaults:
- repro_singleseq_srn_wce.yaml
- _self_
generic_model_args:
num_passes: 1
implicit_function_SRNImplicitFunction_args:
pixel_generator_args:
n_harmonic_functions: 0
raymarch_function_args:
n_harmonic_functions: 0

View File

@ -0,0 +1,18 @@
defaults:
- repro_singleseq_base
- _self_
dataloader_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1
num_workers: 8
images_per_seq_options:
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10

View File

@ -0,0 +1,714 @@
#!/usr/bin/env python
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
""""
This file is the entry point for launching experiments with Implicitron.
Main functions
---------------
- `run_training` is the wrapper for the train, val, test loops
and checkpointing
- `trainvalidate` is the inner loop which runs the model forward/backward
pass, visualizations and metric printing
Launch Training
---------------
Experiment config .yaml files are located in the
`projects/implicitron_trainer/configs` folder. To launch
an experiment, specify the name of the file. Specific config values can
also be overridden from the command line, for example:
```
./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84
```
To run an experiment on a specific GPU, specify the `gpu_idx` key
in the config file / CLI. To run on a different device, specify the
device in `run_training`.
Outputs
--------
The outputs of the experiment are saved and logged in multiple ways:
- Checkpoints:
Model, optimizer and stats are stored in the directory
named by the `exp_dir` key from the config file / CLI parameters.
- Stats
Stats are logged and plotted to the file "train_stats.pdf" in the
same directory. The stats are also saved as part of the checkpoint file.
- Visualizations
Prredictions are plotted to a visdom server running at the
port specified by the `visdom_server` and `visdom_port` keys in the
config file.
"""
import copy
import json
import logging
import os
import random
import time
import warnings
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple
import hydra
import lpips
import numpy as np
import torch
import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
ImplicitronDataset,
FrameData,
)
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
remove_unused_components,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
logger = logging.getLogger(__name__)
if version.parse(hydra.__version__) < version.Version("1.1"):
raise ValueError(
f"Hydra version {hydra.__version__} is too old."
" (Implicitron requires version 1.1 or later.)"
)
try:
# only makes sense in FAIR cluster
import pytorch3d.implicitron.fair_cluster.slurm # noqa: F401
except ModuleNotFoundError:
pass
def init_model(
cfg: DictConfig,
force_load: bool = False,
clear_stats: bool = False,
load_model_only: bool = False,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
"""
Returns an instance of `GenericModel`.
If `cfg.resume` is set or `force_load` is true,
attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so
will return the model with initial weights, unless `force_load` is passed,
in which case a FileNotFoundError is raised.
Args:
force_load: If true, force load model from checkpoint even if
cfg.resume is false.
clear_stats: If true, clear the stats object loaded from checkpoint
load_model_only: If true, load only the model weights from checkpoint
and do not load the state of the optimizer and stats.
Returns:
model: The model with optionally loaded weights from checkpoint
stats: The stats structure (optionally loaded from checkpoint)
optimizer_state: The optimizer state dict containing
`state` and `param_groups` keys (optionally loaded from checkpoint)
Raise:
FileNotFoundError if `force_load` is passed but checkpoint is not found.
"""
# Initialize the model
if cfg.architecture == "generic":
model = GenericModel(**cfg.generic_model_args)
else:
raise ValueError(f"No such arch {cfg.architecture}.")
# Determine the network outputs that should be logged
if hasattr(model, "log_vars"):
log_vars = copy.deepcopy(list(model.log_vars))
else:
log_vars = ["objective"]
visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts"
# Init the stats struct
stats = Stats(
log_vars,
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=cfg.visdom_server,
visdom_port=cfg.visdom_port,
)
# Retrieve the last checkpoint
if cfg.resume_epoch > 0:
model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(cfg.exp_dir)
optimizer_state = None
if model_path is not None:
logger.info("found previous model %s" % model_path)
if force_load or cfg.resume:
logger.info(" -> resuming")
if load_model_only:
model_state_dict = torch.load(model_io.get_model_path(model_path))
stats_load, optimizer_state = None, None
else:
model_state_dict, stats_load, optimizer_state = model_io.load_model(
model_path
)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = cfg.visdom_server
stats.visdom_port = cfg.visdom_port
stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
try:
# TODO: fix on creation of the buffers
# after the hack above, this will not pass in most cases
# ... but this is fine for now
model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info("Cant load state dict in strict mode! -> trying non-strict")
model.load_state_dict(model_state_dict, strict=False)
model.log_vars = log_vars
else:
logger.info(" -> but not resuming -> starting from scratch")
elif force_load:
raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!")
return model, stats, optimizer_state
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: bool = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
# Add the max epochs here
scheduler.max_epochs = max_epochs
optimizer.zero_grad()
return optimizer, scheduler
def trainvalidate(
model,
stats,
epoch,
loader,
optimizer,
validation,
bp_var: str = "objective",
metric_print_interval: int = 5,
visualize_interval: int = 100,
visdom_env_root: str = "trainvalidate",
clip_grad: float = 0.0,
device: str = "cuda:0",
**kwargs,
) -> None:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args:
model: The model module optionally loaded from checkpoint
stats: The stats struct, also optionally loaded from checkpoint
epoch: The index of the current epoch
loader: The dataloader to use for the loop
optimizer: The optimizer module optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env_root: The name of the visdom environment to use for plotting
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
device: The device on which to run the model.
Returns:
None
"""
if validation:
model.eval()
trainmode = "val"
else:
model.train()
trainmode = "train"
t_start = time.time()
# get the visdom env name
visdom_env_imgs = visdom_env_root + "_images_" + trainmode
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
# Iterate through the batches
n_batches = len(loader)
for it, batch in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = batch.to(device)
# run the forward pass
if not validation:
optimizer.zero_grad()
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
else:
with torch.no_grad():
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
)
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
# merge everything into one big dict
preds.update(net_input)
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if it % metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if visualize_interval > 0 and it % visualize_interval == 0:
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
model.visualize(
viz,
visdom_env_imgs,
preds,
prefix,
)
# optimizer step
if not validation:
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
loss.backward()
if clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
model.parameters(), clip_grad
)
if total_norm > clip_grad:
logger.info(
f"Clipping gradient: {total_norm}"
+ f" with coef {clip_grad / total_norm}."
)
optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu"):
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
# set the debug mode
if cfg.detect_anomaly:
logger.info("Anomaly detection!")
torch.autograd.set_detect_anomaly(cfg.detect_anomaly)
# create the output folder
os.makedirs(cfg.exp_dir, exist_ok=True)
_seed_all_random_engines(cfg.seed)
remove_unused_components(cfg)
# dump the exp config to the exp dir
try:
cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml")
OmegaConf.save(config=cfg, f=cfg_filename)
except PermissionError:
warnings.warn("Cant dump config due to insufficient permissions!")
# setup datasets
datasets = dataset_zoo(**cfg.dataset_args)
cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"]
dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args)
# init the model
model, stats, optimizer_state = init_model(cfg)
start_epoch = stats.epoch + 1
# move model to gpu
model.to(device)
# only run evaluation on the test dataloader
if cfg.eval_only:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
return
# init the optimizer
optimizer, scheduler = init_optimizer(
model,
optimizer_state=optimizer_state,
last_epoch=start_epoch,
**cfg.solver_args,
)
# check the scheduler and stats have been initialized correctly
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
past_scheduler_lrs = []
# loop through epochs
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start
with stats:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
_seed_all_random_engines(cfg.seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1])
logger.info(f"scheduler lr = {cur_lr:1.2e}")
past_scheduler_lrs.append(cur_lr)
# train loop
trainvalidate(
model,
stats,
epoch,
dataloaders["train"],
optimizer,
False,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
**cfg,
)
# val loop (optional)
if "val" in dataloaders and epoch % cfg.validation_interval == 0:
trainvalidate(
model,
stats,
epoch,
dataloaders["val"],
optimizer,
True,
visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device,
**cfg,
)
# eval loop (optional)
if (
"test" in dataloaders
and cfg.test_interval > 0
and epoch % cfg.test_interval == 0
):
run_eval(cfg, model, stats, dataloaders["test"], device=device)
assert stats.epoch == epoch, "inconsistent stats!"
# delete previous models if required
# save model
if cfg.store_checkpoints:
if cfg.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
scheduler.step()
new_lr = float(scheduler.get_last_lr()[-1])
if new_lr != cur_lr:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if cfg.test_when_finished:
_eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device)
def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device):
"""
Run the evaluation loop with the test data loader and
save the predictions to the `exp_dir`.
"""
if "test" not in dataloaders:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
eval_task = cfg.dataset_args["dataset_name"].split("_")[-1]
all_source_cameras = (
_get_all_source_cameras(datasets["train"])
if eval_task == "singlesequence"
else None
)
results = run_eval(
cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device
)
# add the evaluation epoch to the results
for r in results:
r["eval_epoch"] = int(stats.epoch)
logger.info("Evaluation results")
evaluate.pretty_print_nvs_metrics(results)
with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f:
json.dump(results, f)
def _get_eval_frame_data(frame_data):
"""
Masks the unknown image data to make sure we cannot use it at model evaluation time.
"""
frame_data_for_eval = copy.deepcopy(frame_data)
is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as(
frame_data.image_rgb
)[:, None, None, None]
for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"):
value_masked = getattr(frame_data_for_eval, k).clone() * is_known
setattr(frame_data_for_eval, k, value_masked)
return frame_data_for_eval
def run_eval(cfg, model, all_source_cameras, loader, task, device):
"""
Run the evaluation loop on the test dataloader
"""
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device)
model.eval()
per_batch_eval_results = []
logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(loader):
frame_data = frame_data.to(device)
# mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data)
with torch.no_grad():
preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
)
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
evaluate.eval_batch(
frame_data,
nvs_prediction,
bg_color="black",
lpips_model=lpips_model,
source_cameras=all_source_cameras,
)
)
_, category_result = evaluate.summarize_nvs_eval_results(
per_batch_eval_results, task
)
return category_result["results"]
def _get_all_source_cameras(
dataset: ImplicitronDataset,
num_workers: int = 8,
) -> CamerasBase:
"""
Load and return all the source cameras in the training dataset
"""
all_frame_data = next(
iter(
torch.utils.data.DataLoader(
dataset,
shuffle=False,
batch_size=len(dataset),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
)
)
is_source = ds_utils.is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_source)[0]]
return source_cameras
def _seed_all_random_engines(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
@dataclass(eq=False)
class ExperimentConfig:
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
dataset_args: DictConfig = get_default_args_field(dataset_zoo)
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig)
@hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
# Set the device
device = "cpu"
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
device = f"cuda:{cfg.gpu_idx}"
logger.info(f"Running experiment on device: {device}")
run_training(cfg, device)
if __name__ == "__main__":
experiment()

View File

@ -0,0 +1,382 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Script to visualize a previously trained model. Example call:
projects/implicitron_trainer/visualize_reconstruction.py
exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097
n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]"
"""
import math
import os
import random
import sys
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.dataset.utils import is_train_frame
from pytorch3d.implicitron.models.base import EvaluationMode
from pytorch3d.implicitron.tools.configurable import get_default_args
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.implicitron.tools.video_writer import VideoWriter
from pytorch3d.implicitron.tools.vis_utils import (
get_visdom_connection,
make_depth_image,
)
from tqdm import tqdm
def render_sequence(
dataset: ImplicitronDataset,
sequence_name: str,
model: torch.nn.Module,
video_path,
n_eval_cameras=40,
fps=20,
max_angle=2 * math.pi,
trajectory_type="circular_lsq_fit",
trajectory_scale=1.1,
scene_center=(0.0, 0.0, 0.0),
up=(0.0, -1.0, 0.0),
traj_offset=0.0,
n_source_views=9,
viz_env="debug",
visdom_show_preds=False,
visdom_server="http://127.0.0.1",
visdom_port=8097,
num_workers=10,
seed=None,
video_resize=None,
):
if seed is None:
seed = hash(sequence_name)
print(f"Loading all data of sequence '{sequence_name}'.")
seq_idx = dataset.seq_to_idx[sequence_name]
train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers)
assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name)
sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test"
print(f"Sequence set = {sequence_set_name}.")
train_cameras = train_data.camera
time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras]
test_cameras = generate_eval_video_cameras(
train_cameras,
time=time,
n_eval_cams=n_eval_cameras,
trajectory_type=trajectory_type,
trajectory_scale=trajectory_scale,
scene_center=scene_center,
up=up,
focal_length=None,
principal_point=torch.zeros(n_eval_cameras, 2),
traj_offset_canonical=[0.0, 0.0, traj_offset],
)
# sample the source views reproducibly
with torch.random.fork_rng():
torch.manual_seed(seed)
source_views_i = torch.randperm(len(seq_idx))[:n_source_views]
# add the first dummy view that will get replaced with the target camera
source_views_i = Fu.pad(source_views_i, [1, 0])
source_views = [seq_idx[i] for i in source_views_i.tolist()]
batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers)
assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name)
preds_total = []
for n in tqdm(range(n_eval_cameras), total=n_eval_cameras):
# set the first batch camera to the target camera
for k in ("R", "T", "focal_length", "principal_point"):
getattr(batch.camera, k)[0] = getattr(test_cameras[n], k)
# Move to cuda
net_input = batch.cuda()
with torch.no_grad():
preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION})
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
preds.update(net_input) # merge everything into one big dict
# Render the predictions to images
rendered_pred = images_from_preds(preds)
preds_total.append(rendered_pred)
# show the preds every 5% of the export iterations
if visdom_show_preds and (
n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1
):
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
show_predictions(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
)
print(f"Exporting videos for sequence {sequence_name} ...")
generate_prediction_videos(
preds_total,
sequence_name=batch.sequence_name[0],
viz=viz,
viz_env=viz_env,
fps=fps,
video_path=video_path,
resize=video_resize,
)
def _load_whole_dataset(dataset, idx, num_workers=10):
load_all_dataloader = torch.utils.data.DataLoader(
torch.utils.data.Subset(dataset, idx),
batch_size=len(idx),
num_workers=num_workers,
shuffle=False,
collate_fn=FrameData.collate,
)
return next(iter(load_all_dataloader))
def images_from_preds(preds):
imout = {}
for k in (
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
"_all_source_images",
):
if k == "_all_source_images" and "image_rgb" in preds:
src_ims = preds["image_rgb"][1:].cpu().detach().clone()
v = _stack_images(src_ims, None)[None]
else:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
mask_resize = Fu.interpolate(
preds["masks_render"],
size=preds[k].shape[2:],
mode="nearest",
)
v = make_depth_image(preds[k], mask_resize)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
imout[k] = v.detach().cpu()
return imout
def _stack_images(ims, size):
ba = ims.shape[0]
H = int(np.ceil(np.sqrt(ba)))
W = H
n_add = H * W - ba
if n_add > 0:
ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1)))
ims = ims.view(H, W, *ims.shape[1:])
cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1)
if size is not None:
cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0]
return cated.clamp(0.0, 1.0)
def show_predictions(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
n_samples=10,
one_image_width=200,
):
"""Given a list of predictions visualize them into a single image using visdom."""
assert isinstance(preds, list)
pred_all = []
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
n_samples = min(n_samples, len(preds))
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
for predi in pred_idx:
# Make the concatentation for the same camera vertically
pred_all.append(
torch.cat(
[
torch.nn.functional.interpolate(
preds[predi][k].cpu(),
scale_factor=one_image_width / preds[predi][k].shape[3],
mode="bilinear",
).clamp(0.0, 1.0)
for k in predicted_keys
],
dim=2,
)
)
# Concatenate the images horizontally
pred_all_cat = torch.cat(pred_all, dim=3)[0]
viz.image(
pred_all_cat,
win="show_predictions",
env=viz_env,
opts={"title": f"pred_{sequence_name}"},
)
def generate_prediction_videos(
preds,
sequence_name,
viz,
viz_env="visualizer",
predicted_keys=(
"images_render",
"masks_render",
"depths_render",
"_all_source_images",
),
fps=20,
video_path="/tmp/video",
resize=None,
):
"""Given a list of predictions create and visualize rotating videos of the
objects using visdom.
"""
assert isinstance(preds, list)
# make sure the target video directory exists
os.makedirs(os.path.dirname(video_path), exist_ok=True)
# init a video writer for each predicted key
vws = {}
for k in predicted_keys:
vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps)
for rendered_pred in tqdm(preds):
for k in predicted_keys:
vws[k].write_frame(
rendered_pred[k][0].detach().cpu().numpy(),
resize=resize,
)
for k in predicted_keys:
vws[k].get_video(quiet=True)
print(f"Generated {vws[k].out_path}.")
viz.video(
videofile=vws[k].out_path,
env=viz_env,
win=k, # we reuse the same window otherwise visdom dies
opts={"title": sequence_name + " " + k},
)
def export_scenes(
exp_dir: str = "",
restrict_sequence_name: Optional[str] = None,
output_directory: Optional[str] = None,
render_size: Tuple[int, int] = (512, 512),
video_size: Optional[Tuple[int, int]] = None,
split: str = "train", # train | test
n_source_views: int = 9,
n_eval_cameras: int = 40,
visdom_server="http://127.0.0.1",
visdom_port=8097,
visdom_show_preds: bool = False,
visdom_env: Optional[str] = None,
gpu_idx: int = 0,
):
# In case an output directory is specified use it. If no output_directory
# is specified create a vis folder inside the experiment directory
if output_directory is None:
output_directory = os.path.join(exp_dir, "vis")
else:
output_directory = output_directory
if not os.path.exists(output_directory):
os.makedirs(output_directory)
# Set the random seeds
torch.manual_seed(0)
np.random.seed(0)
# Get the config from the experiment_directory,
# and overwrite relevant fields
config = _get_config_from_experiment_directory(exp_dir)
config.gpu_idx = gpu_idx
config.exp_dir = exp_dir
# important so that the CO3D dataset gets loaded in full
config.dataset_args.test_on_train = False
# Set the rendering image size
config.generic_model_args.render_image_width = render_size[0]
config.generic_model_args.render_image_height = render_size[1]
if restrict_sequence_name is not None:
config.dataset_args.restrict_sequence_name = restrict_sequence_name
# Set up the CUDA env for the visualization
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx)
# Load the previously trained model
model, _, _ = init_model(config, force_load=True, load_model_only=True)
model.cuda()
model.eval()
# Setup the dataset
dataset = dataset_zoo(**config.dataset_args)[split]
# iterate over the sequences in the dataset
for sequence_name in dataset.seq_to_idx.keys():
with torch.no_grad():
render_sequence(
dataset,
sequence_name,
model,
video_path="{}/video".format(output_directory),
n_source_views=n_source_views,
visdom_show_preds=visdom_show_preds,
n_eval_cameras=n_eval_cameras,
visdom_server=visdom_server,
visdom_port=visdom_port,
viz_env=f"visualizer_{config.visdom_env}"
if visdom_env is None
else visdom_env,
video_resize=video_size,
)
def _get_config_from_experiment_directory(experiment_directory):
cfg_file = os.path.join(experiment_directory, "expconfig.yaml")
config = OmegaConf.load(cfg_file)
return config
def main(argv):
# automatically parses arguments of export_scenes
cfg = OmegaConf.create(get_default_args(export_scenes))
cfg.update(OmegaConf.from_cli())
with torch.no_grad():
export_scenes(**cfg)
if __name__ == "__main__":
main(sys.argv)

View File

@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, Sequence
import torch
from .implicitron_dataset import FrameData, ImplicitronDatasetBase
from .scene_batch_sampler import SceneBatchSampler
def dataloader_zoo(
datasets: Dict[str, ImplicitronDatasetBase],
dataset_name: str = "co3d_singlesequence",
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
dataset_len_val: int = 1,
images_per_seq_options: Sequence[int] = (2,),
sample_consecutive_frames: bool = False,
consecutive_frames_max_gap: int = 0,
consecutive_frames_max_gap_seconds: float = 0.1,
) -> Dict[str, torch.utils.data.DataLoader]:
"""
Returns a set of dataloaders for a given set of datasets.
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
dataset_name: The name of the returned dataset.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]:
raise ValueError(f"Unsupported dataset: {dataset_name}")
dataloaders = {}
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
for dataset_set, dataset in datasets.items():
num_samples = {
"train": dataset_len,
"val": dataset_len_val,
"test": None,
}[dataset_set]
if dataset_set == "test":
batch_sampler = dataset.get_eval_batches()
else:
assert num_samples is not None
num_samples = len(dataset) if num_samples <= 0 else num_samples
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=num_samples,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
dataloaders[dataset_set] = torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=FrameData.collate,
)
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
return dataloaders

View File

@ -0,0 +1,260 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import json
import os
from typing import Any, Dict, List, Optional, Sequence
from iopath.common.file_io import PathManager
from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase
from .utils import (
DATASET_TYPE_KNOWN,
DATASET_TYPE_TEST,
DATASET_TYPE_TRAIN,
DATASET_TYPE_UNKNOWN,
)
# TODO from dataset.dataset_configs import DATASET_CONFIGS
DATASET_CONFIGS: Dict[str, Dict[str, Any]] = {
"default": {
"box_crop": True,
"box_crop_context": 0.3,
"image_width": 800,
"image_height": 800,
"remove_empty_masks": True,
}
}
# fmt: off
CO3D_CATEGORIES: List[str] = list(reversed([
"baseballbat", "banana", "bicycle", "microwave", "tv",
"cellphone", "toilet", "hairdryer", "couch", "kite", "pizza",
"umbrella", "wineglass", "laptop",
"hotdog", "stopsign", "frisbee", "baseballglove",
"cup", "parkingmeter", "backpack", "toyplane", "toybus",
"handbag", "chair", "keyboard", "car", "motorcycle",
"carrot", "bottle", "sandwich", "remote", "bowl", "skateboard",
"toaster", "mouse", "toytrain", "book", "toytruck",
"orange", "broccoli", "plant", "teddybear",
"suitcase", "bench", "ball", "cake",
"vase", "hydrant", "apple", "donut",
]))
# fmt: on
_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
def dataset_zoo(
dataset_name: str = "co3d_singlesequence",
dataset_root: str = _CO3D_DATASET_ROOT,
category: str = "DEFAULT",
limit_to: int = -1,
limit_sequences_to: int = -1,
n_frames_per_sequence: int = -1,
test_on_train: bool = False,
load_point_clouds: bool = False,
mask_images: bool = False,
mask_depths: bool = False,
restrict_sequence_name: Sequence[str] = (),
test_restrict_sequence_id: int = -1,
assert_single_seq: bool = False,
only_test_set: bool = False,
aux_dataset_kwargs: dict = DATASET_CONFIGS["default"],
path_manager: Optional[PathManager] = None,
) -> Dict[str, ImplicitronDatasetBase]:
"""
Generates the training / validation and testing dataset objects.
Args:
dataset_name: The name of the returned dataset.
dataset_root: The root folder of the dataset.
category: The object category of the dataset.
limit_to: Limit the dataset to the first #limit_to frames.
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences.
n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames
in each sequence.
test_on_train: Construct validation and test datasets from
the training subset.
load_point_clouds: Enable returning scene point clouds from the dataset.
mask_images: Mask the loaded images with segmentation masks.
mask_depths: Mask the loaded depths with segmentation masks.
restrict_sequence_name: Restrict the dataset sequences to the ones
present in the given list of names.
test_restrict_sequence_id: The ID of the loaded sequence.
Active for dataset_name='co3d_singlesequence'.
assert_single_seq: Assert that only frames from a single sequence
are present in all generated datasets.
only_test_set: Load only the test set.
aux_dataset_kwargs: Specifies additional arguments to the
ImplicitronDataset constructor call.
Returns:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
"""
datasets = {}
# TODO:
# - implement loading multiple categories
if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]:
# This maps the common names of the dataset subsets ("train"/"val"/"test")
# to the names of the subsets in the CO3D dataset.
set_names_mapping = _get_co3d_set_names_mapping(
dataset_name,
test_on_train,
only_test_set,
)
# load the evaluation batches
task = dataset_name.split("_")[-1]
batch_indices_path = os.path.join(
dataset_root,
category,
f"eval_batches_{task}.json",
)
if not os.path.isfile(batch_indices_path):
# The batch indices file does not exist.
# Most probably the user has not specified the root folder.
raise ValueError("Please specify a correct dataset_root folder.")
with open(batch_indices_path, "r") as f:
eval_batch_index = json.load(f)
if task == "singlesequence":
assert (
test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0
), (
"Please specify an integer id 'test_restrict_sequence_id'"
+ " of the sequence considered for 'singlesequence'"
+ " training and evaluation."
)
assert len(restrict_sequence_name) == 0, (
"For the 'singlesequence' task, the restrict_sequence_name has"
" to be unset while test_restrict_sequence_id has to be set to an"
" integer defining the order of the evaluation sequence."
)
# a sort-stable set() equivalent:
eval_batches_sequence_names = list(
{b[0][0]: None for b in eval_batch_index}.keys()
)
eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id]
eval_batch_index = [
b for b in eval_batch_index if b[0][0] == eval_sequence_name
]
# overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name]
for dataset, subsets in set_names_mapping.items():
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
assert os.path.isfile(frame_file)
sequence_file = os.path.join(
dataset_root, category, "sequence_annotations.jgz"
)
assert os.path.isfile(sequence_file)
subset_lists_file = os.path.join(dataset_root, category, "set_lists.json")
assert os.path.isfile(subset_lists_file)
# TODO: maybe directly in param list
params = {
**copy.deepcopy(aux_dataset_kwargs),
"frame_annotations_file": frame_file,
"sequence_annotations_file": sequence_file,
"subset_lists_file": subset_lists_file,
"dataset_root": dataset_root,
"limit_to": limit_to,
"limit_sequences_to": limit_sequences_to,
"n_frames_per_sequence": n_frames_per_sequence
if dataset == "train"
else -1,
"subsets": subsets,
"load_point_clouds": load_point_clouds,
"mask_images": mask_images,
"mask_depths": mask_depths,
"pick_sequence": restrict_sequence_name,
"path_manager": path_manager,
}
datasets[dataset] = ImplicitronDataset(**params)
if dataset == "test":
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
datasets[dataset].eval_batches = datasets[
dataset
].seq_frame_index_to_dataset_index(eval_batch_index)
if assert_single_seq:
# check theres only one sequence in all datasets
assert (
len(
{
e["frame_annotation"].sequence_name
for dset in datasets.values()
for e in dset.frame_annots
}
)
<= 1
), "Multiple sequences loaded but expected one"
else:
raise ValueError(f"Unsupported dataset: {dataset_name}")
if test_on_train:
datasets["val"] = datasets["train"]
datasets["test"] = datasets["train"]
return datasets
def _get_co3d_set_names_mapping(
dataset_name: str,
test_on_train: bool,
only_test: bool,
) -> Dict[str, List[str]]:
"""
Returns the mapping of the common dataset subset names ("train"/"val"/"test")
to the names of the corresponding subsets in the CO3D dataset
("test_known"/"test_unseen"/"train_known"/"train_unseen").
"""
single_seq = dataset_name == "co3d_singlesequence"
if only_test:
set_names_mapping = {}
else:
set_names_mapping = {
"train": [
(DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN)
+ "_"
+ DATASET_TYPE_KNOWN
]
}
if not test_on_train:
prefixes = [DATASET_TYPE_TEST]
if not single_seq:
prefixes.append(DATASET_TYPE_TRAIN)
set_names_mapping.update(
{
dset: [
p + "_" + t
for p in prefixes
for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN]
]
for dset in ["val", "test"]
}
)
return set_names_mapping

View File

@ -0,0 +1,988 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
import gzip
import hashlib
import json
import os
import random
import warnings
from collections import defaultdict
from dataclasses import dataclass, field, fields
from itertools import islice
from pathlib import Path
from typing import (
ClassVar,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
TypedDict,
Union,
)
import numpy as np
import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.io import IO
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch
from . import types
@dataclass
class FrameData:
"""
A type of the elements returned by indexing the dataset object.
It can represent both individual frames and batches of thereof;
in this documentation, the sizes of tensors refer to single frames;
add the first batch dimension for the collation result.
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image
regions. Regions can be invalid (mask_crop[i,j]=0) in case they
are a result of zero-padding of the image after cropping around
the object bounding box; elements are floats in {0.0, 1.0}.
depth_path: The qualified path to the frame's depth map.
depth_map: A float Tensor of shape `(1, H, W)` holding the depth map
of the frame; values correspond to distances from the camera;
use `depth_mask` and `mask_crop` to filter for valid pixels.
depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the
depth map that are valid for evaluation, they have been checked for
consistency across views; elements are floats in {0.0, 1.0}.
mask_path: A qualified path to the foreground probability mask.
fg_probability: A Tensor of `(1, H, W)` denoting the probability of the
pixels belonging to the captured object; elements are floats
in [0, 1].
bbox_xywh: The bounding box capturing the object in the
format (x0, y0, width, height).
camera: A PyTorch3D camera object corresponding the frame's viewpoint,
corrected for cropping if it happened.
camera_quality_score: The score proportional to the confidence of the
frame's camera estimation (the higher the more accurate).
point_cloud_quality_score: The score proportional to the accuracy of the
frame's sequence point cloud (the higher the more accurate).
sequence_point_cloud_path: The path to the sequence's point cloud.
sequence_point_cloud: A PyTorch3D Pointclouds object holding the
point cloud corresponding to the frame's sequence. When the object
represents a batch of frames, point clouds may be deduplicated;
see `sequence_point_cloud_idx`.
sequence_point_cloud_idx: Integer indices mapping frame indices to the
corresponding point clouds in `sequence_point_cloud`; to get the
corresponding point cloud to `image_rgb[i]`, use
`sequence_point_cloud[sequence_point_cloud_idx[i]]`.
frame_type: The type of the loaded frame specified in
`subset_lists_file`, if provided.
meta: A dict for storing additional frame information.
"""
frame_number: Optional[torch.LongTensor]
frame_timestamp: Optional[torch.Tensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
# masks out padding added due to cropping the square bit
mask_crop: Optional[torch.Tensor] = None
depth_path: Union[str, List[str], None] = None
depth_map: Optional[torch.Tensor] = None
depth_mask: Optional[torch.Tensor] = None
mask_path: Union[str, List[str], None] = None
fg_probability: Optional[torch.Tensor] = None
bbox_xywh: Optional[torch.Tensor] = None
camera: Optional[PerspectiveCameras] = None
camera_quality_score: Optional[torch.Tensor] = None
point_cloud_quality_score: Optional[torch.Tensor] = None
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # seen | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):
new_params = {}
for f in fields(self):
value = getattr(self, f.name)
if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)):
new_params[f.name] = value.to(*args, **kwargs)
else:
new_params[f.name] = value
return type(self)(**new_params)
def cpu(self):
return self.to(device=torch.device("cpu"))
def cuda(self):
return self.to(device=torch.device("cuda"))
# the following functions make sure **frame_data can be passed to functions
def keys(self):
for f in fields(self):
yield f.name
def __getitem__(self, key):
return getattr(self, key)
@classmethod
def collate(cls, batch):
"""
Given a list objects `batch` of class `cls`, collates them into a batched
representation suitable for processing with deep networks.
"""
elem = batch[0]
if isinstance(elem, cls):
pointcloud_ids = [id(el.sequence_point_cloud) for el in batch]
id_to_idx = defaultdict(list)
for i, pc_id in enumerate(pointcloud_ids):
id_to_idx[pc_id].append(i)
sequence_point_cloud = []
sequence_point_cloud_idx = -np.ones((len(batch),))
for i, ind in enumerate(id_to_idx.values()):
sequence_point_cloud_idx[ind] = i
sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud)
assert (sequence_point_cloud_idx >= 0).all()
override_fields = {
"sequence_point_cloud": sequence_point_cloud,
"sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(),
}
# note that the pre-collate value of sequence_point_cloud_idx is unused
collated = {}
for f in fields(elem):
list_values = override_fields.get(
f.name, [getattr(d, f.name) for d in batch]
)
collated[f.name] = (
cls.collate(list_values)
if all(list_value is not None for list_value in list_values)
else None
)
return cls(**collated)
elif isinstance(elem, Pointclouds):
return join_pointclouds_as_batch(batch)
elif isinstance(elem, CamerasBase):
# TODO: don't store K; enforce working in NDC space
return join_cameras_as_batch(batch)
else:
return torch.utils.data._utils.collate.default_collate(batch)
@dataclass(eq=False)
class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]):
"""
Base class to describe a dataset to be used with Implicitron.
The dataset is made up of frames, and the frames are grouped into sequences.
Each sequence has a name (a string).
(A sequence could be a video, or a set of images of one scene.)
This means they have a __getitem__ which returns an instance of a FrameData,
which will describe one frame in one sequence.
Members:
seq_to_idx: For each sequence, the indices of its frames.
"""
seq_to_idx: Dict[str, List[int]] = field(init=False)
def __len__(self) -> int:
raise NotImplementedError
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
"""
If the sequences in the dataset are videos rather than
unordered views, then the dataset should override this method to
return the index and timestamp in their videos of the frames whose
indices are given in `idxs`. In addition,
the values in seq_to_idx should be in ascending order.
If timestamps are absent, they should be replaced with a constant.
This is used for letting SceneBatchSampler identify consecutive
frames.
Args:
idx: frame index in self
Returns:
tuple of
- frame index in video
- timestamp of frame in video
"""
raise ValueError("This dataset does not contain videos.")
def get_eval_batches(self) -> Optional[List[List[int]]]:
return None
class FrameAnnotsEntry(TypedDict):
subset: Optional[str]
frame_annotation: types.FrameAnnotation
@dataclass(eq=False)
class ImplicitronDataset(ImplicitronDatasetBase):
"""
A class for the Common Objects in 3D (CO3D) dataset.
Args:
frame_annotations_file: A zipped json file containing metadata of the
frames in the dataset, serialized List[types.FrameAnnotation].
sequence_annotations_file: A zipped json file containing metadata of the
sequences in the dataset, serialized List[types.SequenceAnnotation].
subset_lists_file: A json file containing the lists of frames corresponding
corresponding to different subsets (e.g. train/val/test) of the dataset;
format: {subset: (sequence_name, frame_id, file_path)}.
subsets: Restrict frames/sequences only to the given list of subsets
as defined in subset_lists_file (see above).
limit_to: Limit the dataset to the first #limit_to frames (after other
filters have been applied).
limit_sequences_to: Limit the dataset to the first
#limit_sequences_to sequences (after other sequence filters have been
applied but before frame-based filters).
pick_sequence: A list of sequence names to restrict the dataset to.
exclude_sequence: A list of the names of the sequences to exclude.
limit_category_to: Restrict the dataset to the given list of categories.
dataset_root: The root folder of the dataset; all the paths in jsons are
specified relative to this root (but not json paths themselves).
load_images: Enable loading the frame RGB data.
load_depths: Enable loading the frame depth maps.
load_depth_masks: Enable loading the frame depth map masks denoting the
depth values used for evaluation (the points consistent across views).
load_masks: Enable loading frame foreground masks.
load_point_clouds: Enable loading sequence-level point clouds.
max_points: Cap on the number of loaded points in the point cloud;
if reached, they are randomly sampled without replacement.
mask_images: Whether to mask the images with the loaded foreground masks;
0 value is used for background.
mask_depths: Whether to mask the depth maps with the loaded foreground
masks; 0 value is used for background.
image_height: The height of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
image_width: The width of the returned images, masks, and depth maps;
aspect ratio is preserved during cropping/resizing.
box_crop: Enable cropping of the image around the bounding box inferred
from the foreground region of the loaded segmentation mask; masks
and depth maps are cropped accordingly; cameras are corrected.
box_crop_mask_thr: The threshold used to separate pixels into foreground
and background based on the foreground_probability mask; if no value
is greater than this threshold, the loader lowers it and repeats.
box_crop_context: The amount of additional padding added to each
dimension of the cropping bounding box, relative to box size.
remove_empty_masks: Removes the frames with no active foreground pixels
in the segmentation mask after thresholding (see box_crop_mask_thr).
n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence
frames in each sequences uniformly without replacement if it has
more frames than that; applied before other frame-level filters.
seed: The seed of the random generator sampling #n_frames_per_sequence
random frames per sequence.
sort_frames: Enable frame annotations sorting to group frames from the
same sequences together and order them by timestamps
eval_batches: A list of batches that form the evaluation set;
list of batch-sized lists of indices corresponding to __getitem__
of this class, thus it can be used directly as a batch sampler.
"""
frame_annotations_type: ClassVar[
Type[types.FrameAnnotation]
] = types.FrameAnnotation
path_manager: Optional[PathManager] = None
frame_annotations_file: str = ""
sequence_annotations_file: str = ""
subset_lists_file: str = ""
subsets: Optional[List[str]] = None
limit_to: int = 0
limit_sequences_to: int = 0
pick_sequence: Sequence[str] = ()
exclude_sequence: Sequence[str] = ()
limit_category_to: Sequence[int] = ()
dataset_root: str = ""
load_images: bool = True
load_depths: bool = True
load_depth_masks: bool = True
load_masks: bool = True
load_point_clouds: bool = False
max_points: int = 0
mask_images: bool = False
mask_depths: bool = False
image_height: Optional[int] = 256
image_width: Optional[int] = 256
box_crop: bool = False
box_crop_mask_thr: float = 0.4
box_crop_context: float = 1.0
remove_empty_masks: bool = False
n_frames_per_sequence: int = -1
seed: int = 0
sort_frames: bool = False
eval_batches: Optional[List[List[int]]] = None
frame_annots: List[FrameAnnotsEntry] = field(init=False)
seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
def __post_init__(self) -> None:
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `subset_to_image_path`.
self.subset_to_image_path = None
self._load_frames()
self._load_sequences()
if self.sort_frames:
self._sort_frames()
self._load_subset_lists()
self._filter_db() # also computes sequence indices
print(str(self))
def seq_frame_index_to_dataset_index(
self,
seq_frame_index: Union[
List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
],
) -> List[List[int]]:
"""
Obtain indices into the dataset object given a list of frames specified as
`seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`.
"""
# TODO: check the frame numbers are unique
_dataset_seq_frame_n_index = {
seq: {
self.frame_annots[idx]["frame_annotation"].frame_number: idx
for idx in seq_idx
}
for seq, seq_idx in self.seq_to_idx.items()
}
def _get_batch_idx(seq_name, frame_no, path=None) -> int:
idx = _dataset_seq_frame_n_index[seq_name][frame_no]
if path is not None:
# Check that the loaded frame path is consistent
# with the one stored in self.frame_annots.
assert os.path.normpath(
self.frame_annots[idx]["frame_annotation"].image.path
) == os.path.normpath(
path
), f"Inconsistent batch {seq_name, frame_no, path}."
return idx
batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index]
return batches_idx
def __str__(self) -> str:
return f"ImplicitronDataset #frames={len(self.frame_annots)}"
def __len__(self) -> int:
return len(self.frame_annots)
def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]:
return entry["subset"]
def __getitem__(self, index) -> FrameData:
if index >= len(self.frame_annots):
raise IndexError(f"index {index} out of range {len(self.frame_annots)}")
entry = self.frame_annots[index]["frame_annotation"]
point_cloud = self.seq_annots[entry.sequence_name].point_cloud
frame_data = FrameData(
frame_number=_safe_as_tensor(entry.frame_number, torch.long),
frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float),
sequence_name=entry.sequence_name,
sequence_category=self.seq_annots[entry.sequence_name].category,
camera_quality_score=_safe_as_tensor(
self.seq_annots[entry.sequence_name].viewpoint_quality_score,
torch.float,
),
point_cloud_quality_score=_safe_as_tensor(
point_cloud.quality_score, torch.float
)
if point_cloud is not None
else None,
)
# The rest of the fields are optional
frame_data.frame_type = self._get_frame_type(self.frame_annots[index])
(
frame_data.fg_probability,
frame_data.mask_path,
frame_data.bbox_xywh,
clamp_bbox_xyxy,
) = self._load_crop_fg_probability(entry)
scale = 1.0
if self.load_images and entry.image is not None:
# original image size
frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long)
(
frame_data.image_rgb,
frame_data.image_path,
frame_data.mask_crop,
scale,
) = self._load_crop_images(
entry, frame_data.fg_probability, clamp_bbox_xyxy
)
if self.load_depths and entry.depth is not None:
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability)
if entry.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(
entry,
scale,
clamp_bbox_xyxy,
)
if self.load_point_clouds and point_cloud is not None:
frame_data.sequence_point_cloud_path = pcl_path = os.path.join(
self.dataset_root, point_cloud.path
)
frame_data.sequence_point_cloud = _load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
return frame_data
def _load_crop_fg_probability(
self, entry: types.FrameAnnotation
) -> Tuple[
Optional[torch.Tensor],
Optional[str],
Optional[torch.Tensor],
Optional[torch.Tensor],
]:
fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = (
None,
None,
None,
None,
)
if (self.load_masks or self.box_crop) and entry.mask is not None:
full_path = os.path.join(self.dataset_root, entry.mask.path)
mask = _load_mask(self._local_path(full_path))
if mask.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!"
)
bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr))
if self.box_crop:
clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context)
mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path)
fg_probability, _, _ = self._resize_image(mask, mode="nearest")
return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy
def _load_crop_images(
self,
entry: types.FrameAnnotation,
fg_probability: Optional[torch.Tensor],
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor, float]:
assert self.dataset_root is not None and entry.image is not None
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = _load_image(self._local_path(path))
if image_rgb.shape[-2:] != entry.image.size:
raise ValueError(
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if self.box_crop:
assert clamp_bbox_xyxy is not None
image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path)
image_rgb, scale, mask_crop = self._resize_image(image_rgb)
if self.mask_images:
assert fg_probability is not None
image_rgb *= fg_probability
return image_rgb, path, mask_crop, scale
def _load_mask_depth(
self,
entry: types.FrameAnnotation,
clamp_bbox_xyxy: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth
assert entry_depth is not None
path = os.path.join(self.dataset_root, entry_depth.path)
depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:]
)
depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path)
depth_map, _, _ = self._resize_image(depth_map, mode="nearest")
if self.mask_depths:
assert fg_probability is not None
depth_map *= fg_probability
if self.load_depth_masks:
assert entry_depth.mask_path is not None
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = _load_depth_mask(self._local_path(mask_path))
if self.box_crop:
assert clamp_bbox_xyxy is not None
depth_mask_bbox_xyxy = _rescale_bbox(
clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:]
)
depth_mask = _crop_around_box(
depth_mask, depth_mask_bbox_xyxy, mask_path
)
depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest")
else:
depth_mask = torch.ones_like(depth_map)
return depth_map, path, depth_mask
def _get_pytorch3d_camera(
self,
entry: types.FrameAnnotation,
scale: float,
clamp_bbox_xyxy: Optional[torch.Tensor],
) -> PerspectiveCameras:
entry_viewpoint = entry.viewpoint
assert entry_viewpoint is not None
# principal point and focal length
principal_point = torch.tensor(
entry_viewpoint.principal_point, dtype=torch.float
)
focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float)
half_image_size_wh_orig = (
torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0
)
# first, we convert from the dataset's NDC convention to pixels
format = entry_viewpoint.intrinsics_format
if format.lower() == "ndc_norm_image_bounds":
# this is e.g. currently used in CO3D for storing intrinsics
rescale = half_image_size_wh_orig
elif format.lower() == "ndc_isotropic":
rescale = half_image_size_wh_orig.min()
else:
raise ValueError(f"Unknown intrinsics format: {format}")
# principal point and focal length in pixels
principal_point_px = half_image_size_wh_orig - principal_point * rescale
focal_length_px = focal_length * rescale
if self.box_crop:
assert clamp_bbox_xyxy is not None
principal_point_px -= clamp_bbox_xyxy[:2]
# now, convert from pixels to PyTorch3D v0.5+ NDC convention
if self.image_height is None or self.image_width is None:
out_size = list(reversed(entry.image.size))
else:
out_size = [self.image_width, self.image_height]
half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0
half_min_image_size_output = half_image_size_output.min()
# rescaled principal point and focal length in ndc
principal_point = (
half_image_size_output - principal_point_px * scale
) / half_min_image_size_output
focal_length = focal_length_px * scale / half_min_image_size_output
return PerspectiveCameras(
focal_length=focal_length[None],
principal_point=principal_point[None],
R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None],
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _load_frames(self) -> None:
print(f"Loading Co3D frames from {self.frame_annotations_file}.")
local_file = self._local_path(self.frame_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
frame_annots_list = types.load_dataclass(
zipfile, List[self.frame_annotations_type]
)
if not frame_annots_list:
raise ValueError("Empty dataset!")
self.frame_annots = [
FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list
]
def _load_sequences(self) -> None:
print(f"Loading Co3D sequences from {self.sequence_annotations_file}.")
local_file = self._local_path(self.sequence_annotations_file)
with gzip.open(local_file, "rt", encoding="utf8") as zipfile:
seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation])
if not seq_annots:
raise ValueError("Empty sequences file!")
self.seq_annots = {entry.sequence_name: entry for entry in seq_annots}
def _load_subset_lists(self) -> None:
print(f"Loading Co3D subset lists from {self.subset_lists_file}.")
if not self.subset_lists_file:
return
with open(self._local_path(self.subset_lists_file), "r") as f:
subset_to_seq_frame = json.load(f)
frame_path_to_subset = {
path: subset
for subset, frames in subset_to_seq_frame.items()
for _, _, path in frames
}
for frame in self.frame_annots:
frame["subset"] = frame_path_to_subset.get(
frame["frame_annotation"].image.path, None
)
if frame["subset"] is None:
warnings.warn(
"Subset lists are given but don't include "
+ frame["frame_annotation"].image.path
)
def _sort_frames(self) -> None:
# Sort frames to have them grouped by sequence, ordered by timestamp
self.frame_annots = sorted(
self.frame_annots,
key=lambda f: (
f["frame_annotation"].sequence_name,
f["frame_annotation"].frame_timestamp or 0,
),
)
def _filter_db(self) -> None:
if self.remove_empty_masks:
print("Removing images with empty masks.")
old_len = len(self.frame_annots)
msg = "remove_empty_masks needs every MaskAnnotation.mass to be set."
def positive_mass(frame_annot: types.FrameAnnotation) -> bool:
mask = frame_annot.mask
if mask is None:
return False
if mask.mass is None:
raise ValueError(msg)
return mask.mass > 1
self.frame_annots = [
frame
for frame in self.frame_annots
if positive_mass(frame["frame_annotation"])
]
print("... filtered %d -> %d" % (old_len, len(self.frame_annots)))
# this has to be called after joining with categories!!
subsets = self.subsets
if subsets:
if not self.subset_lists_file:
raise ValueError(
"Subset filter is on but subset_lists_file was not given"
)
print(f"Limitting Co3D dataset to the '{subsets}' subsets.")
# truncate the list of subsets to the valid one
self.frame_annots = [
entry for entry in self.frame_annots if entry["subset"] in subsets
]
if len(self.frame_annots) == 0:
raise ValueError(f"There are no frames in the '{subsets}' subsets!")
self._invalidate_indexes(filter_seq_annots=True)
if len(self.limit_category_to) > 0:
print(f"Limitting dataset to categories: {self.limit_category_to}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if entry.category in self.limit_category_to
}
# sequence filters
for prefix in ("pick", "exclude"):
orig_len = len(self.seq_annots)
attr = f"{prefix}_sequence"
arr = getattr(self, attr)
if len(arr) > 0:
print(f"{attr}: {str(arr)}")
self.seq_annots = {
name: entry
for name, entry in self.seq_annots.items()
if (name in arr) == (prefix == "pick")
}
print("... filtered %d -> %d" % (orig_len, len(self.seq_annots)))
if self.limit_sequences_to > 0:
self.seq_annots = dict(
islice(self.seq_annots.items(), self.limit_sequences_to)
)
# retain only frames from retained sequences
self.frame_annots = [
f
for f in self.frame_annots
if f["frame_annotation"].sequence_name in self.seq_annots
]
self._invalidate_indexes()
if self.n_frames_per_sequence > 0:
print(f"Taking max {self.n_frames_per_sequence} per sequence.")
keep_idx = []
for seq, seq_indices in self.seq_to_idx.items():
# infer the seed from the sequence name, this is reproducible
# and makes the selection differ for different sequences
seed = _seq_name_to_seed(seq) + self.seed
seq_idx_shuffled = random.Random(seed).sample(
sorted(seq_indices), len(seq_indices)
)
keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence])
print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx)))
self.frame_annots = [self.frame_annots[i] for i in keep_idx]
self._invalidate_indexes(filter_seq_annots=False)
# sequences are not decimated, so self.seq_annots is valid
if self.limit_to > 0 and self.limit_to < len(self.frame_annots):
print(
"limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to)
)
self.frame_annots = self.frame_annots[: self.limit_to]
self._invalidate_indexes(filter_seq_annots=True)
def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None:
# update seq_to_idx and filter seq_meta according to frame_annots change
# if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx
self._invalidate_seq_to_idx()
if filter_seq_annots:
self.seq_annots = {
k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx
}
def _invalidate_seq_to_idx(self) -> None:
seq_to_idx = defaultdict(list)
for idx, entry in enumerate(self.frame_annots):
seq_to_idx[entry["frame_annotation"].sequence_name].append(idx)
self.seq_to_idx = seq_to_idx
def _resize_image(
self, image, mode="bilinear"
) -> Tuple[torch.Tensor, float, torch.Tensor]:
image_height, image_width = self.image_height, self.image_width
if image_height is None or image_width is None:
# skip the resizing
imre_ = torch.from_numpy(image)
return imre_, 1.0, torch.ones_like(imre_[:1])
# takes numpy array, returns pytorch tensor
minscale = min(
image_height / image.shape[-2],
image_width / image.shape[-1],
)
imre = torch.nn.functional.interpolate(
torch.from_numpy(image)[None],
# pyre-ignore[6]
scale_factor=minscale,
mode=mode,
align_corners=False if mode == "bilinear" else None,
recompute_scale_factor=True,
)[0]
imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width)
imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre
mask = torch.zeros(1, self.image_height, self.image_width)
mask[:, 0 : imre.shape[1] - 1, 0 : imre.shape[2] - 1] = 1.0
return imre_, minscale, mask
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path
return self.path_manager.get_local_path(path)
def get_frame_numbers_and_timestamps(
self, idxs: Sequence[int]
) -> List[Tuple[int, float]]:
out: List[Tuple[int, float]] = []
for idx in idxs:
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches
def _seq_name_to_seed(seq_name) -> int:
return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16)
def _load_image(path) -> np.ndarray:
with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0
return im
def _load_16big_png_depth(depth_png) -> np.ndarray:
with Image.open(depth_png) as depth_pil:
# the image is stored with 16-bit depth but PIL reads it as I (32 bit).
# we cast it to uint16, then reinterpret as float16, then cast to float32
depth = (
np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16)
.astype(np.float32)
.reshape((depth_pil.size[1], depth_pil.size[0]))
)
return depth
def _load_1bit_png_mask(file: str) -> np.ndarray:
with Image.open(file) as pil_im:
mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32)
return mask
def _load_depth_mask(path) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth mask file name "%s"' % path)
m = _load_1bit_png_mask(path)
return m[None] # fake feature channel
def _load_depth(path, scale_adjustment) -> np.ndarray:
if not path.lower().endswith(".png"):
raise ValueError('unsupported depth file name "%s"' % path)
d = _load_16big_png_depth(path) * scale_adjustment
d[~np.isfinite(d)] = 0.0
return d[None] # fake feature channel
def _load_mask(path) -> np.ndarray:
with Image.open(path) as pil_im:
mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel
def _get_1d_bounds(arr) -> Tuple[int, int]:
nz = np.flatnonzero(arr)
return nz[0], nz[-1]
def _get_bbox_from_mask(
mask, thr, decrease_quant: float = 0.05
) -> Tuple[int, int, int, int]:
# bbox in xywh
masks_for_box = np.zeros_like(mask)
while masks_for_box.sum() <= 1.0:
masks_for_box = (mask > thr).astype(np.float32)
thr -= decrease_quant
if thr <= 0.0:
warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.")
x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2))
y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1))
return x0, y0, x1 - x0, y1 - y0
def _get_clamp_bbox(
bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = ""
) -> torch.Tensor:
# box_crop_context: rate of expansion for bbox
# returns possibly expanded bbox xyxy as float
# increase box size
if box_crop_context > 0.0:
c = box_crop_context
bbox = bbox.float()
bbox[0] -= bbox[2] * c / 2
bbox[1] -= bbox[3] * c / 2
bbox[2] += bbox[2] * c
bbox[3] += bbox[3] * c
if (bbox[2:] <= 1.0).any():
raise ValueError(
f"squashed image {impath}!! The bounding box contains no pixels."
)
bbox[2:] = torch.clamp(bbox[2:], 2)
bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax]
# +1 because upper bound is not inclusive
return bbox
def _crop_around_box(tensor, bbox, impath: str = ""):
# bbox is xyxy, where the upper bound is corrected with +1
bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1])
bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2])
bbox = bbox.round().long()
tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]]
assert all(c > 0 for c in tensor.shape), f"squashed image {impath}"
return tensor
def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor:
assert bbox is not None
assert np.prod(orig_res) > 1e-8
# average ratio of dimensions
rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0
return bbox * rel_size
def _safe_as_tensor(data, dtype):
if data is None:
return None
return torch.tensor(data, dtype=dtype)
# NOTE this cache is per-worker; they are implemented as processes.
# each batch is loaded and collated by a single worker;
# since sequences tend to co-occur within batches, this is useful.
@functools.lru_cache(maxsize=256)
def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds:
pcl = IO().load_pointcloud(pcl_path)
if max_points > 0:
pcl = pcl.subsample(max_points)
return pcl

View File

@ -0,0 +1,203 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass, field
from typing import Iterator, List, Sequence, Tuple
import numpy as np
from torch.utils.data.sampler import Sampler
from .implicitron_dataset import ImplicitronDatasetBase
@dataclass(eq=False) # TODO: do we need this if not init from config?
class SceneBatchSampler(Sampler[List[int]]):
"""
A class for sampling training batches with a controlled composition
of sequences.
"""
dataset: ImplicitronDatasetBase
batch_size: int
num_batches: int
# the sampler first samples a random element k from this list and then
# takes k random frames per sequence
images_per_seq_options: Sequence[int]
# if True, will sample a contiguous interval of frames in the sequence
# it first finds the connected segments within the sequence of sufficient length,
# then samples a random pivot element among them and ideally uses it as a middle
# of the temporal window, shifting the borders where necessary.
# This strategy mitigates the bias against shorter segments and their boundaries.
sample_consecutive_frames: bool = False
# if a number > 0, then used to define the maximum difference in frame_number
# of neighbouring frames when forming connected segments; otherwise the whole
# sequence is considered a segment regardless of frame numbers
consecutive_frames_max_gap: int = 0
# same but for timestamps if they are available
consecutive_frames_max_gap_seconds: float = 0.1
seq_names: List[str] = field(init=False)
def __post_init__(self) -> None:
if self.batch_size <= 0:
raise ValueError(
"batch_size should be a positive integral value, "
f"but got batch_size={self.batch_size}"
)
if len(self.images_per_seq_options) < 1:
raise ValueError("n_per_seq_posibilities list cannot be empty")
self.seq_names = list(self.dataset.seq_to_idx.keys())
def __len__(self) -> int:
return self.num_batches
def __iter__(self) -> Iterator[List[int]]:
for batch_idx in range(len(self)):
batch = self._sample_batch(batch_idx)
yield batch
def _sample_batch(self, batch_idx) -> List[int]:
n_per_seq = np.random.choice(self.images_per_seq_options)
n_seqs = -(-self.batch_size // n_per_seq) # round up
chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
if self.sample_consecutive_frames:
frame_idx = []
for seq in chosen_seq:
segment_index = self._build_segment_index(
list(self.dataset.seq_to_idx[seq]), n_per_seq
)
segment, idx = segment_index[np.random.randint(len(segment_index))]
if len(segment) <= n_per_seq:
frame_idx.append(segment)
else:
start = np.clip(idx - n_per_seq // 2, 0, len(segment) - n_per_seq)
frame_idx.append(segment[start : start + n_per_seq])
else:
frame_idx = [
_capped_random_choice(
self.dataset.seq_to_idx[seq], n_per_seq, replace=False
)
for seq in chosen_seq
]
frame_idx = np.concatenate(frame_idx)[: self.batch_size].tolist()
if len(frame_idx) < self.batch_size:
warnings.warn(
"Batch size smaller than self.batch_size!"
+ " (This is fine for experiments with a single scene and viewpooling)"
)
return frame_idx
def _build_segment_index(
self, seq_frame_indices: List[int], size: int
) -> List[Tuple[List[int], int]]:
"""
Returns a list of (segment, index) tuples, one per eligible frame, where
segment is a list of frame indices in the contiguous segment the frame
belongs to index is the frame's index within that segment.
Segment references are repeated but the memory is shared.
"""
if (
self.consecutive_frames_max_gap > 0
or self.consecutive_frames_max_gap_seconds > 0.0
):
sequence_timestamps = _sort_frames_by_timestamps_then_numbers(
seq_frame_indices, self.dataset
)
# TODO: use new API to access frame numbers / timestamps
segments = self._split_to_segments(sequence_timestamps)
segments = _cull_short_segments(segments, size)
if not segments:
raise AssertionError("Empty segments after culling")
else:
segments = [seq_frame_indices]
# build an index of segment for random selection of a pivot frame
segment_index = [
(segment, i) for segment in segments for i in range(len(segment))
]
return segment_index
def _split_to_segments(
self, sequence_timestamps: List[Tuple[float, int, int]]
) -> List[List[int]]:
if (
self.consecutive_frames_max_gap <= 0
and self.consecutive_frames_max_gap_seconds <= 0.0
):
raise AssertionError("This function is only needed for non-trivial max_gap")
segments = []
last_no = -self.consecutive_frames_max_gap - 1 # will trigger a new segment
last_ts = -self.consecutive_frames_max_gap_seconds - 1.0
for ts, no, idx in sequence_timestamps:
if ts <= 0.0 and no <= last_no:
raise AssertionError(
"Frames are not ordered in seq_to_idx while timestamps are not given"
)
if (
no - last_no > self.consecutive_frames_max_gap > 0
or ts - last_ts > self.consecutive_frames_max_gap_seconds > 0.0
): # new group
segments.append([idx])
else:
segments[-1].append(idx)
last_no = no
last_ts = ts
return segments
def _sort_frames_by_timestamps_then_numbers(
seq_frame_indices: List[int], dataset: ImplicitronDatasetBase
) -> List[Tuple[float, int, int]]:
"""Build the list of triplets (timestamp, frame_no, dataset_idx).
We attempt to first sort by timestamp, then by frame number.
Timestamps are coalesced with 0s.
"""
nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices)
return sorted(
[
(timestamp, frame_no, idx)
for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps)
]
)
def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]:
lengths = [(len(segment), segment) for segment in segments]
max_len, longest_segment = max(lengths)
if max_len < min_size:
return [longest_segment]
return [segment for segment in segments if len(segment) >= min_size]
def _capped_random_choice(x, size, replace: bool = True):
"""
if replace==True
randomly chooses from x `size` elements without replacement if len(x)>size
else allows replacement and selects `size` elements again.
if replace==False
randomly chooses from x `min(len(x), size)` elements without replacement
"""
len_x = x if isinstance(x, int) else len(x)
if replace:
return np.random.choice(x, size=size, replace=len_x < size)
else:
return np.random.choice(x, size=min(size, len_x), replace=False)

View File

@ -0,0 +1,331 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import gzip
import json
import sys
from dataclasses import MISSING, Field, dataclass
from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast
import numpy as np
_X = TypeVar("_X")
if sys.version_info >= (3, 8, 0):
from typing import get_args, get_origin
elif sys.version_info >= (3, 7, 0):
def get_origin(cls):
return getattr(cls, "__origin__", None)
def get_args(cls):
return getattr(cls, "__args__", None)
else:
raise ImportError("This module requires Python 3.7+")
TF3 = Tuple[float, float, float]
@dataclass
class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root
path: str
# H x W
size: Tuple[int, int] # TODO: rename size_hw?
@dataclass
class DepthAnnotation:
# path to png file, relative w.r.t. dataset_root, storing `depth / scale_adjustment`
path: str
# a factor to convert png values to actual depth: `depth = png * scale_adjustment`
scale_adjustment: float
# path to png file, relative w.r.t. dataset_root, storing binary `depth` mask
mask_path: Optional[str]
@dataclass
class MaskAnnotation:
# path to png file storing (Prob(fg | pixel) * 255)
path: str
# (soft) number of pixels in the mask; sum(Prob(fg | pixel))
mass: Optional[float] = None
@dataclass
class ViewpointAnnotation:
# In right-multiply (PyTorch3D) format. X_cam = X_world @ R + T
R: Tuple[TF3, TF3, TF3]
T: TF3
focal_length: Tuple[float, float]
principal_point: Tuple[float, float]
intrinsics_format: str = "ndc_norm_image_bounds"
# Defines the co-ordinate system where focal_length and principal_point live.
# Possible values: ndc_isotropic | ndc_norm_image_bounds (default)
# ndc_norm_image_bounds: legacy PyTorch3D NDC format, where image boundaries
# correspond to [-1, 1] x [-1, 1], and the scale along x and y may differ
# ndc_isotropic: PyTorch3D 0.5+ NDC convention where the shorter side has
# the range [-1, 1], and the longer one has the range [-s, s]; s >= 1,
# where s is the aspect ratio. The scale is same along x and y.
@dataclass
class FrameAnnotation:
"""A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation`
sequence_name: str
# 0-based, continuous frame number within sequence
frame_number: int
# timestamp in seconds from the video start
frame_timestamp: float
image: ImageAnnotation
depth: Optional[DepthAnnotation] = None
mask: Optional[MaskAnnotation] = None
viewpoint: Optional[ViewpointAnnotation] = None
@dataclass
class PointCloudAnnotation:
# path to ply file with points only, relative w.r.t. dataset_root
path: str
# the bigger the better
quality_score: float
n_points: Optional[int]
@dataclass
class VideoAnnotation:
# path to the original video file, relative w.r.t. dataset_root
path: str
# length of the video in seconds
length: float
@dataclass
class SequenceAnnotation:
sequence_name: str
category: str
video: Optional[VideoAnnotation] = None
point_cloud: Optional[PointCloudAnnotation] = None
# the bigger the better
viewpoint_quality_score: Optional[float] = None
def dump_dataclass(obj: Any, f: IO, binary: bool = False) -> None:
"""
Args:
f: Either a path to a file, or a file opened for writing.
obj: A @dataclass or collection hierarchy including dataclasses.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
f.write(json.dumps(_asdict_rec(obj)).encode("utf8"))
else:
json.dump(_asdict_rec(obj), f)
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
"""
Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields.
Args:
f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False.
"""
if binary:
asdict = json.loads(f.read().decode("utf8"))
else:
asdict = json.load(f)
if isinstance(asdict, list):
# in the list case, run a faster "vectorized" version
cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls))
else:
res = _dataclass_from_dict(asdict, cls)
return res
def _dataclass_list_from_dict_list(dlist, typeannot):
"""
Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args:
dlist: list of objects to convert.
typeannot: type of each of those objects.
Returns:
iterator or list over converted objects of the same length as `dlist`.
Raises:
ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
"""
cls = get_origin(typeannot) or typeannot
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist
elif any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist)
for i, obj in zip(idx, converted):
res[i] = obj
return res
# otherwise, we dispatch by the type of the provided annotation to convert to
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys
types = cls._field_types.values()
dlist_T = zip(*dlist)
res_T = [
_dataclass_list_from_dict_list(key_list, tp)
for key_list, tp in zip(dlist_T, types)
]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0])
dlist_T = zip(*dlist)
res_T = (
_dataclass_list_from_dict_list(pos_list, tp)
for pos_list, tp in zip(dlist_T, types)
)
if issubclass(cls, tuple):
return list(zip(*res_T))
else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t
)
all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t
)
indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1])
vals = np.split(list(all_vals_res), indices[:-1])
return [cls(zip(*k, v)) for k, v in zip(keys, vals)]
elif not dataclasses.is_dataclass(typeannot):
return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields
assert dataclasses.is_dataclass(cls)
fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot)
}
# NOTE the default object is shared here
key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items()
)
transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
def _dataclass_from_dict(d, typeannot):
cls = get_origin(typeannot) or typeannot
if d is None:
return d
elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
types = cls._field_types.values()
return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)])
elif issubclass(cls, (list, tuple)):
types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items
types = types * len(d)
return cls(_dataclass_from_dict(v, tp) for v, tp in zip(d, types))
elif issubclass(cls, dict):
key_t, val_t = get_args(typeannot)
return cls(
(_dataclass_from_dict(k, key_t), _dataclass_from_dict(v, val_t))
for k, v in d.items()
)
elif not dataclasses.is_dataclass(typeannot):
return d
assert dataclasses.is_dataclass(cls)
fieldtypes = {f.name: _unwrap_type(f.type) for f in dataclasses.fields(typeannot)}
return cls(**{k: _dataclass_from_dict(v, fieldtypes[k]) for k, v in d.items()})
def _unwrap_type(tp):
# strips Optional wrapper, if any
if get_origin(tp) is Union:
args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp
def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING:
return field.default_factory()
elif field.default is not MISSING:
return field.default
else:
return None
def _asdict_rec(obj):
return dataclasses._asdict_inner(obj, dict)
def dump_dataclass_jgzip(outfile: str, obj: Any) -> None:
"""
Dumps obj to a gzipped json outfile.
Args:
obj: A @dataclass or collection hiererchy including dataclasses.
outfile: The path to the output file.
"""
with gzip.GzipFile(outfile, "wb") as f:
dump_dataclass(obj, cast(IO, f), binary=True)
def load_dataclass_jgzip(outfile, cls):
"""
Loads a dataclass from a gzipped json outfile.
Args:
outfile: The path to the loaded file.
cls: The type annotation of the loaded dataclass.
Returns:
loaded_dataclass: The loaded dataclass.
"""
with gzip.GzipFile(outfile, "rb") as f:
return load_dataclass(cast(IO, f), cls, binary=True)

View File

@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional
import torch
DATASET_TYPE_TRAIN = "train"
DATASET_TYPE_TEST = "test"
DATASET_TYPE_KNOWN = "known"
DATASET_TYPE_UNKNOWN = "unseen"
def is_known_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a known frame.
"""
return torch.tensor(
[ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type],
dtype=torch.bool,
device=device,
)
def is_train_frame(
frame_type: List[str], device: Optional[str] = None
) -> torch.BoolTensor:
"""
Given a list `frame_type` of frame types in a batch, return a tensor
of boolean flags expressing whether the corresponding frame is a training frame.
"""
return torch.tensor(
[ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type],
dtype=torch.bool,
device=device,
)

View File

@ -0,0 +1,95 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, cast
import torch
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.structures import Pointclouds
from .implicitron_dataset import FrameData, ImplicitronDataset
def get_implicitron_sequence_pointcloud(
dataset: ImplicitronDataset,
sequence_name: Optional[str] = None,
mask_points: bool = True,
max_frames: int = -1,
num_workers: int = 0,
load_dataset_point_cloud: bool = False,
) -> Tuple[Pointclouds, FrameData]:
"""
Make a point cloud by sampling random points from each frame the dataset.
"""
if len(dataset) == 0:
raise ValueError("The dataset is empty.")
if not dataset.load_depths:
raise ValueError("The dataset has to load depths (dataset.load_depths=True).")
if mask_points and not dataset.load_masks:
raise ValueError(
"For mask_points=True, the dataset has to load masks"
+ " (dataset.load_masks=True)."
)
# setup the indices of frames loaded from the dataset db
sequence_entries = list(range(len(dataset)))
if sequence_name is not None:
sequence_entries = [
ei
for ei in sequence_entries
if dataset.frame_annots[ei]["frame_annotation"].sequence_name
== sequence_name
]
if len(sequence_entries) == 0:
raise ValueError(
f'There are no dataset entries for sequence name "{sequence_name}".'
)
# subsample loaded frames if needed
if (max_frames > 0) and (len(sequence_entries) > max_frames):
sequence_entries = [
sequence_entries[i]
for i in torch.randperm(len(sequence_entries))[:max_frames].sort().values
]
# take only the part of the dataset corresponding to the sequence entries
sequence_dataset = torch.utils.data.Subset(dataset, sequence_entries)
# load the required part of the dataset
loader = torch.utils.data.DataLoader(
sequence_dataset,
batch_size=len(sequence_dataset),
shuffle=False,
num_workers=num_workers,
collate_fn=FrameData.collate,
)
frame_data = next(iter(loader)) # there's only one batch
# scene point cloud
if load_dataset_point_cloud:
if not dataset.load_point_clouds:
raise ValueError(
"For load_dataset_point_cloud=True, the dataset has to"
+ " load point clouds (dataset.load_point_clouds=True)."
)
point_cloud = frame_data.sequence_point_cloud
else:
point_cloud = get_rgbd_point_cloud(
frame_data.camera,
frame_data.image_rgb,
frame_data.depth_map,
(cast(torch.Tensor, frame_data.fg_probability) > 0.5).float()
if frame_data.fg_probability is not None
else None,
mask_points=mask_points,
)
return point_cloud, frame_data

View File

@ -0,0 +1,216 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import os
from typing import Optional, cast
import lpips
import torch
from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo
from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
ImplicitronDatasetBase,
)
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
aggregate_nvs_results,
eval_batch,
pretty_print_nvs_metrics,
summarize_nvs_eval_results,
)
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
from tqdm import tqdm
def main() -> None:
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for multisequence/singlesequence tasks for several categories.
The evaluation is conducted on the same data as in [1] and, hence, the results
are directly comparable to the numbers reported in [1].
References:
[1] J. Reizenstein, R. Shapovalov, P. Henzler, L. Sbordone,
P. Labatut, D. Novotny:
Common Objects in 3D: Large-Scale Learning
and Evaluation of Real-life 3D Category Reconstruction
"""
task_results = {}
for task in ("singlesequence", "multisequence"):
task_results[task] = []
for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]:
for single_sequence_id in (0, 1) if task == "singlesequence" else (None,):
category_result = evaluate_dbir_for_category(
category, task=task, single_sequence_id=single_sequence_id
)
print("")
print(
f"Results for task={task}; category={category};"
+ (
f" sequence={single_sequence_id}:"
if single_sequence_id is not None
else ":"
)
)
pretty_print_nvs_metrics(category_result)
print("")
task_results[task].append(category_result)
_print_aggregate_results(task, task_results)
for task in task_results:
_print_aggregate_results(task, task_results)
def evaluate_dbir_for_category(
category: str = "apple",
bg_color: float = 0.0,
task: str = "singlesequence",
single_sequence_id: Optional[int] = None,
num_workers: int = 16,
):
"""
Evaluates new view synthesis metrics of a simple depth-based image rendering
(DBIR) model for a given task, category, and sequence (in case task=='singlesequence').
Args:
category: Object category.
bg_color: Background color of the renders.
task: Evaluation task. Either singlesequence or multisequence.
single_sequence_id: The ID of the evaluiation sequence for the singlesequence task.
num_workers: The number of workers for the employed dataloaders.
Returns:
category_result: A dictionary of quantitative metrics.
"""
single_sequence_id = single_sequence_id if single_sequence_id is not None else -1
torch.manual_seed(42)
if task not in ["multisequence", "singlesequence"]:
raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'")
datasets = dataset_zoo(
category=category,
dataset_root=os.environ["CO3D_DATASET_ROOT"],
assert_single_seq=task == "singlesequence",
dataset_name=f"co3d_{task}",
test_on_train=False,
load_point_clouds=True,
test_restrict_sequence_id=single_sequence_id,
)
dataloaders = dataloader_zoo(
datasets,
dataset_name=f"co3d_{task}",
)
test_dataset = datasets["test"]
test_dataloader = dataloaders["test"]
if task == "singlesequence":
# all_source_cameras are needed for evaluation of the
# target camera difficulty
# pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`.
sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name
all_source_cameras = _get_all_source_cameras(
test_dataset, sequence_name, num_workers=num_workers
)
else:
all_source_cameras = None
image_size = cast(ImplicitronDataset, test_dataset).image_width
if image_size is None:
raise ValueError("Image size should be set in the dataset")
# init the simple DBIR model
model = ModelDBIR(
image_size=image_size,
bg_color=bg_color,
max_points=int(1e5),
)
model.cuda()
# init the lpips model for eval
lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.cuda()
per_batch_eval_results = []
print("Evaluating DBIR model ...")
for frame_data in tqdm(test_dataloader):
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
per_batch_eval_results.append(
eval_batch(
frame_data,
nvs_prediction,
bg_color=bg_color,
lpips_model=lpips_model,
source_cameras=all_source_cameras,
)
)
category_result_flat, category_result = summarize_nvs_eval_results(
per_batch_eval_results, task
)
return category_result["results"]
def _print_aggregate_results(task, task_results) -> None:
"""
Prints the aggregate metrics for a given task.
"""
aggregate_task_result = aggregate_nvs_results(task_results[task])
print("")
print(f"Aggregate results for task={task}:")
pretty_print_nvs_metrics(aggregate_task_result)
print("")
def _get_all_source_cameras(
dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8
):
"""
Loads all training cameras of a given sequence.
The set of all seen cameras is needed for evaluating the viewpoint difficulty
for the singlescene evaluation.
Args:
dataset: Co3D dataset object.
sequence_name: The name of the sequence.
num_workers: The number of for the utilized dataloader.
"""
# load all source cameras of the sequence
seq_idx = dataset.seq_to_idx[sequence_name]
dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx)
(all_frame_data,) = torch.utils.data.DataLoader(
dataset_for_loader,
shuffle=False,
batch_size=len(dataset_for_loader),
num_workers=num_workers,
collate_fn=FrameData.collate,
)
is_known = is_known_frame(all_frame_data.frame_type)
source_cameras = all_frame_data.camera[torch.where(is_known)[0]]
return source_cameras
if __name__ == "__main__":
main()

View File

@ -0,0 +1,649 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData
from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame
from pytorch3d.implicitron.tools import vis_utils
from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps
from pytorch3d.implicitron.tools.image_utils import mask_background
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth, iou, rgb_l1
from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud
from pytorch3d.implicitron.tools.vis_utils import make_depth_image
from pytorch3d.renderer.camera_utils import join_cameras_as_batch
from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras
from pytorch3d.vis.plotly_vis import plot_scene
from tabulate import tabulate
from visdom import Visdom
EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9]
@dataclass
class NewViewSynthesisPrediction:
"""
Holds the tensors that describe a result of synthesizing new views.
"""
depth_render: Optional[torch.Tensor] = None
image_render: Optional[torch.Tensor] = None
mask_render: Optional[torch.Tensor] = None
camera_distance: Optional[torch.Tensor] = None
@dataclass
class _Visualizer:
image_render: torch.Tensor
image_rgb_masked: torch.Tensor
depth_render: torch.Tensor
depth_map: torch.Tensor
depth_mask: torch.Tensor
visdom_env: str = "eval_debug"
_viz: Visdom = field(init=False)
def __post_init__(self):
self._viz = vis_utils.get_visdom_connection()
def show_rgb(
self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor
):
self._viz.images(
torch.cat(
(
self.image_render,
self.image_rgb_masked,
loss_mask_now.repeat(1, 3, 1, 1),
),
dim=3,
),
env=self.visdom_env,
win=metric_name,
opts={"title": f"{metric_name}_{loss_value:1.2f}"},
)
def show_depth(
self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor
):
self._viz.images(
torch.cat(
(
make_depth_image(self.depth_render, loss_mask_now),
make_depth_image(self.depth_map, loss_mask_now),
),
dim=3,
),
env=self.visdom_env,
win="depth_abs" + name_postfix,
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"},
)
self._viz.images(
loss_mask_now,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_mask",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"},
)
self._viz.images(
self.depth_mask,
env=self.visdom_env,
win="depth_abs" + name_postfix + "_maskd",
opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"},
)
# show the 3D plot
# pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as
# `TensorProperties`.
viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to(
loss_mask_now.device
)
pcl_pred = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_render,
self.depth_render,
# mask_crop,
torch.ones_like(self.depth_render),
# loss_mask_now,
)
pcl_gt = get_rgbd_point_cloud(
viewpoint_trivial,
self.image_rgb_masked,
self.depth_map,
# mask_crop,
torch.ones_like(self.depth_map),
# loss_mask_now,
)
_pcls = {
pn: p
for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt))
if int(p.num_points_per_cloud()) > 0
}
plotlyplot = plot_scene(
{f"pcl{name_postfix}": _pcls},
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1,
)
self._viz.plotlyplot(
plotlyplot,
env=self.visdom_env,
win=f"pcl{name_postfix}",
)
def eval_batch(
frame_data: FrameData,
nvs_prediction: NewViewSynthesisPrediction,
bg_color: Union[torch.Tensor, str, float] = "black",
mask_thr: float = 0.5,
lpips_model=None,
visualize: bool = False,
visualize_visdom_env: str = "eval_debug",
break_after_visualising: bool = True,
source_cameras: Optional[List[CamerasBase]] = None,
) -> Dict[str, Any]:
"""
Produce performance metrics for a single batch of new-view synthesis
predictions.
Given a set of known views (for which frame_data.frame_type.endswith('known')
is True), a new-view synthesis method (NVS) is tasked to generate new views
of the scene from the viewpoint of the target views (for which
frame_data.frame_type.endswith('known') is False). The resulting
synthesized new views, stored in `nvs_prediction`, are compared to the
target ground truth in `frame_data` in terms of geometry and appearance
resulting in a dictionary of metrics returned by the `eval_batch` function.
Args:
frame_data: A FrameData object containing the input to the new view
synthesis method.
nvs_prediction: The data describing the synthesized new views.
bg_color: The background color of the generated new views and the
ground truth.
lpips_model: A pre-trained model for evaluating the LPIPS metric.
visualize: If True, visualizes the results to Visdom.
source_cameras: A list of all training cameras for evaluating the
difficulty of the target views.
Returns:
results: A dictionary holding evaluation metrics.
Throws:
ValueError if frame_data does not have frame_type, camera, or image_rgb
ValueError if the batch has a mix of training and test samples
ValueError if the batch frames are not [unseen, known, known, ...]
ValueError if one of the required fields in nvs_prediction is missing
"""
REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"]
frame_type = frame_data.frame_type
if frame_type is None:
raise ValueError("Frame type has not been set.")
# we check that all those fields are not None but Pyre can't infer that properly
# TODO: assign to local variables
if frame_data.image_rgb is None:
raise ValueError("Image is not in the evaluation batch.")
if frame_data.camera is None:
raise ValueError("Camera is not in the evaluation batch.")
if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS):
raise ValueError("One of the required predicted fields is missing")
# obtain copies to make sure we dont edit the original data
nvs_prediction = copy.deepcopy(nvs_prediction)
frame_data = copy.deepcopy(frame_data)
# mask the ground truth depth in case frame_data contains the depth mask
if frame_data.depth_map is not None and frame_data.depth_mask is not None:
frame_data.depth_map *= frame_data.depth_mask
if not isinstance(frame_type, list): # not batch FrameData
frame_type = [frame_type]
is_train = is_train_frame(frame_type)
if not (is_train[0] == is_train).all():
raise ValueError("All frames in the eval batch have to be either train/test.")
# pyre-fixme[16]: `Optional` has no attribute `device`.
is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device)
if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()):
raise ValueError(
"For evaluation the first element of the batch has to be"
+ " a target view while the rest should be source views."
) # TODO: do we need to enforce this?
# take only the first (target image)
for k in REQUIRED_NVS_PREDICTION_FIELDS:
setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1])
for k in [
"depth_map",
"image_rgb",
"fg_probability",
"mask_crop",
]:
if not hasattr(frame_data, k) or getattr(frame_data, k) is None:
continue
setattr(frame_data, k, getattr(frame_data, k)[:1])
if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0:
warnings.warn("Empty or missing depth map in evaluation!")
# eval all results in the resolution of the frame_data image
# pyre-fixme[16]: `Optional` has no attribute `shape`.
image_resol = list(frame_data.image_rgb.shape[2:])
# threshold the masks to make ground truth binary masks
mask_fg, mask_crop = [
(getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop")
]
image_rgb_masked = mask_background(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
frame_data.image_rgb,
mask_fg,
bg_color=bg_color,
)
# resize to the target resolution
for k in REQUIRED_NVS_PREDICTION_FIELDS:
imode = "bilinear" if k == "image_render" else "nearest"
val = getattr(nvs_prediction, k)
setattr(
nvs_prediction,
k,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
torch.nn.functional.interpolate(val, size=image_resol, mode=imode),
)
# clamp predicted images
# pyre-fixme[16]: `Optional` has no attribute `clamp`.
image_render = nvs_prediction.image_render.clamp(0.0, 1.0)
if visualize:
visualizer = _Visualizer(
image_render=image_render,
image_rgb_masked=image_rgb_masked,
# pyre-fixme[6]: Expected `Tensor` for 3rd param but got
# `Optional[torch.Tensor]`.
depth_render=nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map,
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
depth_mask=frame_data.depth_mask[:1],
visdom_env=visualize_visdom_env,
)
results: Dict[str, Any] = {}
results["iou"] = iou(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction.mask_render,
mask_fg,
mask=mask_crop,
)
for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("", "_fg")):
loss_mask_now = mask_crop * loss_fg_mask
for rgb_metric_name, rgb_metric_fun in zip(
("psnr", "rgb_l1"), (calc_psnr, rgb_l1)
):
metric_name = rgb_metric_name + name_postfix
results[metric_name] = rgb_metric_fun(
image_render,
image_rgb_masked,
mask=loss_mask_now,
)
if visualize:
visualizer.show_rgb(
results[metric_name].item(), metric_name, loss_mask_now
)
if name_postfix == "_fg":
# only record depth metrics for the foreground
_, abs_ = eval_depth(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
nvs_prediction.depth_render,
# pyre-fixme[6]: Expected `Tensor` for 2nd param but got
# `Optional[torch.Tensor]`.
frame_data.depth_map,
get_best_scale=True,
mask=loss_mask_now,
crop=5,
)
results["depth_abs" + name_postfix] = abs_.mean()
if visualize:
visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now)
if break_after_visualising:
import pdb
pdb.set_trace()
if lpips_model is not None:
im1, im2 = [
2.0 * im.clamp(0.0, 1.0) - 1.0
for im in (image_rgb_masked, nvs_prediction.image_render)
]
results["lpips"] = lpips_model.forward(im1, im2).item()
# convert all metrics to floats
results = {k: float(v) for k, v in results.items()}
if source_cameras is None:
# pyre-fixme[16]: Optional has no attribute __getitem__
source_cameras = frame_data.camera[torch.where(is_known)[0]]
results["meta"] = {
# calculate the camera difficulties and add to results
"camera_difficulty": calculate_camera_difficulties(
frame_data.camera[0],
source_cameras,
)[0].item(),
# store the size of the batch (corresponds to n_src_views+1)
"batch_size": int(is_known.numel()),
# store the type of the target frame
# pyre-fixme[16]: `None` has no attribute `__getitem__`.
"frame_type": str(frame_data.frame_type[0]),
}
return results
def average_per_batch_results(
results_per_batch: List[Dict[str, Any]],
idx: Optional[torch.Tensor] = None,
) -> dict:
"""
Average a list of per-batch metrics `results_per_batch`.
Optionally, if `idx` is given, only a subset of the per-batch
metrics, indexed by `idx`, is averaged.
"""
result_keys = list(results_per_batch[0].keys())
result_keys.remove("meta")
if idx is not None:
results_per_batch = [results_per_batch[i] for i in idx]
if len(results_per_batch) == 0:
return {k: float("NaN") for k in result_keys}
return {
k: float(np.array([r[k] for r in results_per_batch]).mean())
for k in result_keys
}
def calculate_camera_difficulties(
cameras_target: CamerasBase,
cameras_source: CamerasBase,
) -> torch.Tensor:
"""
Calculate the difficulties of the target cameras, given a set of known
cameras `cameras_source`.
Returns:
a tensor of shape (len(cameras_target),)
"""
ious = [
volumetric_camera_overlaps(
join_cameras_as_batch(
# pyre-fixme[6]: Expected `CamerasBase` for 1st param but got
# `Optional[pytorch3d.renderer.utils.TensorProperties]`.
[cameras_target[cami], cameras_source.to(cameras_target.device)]
)
)[0, :]
for cami in range(cameras_target.R.shape[0])
]
camera_difficulties = torch.stack(
[_reduce_camera_iou_overlap(iou[1:]) for iou in ious]
)
return camera_difficulties
def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tensor:
"""
Calculate the final camera difficulty by computing the average of the
ious of the two most similar cameras.
Returns:
single-element Tensor
"""
# pyre-ignore[16] topk not recognized
return ious.topk(k=min(topk, len(ious) - 1)).values.mean()
def get_camera_difficulty_bin_edges(task: str):
"""
Get the edges of camera difficulty bins.
"""
_eps = 1e-5
if task == "multisequence":
# TODO: extract those to constants
diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4)
diff_bin_edges[0] = 0.0 - _eps
elif task == "singlesequence":
diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float()
else:
raise ValueError(f"No such eval task {task}.")
diff_bin_names = ["hard", "medium", "easy"]
return diff_bin_edges, diff_bin_names
def summarize_nvs_eval_results(
per_batch_eval_results: List[Dict[str, Any]],
task: str = "singlesequence",
):
"""
Compile the per-batch evaluation results `per_batch_eval_results` into
a set of aggregate metrics. The produced metrics depend on the task.
Args:
per_batch_eval_results: Metrics of each per-batch evaluation.
task: The type of the new-view synthesis task.
Either 'singlesequence' or 'multisequence'.
Returns:
nvs_results_flat: A flattened dict of all aggregate metrics.
aux_out: A dictionary holding a set of auxiliary results.
"""
n_batches = len(per_batch_eval_results)
eval_sets: List[Optional[str]] = []
if task == "singlesequence":
eval_sets = [None]
# assert n_batches==100
elif task == "multisequence":
eval_sets = ["train", "test"]
# assert n_batches==1000
else:
raise ValueError(task)
batch_sizes = torch.tensor(
[r["meta"]["batch_size"] for r in per_batch_eval_results]
).long()
camera_difficulty = torch.tensor(
[r["meta"]["camera_difficulty"] for r in per_batch_eval_results]
).float()
is_train = is_train_frame([r["meta"]["frame_type"] for r in per_batch_eval_results])
# init the result database dict
results = []
diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task)
n_diff_edges = diff_bin_edges.numel()
# add per set averages
for SET in eval_sets:
if SET is None:
# task=='singlesequence'
ok_set = torch.ones(n_batches, dtype=torch.bool)
set_name = "test"
else:
# task=='multisequence'
ok_set = is_train == int(SET == "train")
set_name = SET
# eval each difficulty bin, including a full average result (diff_bin=None)
for diff_bin in [None, *list(range(n_diff_edges - 1))]:
if diff_bin is None:
# average over all results
in_bin = ok_set
diff_bin_name = "all"
else:
b1, b2 = diff_bin_edges[diff_bin : (diff_bin + 2)]
in_bin = ok_set & (camera_difficulty > b1) & (camera_difficulty <= b2)
diff_bin_name = diff_bin_names[diff_bin]
bin_results = average_per_batch_results(
per_batch_eval_results, idx=torch.where(in_bin)[0]
)
results.append(
{
"subset": set_name,
"subsubset": f"diff={diff_bin_name}",
"metrics": bin_results,
}
)
if task == "multisequence":
# split based on n_src_views
n_src_views = batch_sizes - 1
for n_src in EVAL_N_SRC_VIEWS:
ok_src = ok_set & (n_src_views == n_src)
n_src_results = average_per_batch_results(
per_batch_eval_results,
idx=torch.where(ok_src)[0],
)
results.append(
{
"subset": set_name,
"subsubset": f"n_src={int(n_src)}",
"metrics": n_src_results,
}
)
aux_out = {"results": results}
return flatten_nvs_results(results), aux_out
def _get_flat_nvs_metric_key(result, metric_name) -> str:
metric_key_postfix = f"|subset={result['subset']}|{result['subsubset']}"
metric_key = f"{metric_name}{metric_key_postfix}"
return metric_key
def flatten_nvs_results(results):
"""
Takes input `results` list of dicts of the form:
```
[
{
'subset':'train/test/...',
'subsubset': 'src=1/src=2/...',
'metrics': nvs_eval_metrics}
},
...
]
```
And converts to a flat dict as follows:
{
'subset=train/test/...|subsubset=src=1/src=2/...': nvs_eval_metrics,
...
}
"""
results_flat = {}
for result in results:
for metric_name, metric_val in result["metrics"].items():
metric_key = _get_flat_nvs_metric_key(result, metric_name)
assert metric_key not in results_flat
results_flat[metric_key] = metric_val
return results_flat
def pretty_print_nvs_metrics(results) -> None:
subsets, subsubsets = [
_ordered_set([r[k] for r in results]) for k in ("subset", "subsubset")
]
metrics = _ordered_set([metric for r in results for metric in r["metrics"]])
for subset in subsets:
tab = {}
for metric in metrics:
tab[metric] = []
header = ["metric"]
for subsubset in subsubsets:
metric_vals = [
r["metrics"][metric]
for r in results
if r["subsubset"] == subsubset and r["subset"] == subset
]
if len(metric_vals) > 0:
tab[metric].extend(metric_vals)
header.extend(subsubsets)
if any(len(v) > 0 for v in tab.values()):
print(f"===== NVS results; subset={subset} =====")
print(
tabulate(
[[metric, *v] for metric, v in tab.items()],
# pyre-fixme[61]: `header` is undefined, or not always defined.
headers=header,
)
)
def _ordered_set(list_):
return list(OrderedDict((i, 0) for i in list_).keys())
def aggregate_nvs_results(task_results):
"""
Aggregate nvs results.
For singlescene, this averages over all categories and scenes,
for multiscene, the average is over all per-category results.
"""
task_results_cat = [r_ for r in task_results for r_ in r]
subsets, subsubsets = [
_ordered_set([r[k] for r in task_results_cat]) for k in ("subset", "subsubset")
]
metrics = _ordered_set(
[metric for r in task_results_cat for metric in r["metrics"]]
)
average_results = []
for subset in subsets:
for subsubset in subsubsets:
metrics_lists = [
r["metrics"]
for r in task_results_cat
if r["subsubset"] == subsubset and r["subset"] == subset
]
avg_metrics = {}
for metric in metrics:
avg_metrics[metric] = float(
np.nanmean(
np.array([metric_list[metric] for metric_list in metrics_lists])
)
)
average_results.append(
{
"subset": subset,
"subsubset": subsubset,
"metrics": avg_metrics,
}
)
return average_results

View File

@ -0,0 +1,172 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from collections import defaultdict
from typing import Dict, List, Optional, Union
import torch
from pytorch3d.implicitron.tools.config import Configurable
# TODO: probabilistic embeddings?
class Autodecoder(Configurable, torch.nn.Module):
"""
Autodecoder module
Settings:
encoding_dim: Embedding dimension for the decoder.
n_instances: The maximum number of instances stored by the autodecoder.
init_scale: Scale factor for the initial autodecoder weights.
ignore_input: If `True`, optimizes a single code for any input.
"""
encoding_dim: int = 0
n_instances: int = 0
init_scale: float = 1.0
ignore_input: bool = False
def __post_init__(self):
super().__init__()
if self.n_instances <= 0:
# Do not init the codes at all in case we have 0 instances.
return
self._autodecoder_codes = torch.nn.Embedding(
self.n_instances,
self.encoding_dim,
scale_grad_by_freq=True,
)
with torch.no_grad():
# weight has been initialised from Normal(0, 1)
self._autodecoder_codes.weight *= self.init_scale
self._sequence_map = self._build_sequence_map()
# Make sure to register hooks for correct handling of saving/loading
# the module's _sequence_map.
self._register_load_state_dict_pre_hook(self._load_sequence_map_hook)
self._register_state_dict_hook(_save_sequence_map_hook)
def _build_sequence_map(
self, sequence_map_dict: Optional[Dict[str, int]] = None
) -> Dict[str, int]:
"""
Args:
sequence_map_dict: A dictionary used to initialize the sequence_map.
Returns:
sequence_map: a dictionary of key: id pairs.
"""
# increments the counter when asked for a new value
sequence_map = defaultdict(iter(range(self.n_instances)).__next__)
if sequence_map_dict is not None:
# Assign all keys from the loaded sequence_map_dict to self._sequence_map.
# Since this is done in the original order, it should generate
# the same set of key:id pairs. We check this with an assert to be sure.
for x, x_id in sequence_map_dict.items():
x_id_ = sequence_map[x]
assert x_id == x_id_
return sequence_map
def calc_squared_encoding_norm(self):
if self.n_instances <= 0:
return None
return (self._autodecoder_codes.weight ** 2).mean()
def get_encoding_dim(self) -> int:
if self.n_instances <= 0:
return 0
return self.encoding_dim
def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]:
"""
Args:
x: A batch of `N` sequence identifiers. Either a long tensor of size
`(N,)` keys in [0, n_instances), or a list of `N` string keys that
are hashed to codes (without collisions).
Returns:
codes: A tensor of shape `(N, self.encoding_dim)` containing the
sequence-specific autodecoder codes.
"""
if self.n_instances == 0:
return None
if self.ignore_input:
x = ["singleton"]
if isinstance(x[0], str):
try:
x = torch.tensor(
# pyre-ignore[29]
[self._sequence_map[elem] for elem in x],
dtype=torch.long,
device=next(self.parameters()).device,
)
except StopIteration:
raise ValueError("Not enough n_instances in the autodecoder")
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self._autodecoder_codes(x)
def _load_sequence_map_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
"""
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
Returns:
Constructed sequence_map if it exists in the state_dict
else raises a warning only.
"""
sequence_map_key = prefix + "_sequence_map"
if sequence_map_key in state_dict:
sequence_map_dict = state_dict.pop(sequence_map_key)
self._sequence_map = self._build_sequence_map(
sequence_map_dict=sequence_map_dict
)
else:
warnings.warn("No sequence map in Autodecoder state dict!")
def _save_sequence_map_hook(
self,
state_dict,
prefix,
local_metadata,
) -> None:
"""
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
"""
sequence_map_key = prefix + "_sequence_map"
sequence_map_dict = dict(self._sequence_map.items())
state_dict[sequence_map_key] = sequence_map_dict

View File

@ -0,0 +1,883 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import warnings
from dataclasses import field
from typing import Any, Dict, List, Optional, Tuple
import torch
import tqdm
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools import image_utils, vis_utils
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples
from pytorch3d.implicitron.tools.utils import cat_dataclass
from pytorch3d.renderer import RayBundle, utils as rend_utils
from pytorch3d.renderer.cameras import CamerasBase
from visdom import Visdom
from .autodecoder import Autodecoder
from .implicit_function.base import ImplicitFunctionBase
from .implicit_function.idr_feature_field import IdrFeatureField # noqa
from .implicit_function.neural_radiance_field import ( # noqa
NeRFormerImplicitFunction,
NeuralRadianceFieldImplicitFunction,
)
from .implicit_function.scene_representation_networks import ( # noqa
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
)
from .metrics import ViewMetrics
from .renderer.base import (
BaseRenderer,
EvaluationMode,
ImplicitFunctionWrapper,
RendererOutput,
RenderSamplingMode,
)
from .renderer.lstm_renderer import LSTMRenderer # noqa
from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa
from .renderer.ray_sampler import RaySampler
from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa
from .resnet_feature_extractor import ResNetFeatureExtractor
from .view_pooling.feature_aggregation import FeatureAggregatorBase
from .view_pooling.view_sampling import ViewSampler
STD_LOG_VARS = ["objective", "epoch", "sec/it"]
# pyre-ignore: 13
class GenericModel(Configurable, torch.nn.Module):
"""
GenericModel is a wrapper for the neural implicit
rendering and reconstruction pipeline which consists
of the following sequence of 7 steps (steps 24 are normally
skipped in overfitting scenario, since conditioning on source views
does not add much information; otherwise they should be present altogether):
(1) Ray Sampling
------------------
Rays are sampled from an image grid based on the target view(s).
_____________
(2) Feature Extraction (optional)
-----------------------
A feature extractor (e.g. a convolutional
neural net) is used to extract image features
from the source view(s).
(3) View Sampling (optional)
------------------
Image features are sampled at the 2D projections
of a set of 3D points along each of the sampled
target rays from (1).
(4) Feature Aggregation (optional)
------------------
Aggregate features and masks sampled from
image view(s) in (3).
____________
(5) Implicit Function Evaluation
------------------
Evaluate the implicit function(s) at the sampled ray points
(optionally pass in the aggregated image features from (4)).
(6) Rendering
------------------
Render the image into the target cameras by raymarching along
the sampled rays and aggregating the colors and densities
output by the implicit function in (5).
(7) Loss Computation
------------------
Compute losses based on the predicted target image(s).
The `forward` function of GenericModel executes
this sequence of steps. Currently, steps 1, 3, 4, 5, 6
can be customized by intializing a subclass of the appropriate
baseclass and adding the newly created module to the registry.
Please see https://github.com/fairinternal/pytorch3d/blob/co3d/projects/implicitron_trainer/README.md#custom-plugins
for more details on how to create and register a custom component.
In the config .yaml files for experiments, the parameters below are
contained in the `generic_model_args` node. As GenericModel
derives from Configurable, the input arguments are
parsed by the run_auto_creation function to initialize the
necessary member modules. Please see implicitron_trainer/README.md
for more details on this process.
Args:
mask_images: Whether or not to mask the RGB image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
mask_depths: Whether or not to mask the depth image background given the
foreground mask (the `fg_probability` argument of `GenericModel.forward`)
render_image_width: Width of the output image to render
render_image_height: Height of the output image to render
mask_threshold: If greater than 0.0, the foreground mask is
thresholded by this value before being applied to the RGB/Depth images
output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by
splatting onto an image grid. Default: False.
bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0)
view_pool: If True, features are sampled from the source image(s)
at the projected 2d locations of the sampled 3d ray points from the target
view(s), i.e. this activates step (3) above.
num_passes: The specified implicit_function is initialized num_passes
times and run sequentially.
chunk_size_grid: The total number of points which can be rendered
per chunk. This is used to compute the number of rays used
per chunk when the chunked version of the renderer is used (in order
to fit rendering on all rays in memory)
render_features_dimensions: The number of output features to render.
Defaults to 3, corresponding to RGB images.
n_train_target_views: The number of cameras to render into at training
time; first `n_train_target_views` in the batch are considered targets,
the rest are sources.
sampling_mode_training: The sampling method to use during training. Must be
a value from the RenderSamplingMode Enum.
sampling_mode_evaluation: Same as above but for evaluation.
sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding
of the image (referred to as the global_code) that can be used to model aspects of
the scene such as multiple objects or morphing objects. It is up to the implicit
function definition how to use it, but the most typical way is to broadcast and
concatenate to the other inputs for the implicit function.
raysampler: An instance of RaySampler which is used to emit
rays from the target view(s).
renderer_class_type: The name of the renderer class which is available in the global
registry.
renderer: A renderer class which inherits from BaseRenderer. This is used to
generate the images from the target view(s).
image_feature_extractor: A module for extrating features from an input image.
view_sampler: An instance of ViewSampler which is used for sampling of
image-based features at the 2D projections of a set
of 3D points.
feature_aggregator_class_type: The name of the feature aggregator class which
is available in the global registry.
feature_aggregator: A feature aggregator class which inherits from
FeatureAggregatorBase. Typically, the aggregated features and their
masks are output by a `ViewSampler` which samples feature tensors extracted
from a set of source images. FeatureAggregator executes step (4) above.
implicit_function_class_type: The type of implicit function to use which
is available in the global registry.
implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions
are initialised to be in self._implicit_functions.
loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation
for `ViewMetrics` class for available loss functions.
log_vars: A list of variable names which should be logged.
The names should correspond to a subset of the keys of the
dict `preds` output by the `forward` function.
"""
mask_images: bool = True
mask_depths: bool = True
render_image_width: int = 400
render_image_height: int = 400
mask_threshold: float = 0.5
output_rasterized_mc: bool = False
bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0)
view_pool: bool = False
num_passes: int = 1
chunk_size_grid: int = 4096
render_features_dimensions: int = 3
tqdm_trigger_threshold: int = 16
n_train_target_views: int = 1
sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid"
# ---- autodecoder settings
sequence_autodecoder: Autodecoder
# ---- raysampler
raysampler: RaySampler
# ---- renderer configs
renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer"
renderer: BaseRenderer
# ---- view sampling settings - used if view_pool=True
# (This is only created if view_pool is False)
image_feature_extractor: ResNetFeatureExtractor
view_sampler: ViewSampler
# ---- ---- view sampling feature aggregator settings
feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator"
feature_aggregator: FeatureAggregatorBase
# ---- implicit function settings
implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction"
# This is just a model, never constructed.
# The actual implicit functions live in self._implicit_functions
implicit_function: ImplicitFunctionBase
# ---- loss weights
loss_weights: Dict[str, float] = field(
default_factory=lambda: {
"loss_rgb_mse": 1.0,
"loss_prev_stage_rgb_mse": 1.0,
"loss_mask_bce": 0.0,
"loss_prev_stage_mask_bce": 0.0,
}
)
# ---- variables to be logged (logger automatically ignores if not computed)
log_vars: List[str] = field(
default_factory=lambda: [
"loss_rgb_psnr_fg",
"loss_rgb_psnr",
"loss_rgb_mse",
"loss_rgb_huber",
"loss_depth_abs",
"loss_depth_abs_fg",
"loss_mask_neg_iou",
"loss_mask_bce",
"loss_mask_beta_prior",
"loss_eikonal",
"loss_density_tv",
"loss_depth_neg_penalty",
"loss_autodecoder_norm",
# metrics that are only logged in 2+stage renderes
"loss_prev_stage_rgb_mse",
"loss_prev_stage_rgb_psnr_fg",
"loss_prev_stage_rgb_psnr",
"loss_prev_stage_mask_bce",
*STD_LOG_VARS,
]
)
def __post_init__(self):
super().__init__()
self.view_metrics = ViewMetrics()
self._check_and_preprocess_renderer_configs()
self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training
self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation
self.raysampler_args["image_width"] = self.render_image_width
self.raysampler_args["image_height"] = self.render_image_height
run_auto_creation(self)
self._implicit_functions = self._construct_implicit_functions()
self.print_loss_weights()
def forward(
self,
*, # force keyword-only arguments
image_rgb: Optional[torch.Tensor],
camera: CamerasBase,
fg_probability: Optional[torch.Tensor],
mask_crop: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
sequence_name: Optional[List[str]],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> Dict[str, Any]:
"""
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images;
the first `min(B, n_train_target_views)` images are considered targets and
are used to supervise the renders; the rest corresponding to the source
viewpoints from which features will be extracted.
camera: An instance of CamerasBase containing a batch of `B` cameras corresponding
to the viewpoints of target images, from which the rays will be sampled,
and source images, which will be used for intersecting with target rays.
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of
foreground masks.
mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid
regions in the input images (i.e. regions that do not correspond
to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to
"mask_sample", rays will be sampled in the non zero regions.
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
sequence_name: A list of `B` strings corresponding to the sequence names
from which images `image_rgb` were extracted. They are used to match
target frames with relevant source frames.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
Returns:
preds: A dictionary containing all outputs of the forward pass including the
rendered images, depths, masks, losses and other metrics.
"""
image_rgb, fg_probability, depth_map = self._preprocess_input(
image_rgb, fg_probability, depth_map
)
# Obtain the batch size from the camera as this is the only required input.
batch_size = camera.R.shape[0]
# Determine the number of target views, i.e. cameras we render into.
n_targets = (
1
if evaluation_mode == EvaluationMode.EVALUATION
else batch_size
if self.n_train_target_views <= 0
else min(self.n_train_target_views, batch_size)
)
# Select the target cameras.
target_cameras = camera[list(range(n_targets))]
# Determine the used ray sampling mode.
sampling_mode = RenderSamplingMode(
self.sampling_mode_training
if evaluation_mode == EvaluationMode.TRAINING
else self.sampling_mode_evaluation
)
# (1) Sample rendering rays with the ray sampler.
ray_bundle: RayBundle = self.raysampler(
target_cameras,
evaluation_mode,
mask=mask_crop[:n_targets]
if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE
else None,
)
# custom_args hold additional arguments to the implicit function.
custom_args = {}
if self.view_pool:
if sequence_name is None:
raise ValueError("sequence_name must be provided for view pooling")
# (2) Extract features for the image
img_feats = self.image_feature_extractor(image_rgb, fg_probability)
# (3) Sample features and masks at the ray points
curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731
pts=pts,
seq_id_pts=sequence_name[:n_targets],
camera=camera,
seq_id_camera=sequence_name,
feats=img_feats,
masks=mask_crop,
) # returns feats_sampled, masks_sampled
# (4) Aggregate features from multiple views
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731
*curried_view_sampler(pts=pts),
pts=pts,
camera=camera,
) # TODO: do we need to pass a callback rather than compute here?
# precomputing will be faster for 2 passes
# -> but this is important for non-nerf
custom_args["fun_viewpool"] = curried_view_pool
global_code = None
if self.sequence_autodecoder.n_instances > 0:
if sequence_name is None:
raise ValueError("sequence_name must be provided for autodecoder.")
global_code = self.sequence_autodecoder(sequence_name[:n_targets])
custom_args["global_code"] = global_code
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
# torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
# torch.nn.Module]` is not a function.
for func in self._implicit_functions:
func.bind_args(**custom_args)
object_mask: Optional[torch.Tensor] = None
if fg_probability is not None:
sampled_fb_prob = rend_utils.ndc_grid_sample(
fg_probability[:n_targets], ray_bundle.xys, mode="nearest"
)
object_mask = sampled_fb_prob > 0.5
# (5)-(6) Implicit function evaluation and Rendering
rendered = self._render(
ray_bundle=ray_bundle,
sampling_mode=sampling_mode,
evaluation_mode=evaluation_mode,
implicit_functions=self._implicit_functions,
object_mask=object_mask,
)
# Unbind the custom arguments to prevent pytorch from storing
# large buffers of intermediate results due to points in the
# bound arguments.
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
# torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
# torch.nn.Module]` is not a function.
for func in self._implicit_functions:
func.unbind_args()
preds = self._get_view_metrics(
raymarched=rendered,
xys=ray_bundle.xys,
image_rgb=None if image_rgb is None else image_rgb[:n_targets],
depth_map=None if depth_map is None else depth_map[:n_targets],
fg_probability=None
if fg_probability is None
else fg_probability[:n_targets],
mask_crop=None if mask_crop is None else mask_crop[:n_targets],
)
if sampling_mode == RenderSamplingMode.MASK_SAMPLE:
if self.output_rasterized_mc:
# Visualize the monte-carlo pixel renders by splatting onto
# an image grid.
(
preds["images_render"],
preds["depths_render"],
preds["masks_render"],
) = self._rasterize_mc_samples(
ray_bundle.xys,
rendered.features,
rendered.depths,
masks=rendered.masks,
)
elif sampling_mode == RenderSamplingMode.FULL_GRID:
preds["images_render"] = rendered.features.permute(0, 3, 1, 2)
preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2)
preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2)
preds["nvs_prediction"] = NewViewSynthesisPrediction(
image_render=preds["images_render"],
depth_render=preds["depths_render"],
mask_render=preds["masks_render"],
)
else:
raise AssertionError("Unreachable state")
# calc the AD penalty, returns None if autodecoder is not active
ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm()
if ad_penalty is not None:
preds["loss_autodecoder_norm"] = ad_penalty
# (7) Compute losses
# finally get the optimization objective using self.loss_weights
objective = self._get_objective(preds)
if objective is not None:
preds["objective"] = objective
return preds
def _get_objective(self, preds) -> Optional[torch.Tensor]:
"""
A helper function to compute the overall loss as the dot product
of individual loss functions with the corresponding weights.
"""
losses_weighted = [
preds[k] * float(w)
for k, w in self.loss_weights.items()
if (k in preds and w != 0.0)
]
if len(losses_weighted) == 0:
warnings.warn("No main objective found.")
return None
loss = sum(losses_weighted)
assert torch.is_tensor(loss)
return loss
def visualize(
self,
viz: Visdom,
visdom_env_imgs: str,
preds: Dict[str, Any],
prefix: str,
) -> None:
"""
Helper function to visualize the predictions generated
in the forward pass.
Args:
viz: Visdom connection object
visdom_env_imgs: name of visdom environment for the images.
preds: predictions dict like returned by forward()
prefix: prepended to the names of images
"""
if not viz.check_connection():
print("no visdom server! -> skipping batch vis")
return
idx_image = 0
title = f"{prefix}_im{idx_image}"
vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title)
def _render(
self,
*,
ray_bundle: RayBundle,
object_mask: Optional[torch.Tensor],
sampling_mode: RenderSamplingMode,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object
in the image. This is required for the SignedDistanceFunctionRenderer.
sampling_mode: The sampling method to use. Must be a value from the
RenderSamplingMode Enum.
Returns:
An instance of RendererOutput
"""
if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0:
return _apply_chunked(
self.renderer,
_chunk_generator(
self.chunk_size_grid,
ray_bundle,
object_mask,
self.tqdm_trigger_threshold,
**kwargs,
),
lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]),
)
else:
# pyre-fixme[29]: `BaseRenderer` is not a function.
return self.renderer(
ray_bundle=ray_bundle,
object_mask=object_mask,
**kwargs,
)
def _get_viewpooled_feature_dim(self):
return (
self.feature_aggregator.get_aggregated_feature_dim(
self.image_feature_extractor.get_feat_dims()
)
if self.view_pool
else 0
)
def _check_and_preprocess_renderer_configs(self):
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_training"
] = self.raysampler_args["stratified_point_sampling_training"]
self.renderer_MultiPassEmissionAbsorptionRenderer_args[
"stratified_sampling_coarse_evaluation"
] = self.raysampler_args["stratified_point_sampling_evaluation"]
self.renderer_SignedDistanceFunctionRenderer_args[
"render_features_dimensions"
] = self.render_features_dimensions
self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[
"object_bounding_sphere"
] = self.raysampler_args["scene_extent"]
def create_image_feature_extractor(self):
"""
Custom creation function called by run_auto_creation so that the
image_feature_extractor is not created if it is not be needed.
"""
if self.view_pool:
self.image_feature_extractor = ResNetFeatureExtractor(
**self.image_feature_extractor_args
)
def create_implicit_function(self) -> None:
"""
No-op called by run_auto_creation so that self.implicit_function
does not get created. __post_init__ creates the implicit function(s)
in wrappers explicitly in self._implicit_functions.
"""
pass
def _construct_implicit_functions(self):
"""
After run_auto_creation has been called, the arguments
for each of the possible implicit function methods are
available. `GenericModel` arguments are first validated
based on the custom requirements for each specific
implicit function method. Then the required implicit
function(s) are initialized.
"""
# nerf preprocessing
nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args
nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args
nerf_args["latent_dim"] = nerformer_args["latent_dim"] = (
self._get_viewpooled_feature_dim()
+ self.sequence_autodecoder.get_encoding_dim()
)
nerf_args["color_dim"] = nerformer_args[
"color_dim"
] = self.render_features_dimensions
# idr preprocessing
idr = self.implicit_function_IdrFeatureField_args
idr["feature_vector_size"] = self.render_features_dimensions
idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim()
# srn preprocessing
srn = self.implicit_function_SRNImplicitFunction_args
srn.raymarch_function_args.latent_dim = (
self._get_viewpooled_feature_dim()
+ self.sequence_autodecoder.get_encoding_dim()
)
# srn_hypernet preprocessing
srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args
srn_hypernet_args = srn_hypernet.hypernet_args
srn_hypernet_args.latent_dim_hypernet = (
self.sequence_autodecoder.get_encoding_dim()
)
srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim()
# check that for srn, srn_hypernet, idr we have self.num_passes=1
implicit_function_type = registry.get(
ImplicitFunctionBase, self.implicit_function_class_type
)
if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes():
raise ValueError(
self.implicit_function_class_type
+ f"requires num_passes=1 not {self.num_passes}"
)
if implicit_function_type.requires_pooling_without_aggregation():
has_aggregation = hasattr(self.feature_aggregator, "reduction_functions")
if not self.view_pool or has_aggregation:
raise ValueError(
"Chosen implicit function requires view pooling without aggregation."
)
config_name = f"implicit_function_{self.implicit_function_class_type}_args"
config = getattr(self, config_name, None)
if config is None:
raise ValueError(f"{config_name} not present")
implicit_functions_list = [
ImplicitFunctionWrapper(implicit_function_type(**config))
for _ in range(self.num_passes)
]
return torch.nn.ModuleList(implicit_functions_list)
def print_loss_weights(self) -> None:
"""
Print a table of the loss weights.
"""
print("-------\nloss_weights:")
for k, w in self.loss_weights.items():
print(f"{k:40s}: {w:1.2e}")
print("-------")
def _preprocess_input(
self,
image_rgb: Optional[torch.Tensor],
fg_probability: Optional[torch.Tensor],
depth_map: Optional[torch.Tensor],
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Helper function to preprocess the input images and optional depth maps
to apply masking if required.
Args:
image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images
corresponding to the source viewpoints from which features will be extracted
fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch
of foreground masks with values in [0, 1].
depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
Returns:
Modified image_rgb, fg_mask, depth_map
"""
fg_mask = fg_probability
if fg_mask is not None and self.mask_threshold > 0.0:
# threshold masks
warnings.warn("Thresholding masks!")
fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask)
if self.mask_images and fg_mask is not None and image_rgb is not None:
# mask the image
warnings.warn("Masking images!")
image_rgb = image_utils.mask_background(
image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color)
)
if self.mask_depths and fg_mask is not None and depth_map is not None:
# mask the depths
assert (
self.mask_threshold > 0.0
), "Depths should be masked only with thresholded masks"
warnings.warn("Masking depths!")
depth_map = depth_map * fg_mask
return image_rgb, fg_mask, depth_map
def _get_view_metrics(
self,
raymarched: RendererOutput,
xys: torch.Tensor,
image_rgb: Optional[torch.Tensor] = None,
depth_map: Optional[torch.Tensor] = None,
fg_probability: Optional[torch.Tensor] = None,
mask_crop: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
metrics = self.view_metrics(
image_sampling_grid=xys,
images_pred=raymarched.features,
images=image_rgb,
depths_pred=raymarched.depths,
depths=depth_map,
masks_pred=raymarched.masks,
masks=fg_probability,
masks_crop=mask_crop,
keys_prefix=keys_prefix,
**raymarched.aux,
)
if raymarched.prev_stage:
metrics.update(
self._get_view_metrics(
raymarched.prev_stage,
xys,
image_rgb,
depth_map,
fg_probability,
mask_crop,
keys_prefix=(keys_prefix + "prev_stage_"),
)
)
return metrics
@torch.no_grad()
def _rasterize_mc_samples(
self,
xys: torch.Tensor,
features: torch.Tensor,
depth: torch.Tensor,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Rasterizes Monte-Carlo features back onto the image.
Args:
xys: B x ... x 2 2D point locations in PyTorch3D NDC convention
features: B x ... x C tensor containing per-point rendered features.
depth: B x ... x 1 tensor containing per-point rendered depth.
"""
ba = xys.shape[0]
# Flatten the features and xy locations.
features_depth_ras = torch.cat(
(
features.reshape(ba, -1, features.shape[-1]),
depth.reshape(ba, -1, 1),
),
dim=-1,
)
xys_ras = xys.reshape(ba, -1, 2)
if masks is not None:
masks_ras = masks.reshape(ba, -1, 1)
else:
masks_ras = None
if min(self.render_image_height, self.render_image_width) <= 0:
raise ValueError(
"Need to specify a positive"
" self.render_image_height and self.render_image_width"
" for MC rasterisation."
)
# Estimate the rasterization point radius so that we approximately fill
# the whole image given the number of rasterized points.
pt_radius = 2.0 * math.sqrt(xys.shape[1])
# Rasterize the samples.
features_depth_render, masks_render = rasterize_mc_samples(
xys_ras,
features_depth_ras,
(self.render_image_height, self.render_image_width),
radius=pt_radius,
masks=masks_ras,
)
images_render = features_depth_render[:, :-1]
depths_render = features_depth_render[:, -1:]
return images_render, depths_render, masks_render
def _apply_chunked(func, chunk_generator, tensor_collator):
"""
Helper function to apply a function on a sequence of
chunked inputs yielded by a generator and collate
the result.
"""
processed_chunks = [
func(*chunk_args, **chunk_kwargs)
for chunk_args, chunk_kwargs in chunk_generator
]
return cat_dataclass(processed_chunks, tensor_collator)
def _tensor_collator(batch, new_dims) -> torch.Tensor:
"""
Helper function to reshape the batch to the desired shape
"""
return torch.cat(batch, dim=1).reshape(*new_dims, -1)
def _chunk_generator(
chunk_size: int,
ray_bundle: RayBundle,
object_mask: Optional[torch.Tensor],
tqdm_trigger_threshold: int,
*args,
**kwargs,
):
"""
Helper function which yields chunks of rays from the
input ray_bundle, to be used when the number of rays is
large and will not fit in memory for rendering.
"""
(
batch_size,
*spatial_dim,
n_pts_per_ray,
) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray
if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0:
raise ValueError(
f"chunk_size_grid ({chunk_size}) should be divisible "
f"by n_pts_per_ray ({n_pts_per_ray})"
)
n_rays = math.prod(spatial_dim)
# special handling for raytracing-based methods
n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size)
chunk_size_in_rays = -(-n_rays // n_chunks)
iter = range(0, n_rays, chunk_size_in_rays)
if len(iter) >= tqdm_trigger_threshold:
iter = tqdm.tqdm(iter)
for start_idx in iter:
end_idx = min(start_idx + chunk_size_in_rays, n_rays)
ray_bundle_chunk = RayBundle(
origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx],
directions=ray_bundle.directions.reshape(batch_size, -1, 3)[
:, start_idx:end_idx
],
lengths=ray_bundle.lengths.reshape(
batch_size, math.prod(spatial_dim), n_pts_per_ray
)[:, start_idx:end_idx],
xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx],
)
extra_args = kwargs.copy()
if object_mask is not None:
extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[
:, start_idx:end_idx
]
yield [ray_bundle_chunk, *args], extra_args

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,50 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Optional
from pytorch3d.implicitron.tools.config import ReplaceableBase
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import RayBundle
class ImplicitFunctionBase(ABC, ReplaceableBase):
def __init__(self):
super().__init__()
@abstractmethod
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
raise NotImplementedError()
@staticmethod
def allows_multiple_passes() -> bool:
"""
Returns True if this implicit function allows
multiple passes.
"""
return False
@staticmethod
def requires_pooling_without_aggregation() -> bool:
"""
Returns True if this implicit function needs
pooling without aggregation.
"""
return False
def on_bind_args(self) -> None:
"""
Called when the custom args are fixed in the main model forward pass.
"""
pass

View File

@ -0,0 +1,142 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import math
from typing import Sequence
import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer.implicit import HarmonicEmbedding
from torch import nn
from .base import ImplicitFunctionBase
@registry.register
class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
feature_vector_size: int = 3
d_in: int = 3
d_out: int = 1
dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512)
geometric_init: bool = True
bias: float = 1.0
skip_in: Sequence[int] = ()
weight_norm: bool = True
n_harmonic_functions_xyz: int = 0
pooled_feature_dim: int = 0
encoding_dim: int = 0
def __post_init__(self):
super().__init__()
dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size]
self.embed_fn = None
if self.n_harmonic_functions_xyz > 0:
self.embed_fn = HarmonicEmbedding(
self.n_harmonic_functions_xyz, append_input=True
)
dims[0] = self.embed_fn.get_output_dim()
if self.pooled_feature_dim > 0:
dims[0] += self.pooled_feature_dim
if self.encoding_dim > 0:
dims[0] += self.encoding_dim
self.num_layers = len(dims)
out_dim = 0
layers = []
for layer_idx in range(self.num_layers - 1):
if layer_idx + 1 in self.skip_in:
out_dim = dims[layer_idx + 1] - dims[0]
else:
out_dim = dims[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim)
if self.geometric_init:
if layer_idx == self.num_layers - 2:
torch.nn.init.normal_(
lin.weight,
mean=math.pi ** 0.5 / dims[layer_idx] ** 0.5,
std=0.0001,
)
torch.nn.init.constant_(lin.bias, -self.bias)
elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
torch.nn.init.normal_(
lin.weight[:, :3], 0.0, 2 ** 0.5 / out_dim ** 0.5
)
elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
else:
torch.nn.init.constant_(lin.bias, 0.0)
torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5)
if self.weight_norm:
lin = nn.utils.weight_norm(lin)
layers.append(lin)
self.linear_layers = torch.nn.ModuleList(layers)
self.out_dim = out_dim
self.softplus = nn.Softplus(beta=100)
# pyre-fixme[14]: `forward` overrides method defined in `ImplicitFunctionBase`
# inconsistently.
def forward(
self,
# ray_bundle: RayBundle,
rays_points_world: torch.Tensor, # TODO: unify the APIs
fun_viewpool=None,
global_code=None,
):
# this field only uses point locations
# rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
if rays_points_world.numel() == 0 or (
self.embed_fn is None and fun_viewpool is None and global_code is None
):
return torch.tensor(
[], device=rays_points_world.device, dtype=rays_points_world.dtype
).view(0, self.out_dim)
embedding = None
if self.embed_fn is not None:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
embedding = self.embed_fn(rays_points_world)
if fun_viewpool is not None:
assert rays_points_world.ndim == 2
pooled_feature = fun_viewpool(rays_points_world[None])
# TODO: pooled features are 4D!
embedding = torch.cat((embedding, pooled_feature), dim=-1)
if global_code is not None:
assert embedding.ndim == 2
assert global_code.shape[0] == 1 # TODO: generalize to batches!
# This will require changing raytracer code
# embedding = embedding[None].expand(global_code.shape[0], *embedding.shape)
embedding = torch.cat(
(embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)),
dim=-1,
)
x = embedding
for layer_idx in range(self.num_layers - 1):
if layer_idx in self.skip_in:
x = torch.cat([x, embedding], dim=-1) / 2 ** 0.5
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.linear_layers[layer_idx](x)
if layer_idx < self.num_layers - 2:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
x = self.softplus(x)
return x # TODO: unify the APIs

View File

@ -0,0 +1,542 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import List, Optional
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .base import ImplicitFunctionBase
from .utils import create_embeddings_for_implicit_function
class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
n_harmonic_functions_xyz: int = 10
n_harmonic_functions_dir: int = 4
n_hidden_neurons_dir: int = 128
latent_dim: int = 0
input_xyz: bool = True
xyz_ray_dir_in_camera_coords: bool = False
color_dim: int = 3
"""
Args:
n_harmonic_functions_xyz: The number of harmonic functions
used to form the harmonic embedding of 3D point locations.
n_harmonic_functions_dir: The number of harmonic functions
used to form the harmonic embedding of the ray directions.
n_hidden_neurons_xyz: The number of hidden units in the
fully connected layers of the MLP that accepts the 3D point
locations and outputs the occupancy field with the intermediate
features.
n_hidden_neurons_dir: The number of hidden units in the
fully connected layers of the MLP that accepts the intermediate
features and ray directions and outputs the radiance field
(per-point colors).
n_layers_xyz: The number of layers of the MLP that outputs the
occupancy field.
append_xyz: The list of indices of the skip layers of the occupancy MLP.
"""
def __post_init__(self):
super().__init__()
# The harmonic embedding layer converts input 3D coordinates
# to a representation that is more suitable for
# processing with a deep neural network.
self.harmonic_embedding_xyz = HarmonicEmbedding(
self.n_harmonic_functions_xyz, append_input=True
)
self.harmonic_embedding_dir = HarmonicEmbedding(
self.n_harmonic_functions_dir, append_input=True
)
if not self.input_xyz and self.latent_dim <= 0:
raise ValueError("The latent dimension has to be > 0 if xyz is not input!")
embedding_dim_dir = self.harmonic_embedding_dir.get_output_dim()
self.xyz_encoder = self._construct_xyz_encoder(
input_dim=self.get_xyz_embedding_dim()
)
self.intermediate_linear = torch.nn.Linear(
self.n_hidden_neurons_xyz, self.n_hidden_neurons_xyz
)
_xavier_init(self.intermediate_linear)
self.density_layer = torch.nn.Linear(self.n_hidden_neurons_xyz, 1)
_xavier_init(self.density_layer)
# Zero the bias of the density layer to avoid
# a completely transparent initialization.
self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough
self.color_layer = torch.nn.Sequential(
LinearWithRepeat(
self.n_hidden_neurons_xyz + embedding_dim_dir, self.n_hidden_neurons_dir
),
torch.nn.ReLU(True),
torch.nn.Linear(self.n_hidden_neurons_dir, self.color_dim),
torch.nn.Sigmoid(),
)
def get_xyz_embedding_dim(self):
return (
self.harmonic_embedding_xyz.get_output_dim() * int(self.input_xyz)
+ self.latent_dim
)
def _construct_xyz_encoder(self, input_dim: int):
raise NotImplementedError()
def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
"""
This function takes per-point `features` predicted by `self.xyz_encoder`
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self.color_layer((self.intermediate_linear(features), rays_embedding))
@staticmethod
def allows_multiple_passes() -> bool:
"""
Returns True as this implicit function allows
multiple passes. Overridden from ImplicitFunctionBase.
"""
return True
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
"""
The forward function accepts the parametrizations of
3D points sampled along projection rays. The forward
pass is responsible for attaching a 3D vector
and a 1D scalar representing the point's
RGB color and opacity respectively.
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
# rays_points_world.shape = [minibatch x ... x pts_per_ray x 3]
embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle),
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self.harmonic_embedding_xyz
if self.input_xyz
else None,
global_code=global_code,
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords,
camera=camera,
)
# embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
features = self.xyz_encoder(embeds)
# features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
# TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
features = features.reshape(*rays_points_world.shape[:-1], -1)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raw_densities = self.density_layer(features)
# raw_densities.shape = [minibatch x ... x 1] in [0-1]
if self.xyz_ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions
rays_colors = self._get_colors(features, directions)
# rays_colors.shape = [minibatch x ... x 3] in [0-1]
return raw_densities, rays_colors, {}
@registry.register
class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 1.0
n_hidden_neurons_xyz: int = 256
n_layers_xyz: int = 8
append_xyz: List[int] = field(default_factory=lambda: [5])
def _construct_xyz_encoder(self, input_dim: int):
return MLPWithInputSkips(
self.n_layers_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_skips=self.append_xyz,
)
@registry.register
class NeRFormerImplicitFunction(NeuralRadianceFieldBase):
transformer_dim_down_factor: float = 2.0
n_hidden_neurons_xyz: int = 80
n_layers_xyz: int = 2
append_xyz: List[int] = field(default_factory=lambda: [1])
def _construct_xyz_encoder(self, input_dim: int):
return TransformerWithInputSkips(
self.n_layers_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_dim,
self.n_hidden_neurons_xyz,
input_skips=self.append_xyz,
dim_down_factor=self.transformer_dim_down_factor,
)
@staticmethod
def requires_pooling_without_aggregation() -> bool:
"""
Returns True as this implicit function needs
pooling without aggregation. Overridden from ImplicitFunctionBase.
"""
return True
class MLPWithInputSkips(torch.nn.Module):
"""
Implements the multi-layer perceptron architecture of the Neural Radiance Field.
As such, `MLPWithInputSkips` is a multi layer perceptron consisting
of a sequence of linear layers with ReLU activations.
Additionally, for a set of predefined layers `input_skips`, the forward pass
appends a skip tensor `z` to the output of the preceding layer.
Note that this follows the architecture described in the Supplementary
Material (Fig. 7) of [1].
References:
[1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik
and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng:
NeRF: Representing Scenes as Neural Radiance Fields for View
Synthesis, ECCV2020
"""
def _make_affine_layer(self, input_dim, hidden_dim):
l1 = torch.nn.Linear(input_dim, hidden_dim * 2)
l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2)
_xavier_init(l1)
_xavier_init(l2)
return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2)
def _apply_affine_layer(self, layer, x, z):
mu_log_std = layer(z)
mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1)
std = torch.nn.functional.softplus(log_std)
return (x - mu) * std
def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 256,
input_skips: List[int] = [5],
skip_affine_trans: bool = False,
no_last_relu=False,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
layers = []
skip_affine_layers = []
for layeri in range(n_layers):
dimin = hidden_dim if layeri > 0 else input_dim
dimout = hidden_dim if layeri + 1 < n_layers else output_dim
if layeri > 0 and layeri in input_skips:
if skip_affine_trans:
skip_affine_layers.append(
self._make_affine_layer(skip_dim, hidden_dim)
)
else:
dimin = hidden_dim + skip_dim
linear = torch.nn.Linear(dimin, dimout)
_xavier_init(linear)
layers.append(
torch.nn.Sequential(linear, torch.nn.ReLU(True))
if not no_last_relu or layeri + 1 < n_layers
else linear
)
self.mlp = torch.nn.ModuleList(layers)
if skip_affine_trans:
self.skip_affines = torch.nn.ModuleList(skip_affine_layers)
self._input_skips = set(input_skips)
self._skip_affine_trans = skip_affine_trans
def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None):
"""
Args:
x: The input tensor of shape `(..., input_dim)`.
z: The input skip tensor of shape `(..., skip_dim)` which is appended
to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape `(..., output_dim)`.
"""
y = x
if z is None:
# if the skip tensor is None, we use `x` instead.
z = x
skipi = 0
for li, layer in enumerate(self.mlp):
if li in self._input_skips:
if self._skip_affine_trans:
y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
else:
y = torch.cat((y, z), dim=-1)
skipi += 1
y = layer(y)
return y
class TransformerWithInputSkips(torch.nn.Module):
def __init__(
self,
n_layers: int = 8,
input_dim: int = 39,
output_dim: int = 256,
skip_dim: int = 39,
hidden_dim: int = 64,
input_skips: List[int] = [5],
dim_down_factor: float = 1,
):
"""
Args:
n_layers: The number of linear layers of the MLP.
input_dim: The number of channels of the input tensor.
output_dim: The number of channels of the output.
skip_dim: The number of channels of the tensor `z` appended when
evaluating the skip layers.
hidden_dim: The number of hidden units of the MLP.
input_skips: The list of layer indices at which we append the skip
tensor `z`.
"""
super().__init__()
self.first = torch.nn.Linear(input_dim, hidden_dim)
_xavier_init(self.first)
self.skip_linear = torch.nn.ModuleList()
layers_pool, layers_ray = [], []
dimout = 0
for layeri in range(n_layers):
dimin = int(round(hidden_dim / (dim_down_factor ** layeri)))
dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1))))
print(f"Tr: {dimin} -> {dimout}")
for _i, l in enumerate((layers_pool, layers_ray)):
l.append(
TransformerEncoderLayer(
d_model=[dimin, dimout][_i],
nhead=4,
dim_feedforward=hidden_dim,
dropout=0.0,
d_model_out=dimout,
)
)
if layeri in input_skips:
self.skip_linear.append(torch.nn.Linear(input_dim, dimin))
self.last = torch.nn.Linear(dimout, output_dim)
_xavier_init(self.last)
self.layers_pool, self.layers_ray = (
torch.nn.ModuleList(layers_pool),
torch.nn.ModuleList(layers_ray),
)
self._input_skips = set(input_skips)
def forward(
self,
x: torch.Tensor,
z: Optional[torch.Tensor] = None,
):
"""
Args:
x: The input tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`.
z: The input skip tensor of shape
`(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)`
which is appended to layers whose indices are specified by `input_skips`.
Returns:
y: The output tensor of shape
`(minibatch, 1, ..., n_ray_pts, input_dim)`.
"""
if z is None:
# if the skip tensor is None, we use `x` instead.
z = x
y = self.first(x)
B, n_pool, n_rays, n_pts, dim = y.shape
# y_p in n_pool, n_pts, B x n_rays x dim
y_p = y.permute(1, 3, 0, 2, 4)
skipi = 0
dimh = dim
for li, (layer_pool, layer_ray) in enumerate(
zip(self.layers_pool, self.layers_ray)
):
y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh)
if li in self._input_skips:
z_skip = self.skip_linear[skipi](z)
y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape(
n_pool, n_pts * B * n_rays, dimh
)
skipi += 1
# n_pool x B*n_rays*n_pts x dim
y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None)
dimh = y_pool_attn.shape[-1]
y_ray_attn = (
y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh)
.permute(1, 0, 2, 3)
.reshape(n_pts, n_pool * B * n_rays, dimh)
)
# n_pts x n_pool*B*n_rays x dim
y_ray_attn, ray_attn = layer_ray(
y_ray_attn,
src_key_padding_mask=None,
)
y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3)
y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4)
W = torch.softmax(y[..., :1], dim=1)
y = (y * W).sum(dim=1)
y = self.last(y)
return y
class TransformerEncoderLayer(torch.nn.Module):
r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
This standard encoder layer is based on the paper "Attention Is All You Need".
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
in a different way during application.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
dim_feedforward: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
activation: the activation function of intermediate layer, relu or gelu (default=relu).
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1
):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout)
# Implementation of Feedforward model
self.linear1 = torch.nn.Linear(d_model, dim_feedforward)
self.dropout = torch.nn.Dropout(dropout)
d_model_out = d_model if d_model_out <= 0 else d_model_out
self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out)
self.norm1 = torch.nn.LayerNorm(d_model)
self.norm2 = torch.nn.LayerNorm(d_model_out)
self.dropout1 = torch.nn.Dropout(dropout)
self.dropout2 = torch.nn.Dropout(dropout)
self.activation = torch.nn.functional.relu
def forward(self, src, src_mask=None, src_key_padding_mask=None):
r"""Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
see the docs in Transformer class.
"""
src2, attn = self.self_attn(
src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
d_out = src2.shape[-1]
src = src[..., :d_out] + self.dropout2(src2)[..., :d_out]
src = self.norm2(src)
return src, attn
def _xavier_init(linear) -> None:
"""
Performs the Xavier weight initialization of the linear layer `linear`.
"""
torch.nn.init.xavier_uniform_(linear.weight.data)

View File

@ -0,0 +1,411 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/vsitzmann/scene-representation-networks
# Copyright (c) 2019 Vincent Sitzmann
from typing import Any, Optional, Tuple, cast
import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping
from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation
from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.implicit import HarmonicEmbedding
from .base import ImplicitFunctionBase
from .utils import create_embeddings_for_implicit_function
def _kaiming_normal_init(module: torch.nn.Module) -> None:
if isinstance(module, (torch.nn.Linear, LinearWithRepeat)):
torch.nn.init.kaiming_normal_(
module.weight, a=0.0, nonlinearity="relu", mode="fan_in"
)
class SRNRaymarchFunction(Configurable, torch.nn.Module):
n_harmonic_functions: int = 3 # 0 means raw 3D coord inputs
n_hidden_units: int = 256
n_layers: int = 2
in_features: int = 3
out_features: int = 256
latent_dim: int = 0
xyz_in_camera_coords: bool = False
# The internal network can be set as an output of an SRNHyperNet.
# Note that, in order to avoid Pytorch's automatic registering of the
# raymarch_function module on construction, we input the network wrapped
# as a 1-tuple.
# raymarch_function should ideally be typed as Optional[Tuple[Callable]]
# but Omegaconf.structured doesn't like that. TODO: revisit after new
# release of omegaconf including https://github.com/omry/omegaconf/pull/749 .
raymarch_function: Any = None
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static(
self.in_features,
self.n_harmonic_functions,
True,
)
+ self.latent_dim
)
if self.raymarch_function is not None:
self._net = self.raymarch_function[0]
else:
self._net = pytorch_prototyping.FCBlock(
hidden_ch=self.n_hidden_units,
num_hidden_layers=self.n_layers,
in_features=input_embedding_dim,
out_features=self.out_features,
)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
"""
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
fun_viewpool: an optional callback with the signature
fun_fiewpool(points) -> pooled_features
where points is a [N_TGT x N x 3] tensor of world coords,
and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor
of the features pooled from the context images.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: Set to None.
"""
# We first convert the ray parametrizations to world
# coordinates with `ray_bundle_to_ray_points`.
rays_points_world = ray_bundle_to_ray_points(ray_bundle)
embeds = create_embeddings_for_implicit_function(
xyz_world=ray_bundle_to_ray_points(ray_bundle),
# pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
# for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
xyz_embedding_function=self._harmonic_embedding,
global_code=global_code,
fun_viewpool=fun_viewpool,
xyz_in_camera_coords=self.xyz_in_camera_coords,
camera=camera,
)
# Before running the network, we have to resize embeds to ndims=3,
# otherwise the SRN layers consume huge amounts of memory.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raymarch_features = self._net(
embeds.view(embeds.shape[0], -1, embeds.shape[-1])
)
# raymarch_features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
# NNs operate on the flattenned rays; reshaping to the correct spatial size
raymarch_features = raymarch_features.reshape(*rays_points_world.shape[:-1], -1)
return raymarch_features, None
class SRNPixelGenerator(Configurable, torch.nn.Module):
n_harmonic_functions: int = 4
n_hidden_units: int = 256
n_hidden_units_color: int = 128
n_layers: int = 2
in_features: int = 256
out_features: int = 3
ray_dir_in_camera_coords: bool = False
def __post_init__(self):
super().__init__()
self._harmonic_embedding = HarmonicEmbedding(
self.n_harmonic_functions, append_input=True
)
self._net = pytorch_prototyping.FCBlock(
hidden_ch=self.n_hidden_units,
num_hidden_layers=self.n_layers,
in_features=self.in_features,
out_features=self.n_hidden_units,
)
self._density_layer = torch.nn.Linear(self.n_hidden_units, 1)
self._density_layer.apply(_kaiming_normal_init)
embedding_dim_dir = self._harmonic_embedding.get_output_dim(input_dims=3)
self._color_layer = torch.nn.Sequential(
LinearWithRepeat(
self.n_hidden_units + embedding_dim_dir,
self.n_hidden_units_color,
),
torch.nn.LayerNorm([self.n_hidden_units_color]),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(self.n_hidden_units_color, self.out_features),
)
self._color_layer.apply(_kaiming_normal_init)
# TODO: merge with NeuralRadianceFieldBase's _get_colors
def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor):
"""
This function takes per-point `features` predicted by `self.net`
and evaluates the color model in order to attach to each
point a 3D vector of its RGB color.
"""
# Normalize the ray_directions to unit l2 norm.
rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
# Obtain the harmonic embedding of the normalized ray directions.
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
rays_embedding = self._harmonic_embedding(rays_directions_normed)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
return self._color_layer((features, rays_embedding))
def forward(
self,
raymarch_features: torch.Tensor,
ray_bundle: RayBundle,
camera: Optional[CamerasBase] = None,
**kwargs,
):
"""
Args:
raymarch_features: Features from the raymarching network of shape
`(minibatch, ..., self.in_features)`
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape `(minibatch, ..., 3)` denoting the
origins of the sampling rays in world coords.
directions: A tensor of shape `(minibatch, ..., 3)`
containing the direction vectors of sampling rays in world coords.
lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)`
containing the lengths at which the rays are sampled.
Returns:
rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)`
denoting the opacitiy of each ray point.
rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)`
denoting the color of each ray point.
"""
# raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
features = self._net(raymarch_features)
# features.shape = [minibatch x ... x self.n_hidden_units]
if self.ray_dir_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords")
directions = ray_bundle.directions @ camera.R
else:
directions = ray_bundle.directions
# NNs operate on the flattenned rays; reshaping to the correct spatial size
features = features.reshape(*raymarch_features.shape[:-1], -1)
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
raw_densities = self._density_layer(features)
rays_colors = self._get_colors(features, directions)
return raw_densities, rays_colors
class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
"""
This is a raymarching function which has a forward like SRNRaymarchFunction
but instead of the weights being parameters of the module, they
are the output of another network, the hypernet, which takes the global_code
as input. All the dataclass members of SRNRaymarchFunction are here with the
same meaning. In addition, there are members with names ending `_hypernet`
which affect the hypernet.
Because this class may be called repeatedly for the same global_code, the
output of the hypernet is cached in self.cached_srn_raymarch_function.
This member must be manually set to None whenever the global_code changes.
"""
n_harmonic_functions: int = 3 # 0 means raw 3D coord inputs
n_hidden_units: int = 256
n_layers: int = 2
n_hidden_units_hypernet: int = 256
n_layers_hypernet: int = 1
in_features: int = 3
out_features: int = 256
latent_dim_hypernet: int = 0
latent_dim: int = 0
xyz_in_camera_coords: bool = False
def __post_init__(self):
super().__init__()
raymarch_input_embedding_dim = (
HarmonicEmbedding.get_output_dim_static(
self.in_features,
self.n_harmonic_functions,
True,
)
+ self.latent_dim
)
self._hypernet = hyperlayers.HyperFC(
hyper_in_ch=self.latent_dim_hypernet,
hyper_num_hidden_layers=self.n_layers_hypernet,
hyper_hidden_ch=self.n_hidden_units_hypernet,
hidden_ch=self.n_hidden_units,
num_hidden_layers=self.n_layers,
in_ch=raymarch_input_embedding_dim,
out_ch=self.n_hidden_units,
)
self.cached_srn_raymarch_function: Optional[Tuple[SRNRaymarchFunction]] = None
def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]:
"""
Runs the hypernet and returns a 1-tuple containing the generated
srn_raymarch_function.
"""
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
net = self._hypernet(global_code)
# use the hyper-net generated network to instantiate the raymarch module
srn_raymarch_function = SRNRaymarchFunction(
n_harmonic_functions=self.n_harmonic_functions,
n_hidden_units=self.n_hidden_units,
n_layers=self.n_layers,
in_features=self.in_features,
out_features=self.out_features,
latent_dim=self.latent_dim,
xyz_in_camera_coords=self.xyz_in_camera_coords,
raymarch_function=(net,),
)
# move the generated raymarch function to the correct device
srn_raymarch_function.to(global_code.device)
return (srn_raymarch_function,)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
**kwargs,
):
if global_code is None:
raise ValueError("SRN Hypernetwork requires a non-trivial global code.")
# The raymarching network is cached in case the function is called repeatedly
# across LSTM iterations for the same global_code.
if self.cached_srn_raymarch_function is None:
# generate the raymarching network from the hypernet
# pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
self.cached_srn_raymarch_function = self._run_hypernet(global_code)
(srn_raymarch_function,) = cast(
Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
)
return srn_raymarch_function(
ray_bundle=ray_bundle,
fun_viewpool=fun_viewpool,
camera=camera,
global_code=None, # the hypernetwork takes the global code
)
@registry.register
# pyre-fixme[13]: Uninitialized attribute
class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
raymarch_function: SRNRaymarchFunction
pixel_generator: SRNPixelGenerator
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
raymarch_features: Optional[torch.Tensor] = None,
**kwargs,
):
predict_colors = raymarch_features is not None
if predict_colors:
return self.pixel_generator(
raymarch_features=raymarch_features,
ray_bundle=ray_bundle,
camera=camera,
**kwargs,
)
else:
return self.raymarch_function(
ray_bundle=ray_bundle,
fun_viewpool=fun_viewpool,
camera=camera,
global_code=global_code,
**kwargs,
)
@registry.register
# pyre-fixme[13]: Uninitialized attribute
class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
"""
This implicit function uses a hypernetwork to generate the
SRNRaymarchingFunction, and this is cached. Whenever the
global_code changes, `on_bind_args` must be called to clear
the cache.
"""
hypernet: SRNRaymarchHyperNet
pixel_generator: SRNPixelGenerator
def __post_init__(self):
super().__init__()
run_auto_creation(self)
def forward(
self,
ray_bundle: RayBundle,
fun_viewpool=None,
camera: Optional[CamerasBase] = None,
global_code=None,
raymarch_features: Optional[torch.Tensor] = None,
**kwargs,
):
predict_colors = raymarch_features is not None
if predict_colors:
return self.pixel_generator(
raymarch_features=raymarch_features,
ray_bundle=ray_bundle,
camera=camera,
**kwargs,
)
else:
return self.hypernet(
ray_bundle=ray_bundle,
fun_viewpool=fun_viewpool,
camera=camera,
global_code=global_code,
**kwargs,
)
def on_bind_args(self):
"""
The global_code may have changed, so we reset the hypernet.
"""
self.hypernet.cached_srn_raymarch_function = None

View File

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Callable, Optional
import torch
from pytorch3d.renderer.cameras import CamerasBase
def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor):
"""
Expands the `global_code` of shape (minibatch, dim)
so that it can be appended to `embeds` of shape (minibatch, ..., dim2),
and appends to the last dimension of `embeds`.
"""
bs = embeds.shape[0]
global_code_broadcast = global_code.view(bs, *([1] * (embeds.ndim - 2)), -1).expand(
*embeds.shape[:-1],
global_code.shape[-1],
)
return torch.cat([embeds, global_code_broadcast], dim=-1)
def create_embeddings_for_implicit_function(
xyz_world: torch.Tensor,
xyz_in_camera_coords: bool,
global_code: Optional[torch.Tensor],
camera: Optional[CamerasBase],
fun_viewpool: Optional[Callable],
xyz_embedding_function: Optional[Callable],
) -> torch.Tensor:
bs, *spatial_size, pts_per_ray, _ = xyz_world.shape
if xyz_in_camera_coords:
if camera is None:
raise ValueError("Camera must be given if xyz_in_camera_coords")
ray_points_for_embed = (
camera.get_world_to_view_transform()
.transform_points(xyz_world.view(bs, -1, 3))
.view(xyz_world.shape)
)
else:
ray_points_for_embed = xyz_world
if xyz_embedding_function is None:
embeds = torch.empty(
bs,
1,
math.prod(spatial_size),
pts_per_ray,
0,
dtype=xyz_world.dtype,
device=xyz_world.device,
)
else:
embeds = xyz_embedding_function(ray_points_for_embed).reshape(
bs,
1,
math.prod(spatial_size),
pts_per_ray,
-1,
) # flatten spatial, add n_src dim
if fun_viewpool is not None:
# viewpooling
embeds_viewpooled = fun_viewpool(xyz_world.reshape(bs, -1, 3))
embed_shape = (
bs,
embeds_viewpooled.shape[1],
math.prod(spatial_size),
pts_per_ray,
-1,
)
embeds_viewpooled = embeds_viewpooled.reshape(*embed_shape)
if embeds is not None:
embeds = torch.cat([embeds.expand(*embed_shape), embeds_viewpooled], dim=-1)
else:
embeds = embeds_viewpooled
if global_code is not None:
# append the broadcasted global code to embeds
embeds = broadcast_global_code(embeds, global_code)
return embeds

View File

@ -0,0 +1,230 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from typing import Dict, Optional
import torch
from pytorch3d.implicitron.tools import metric_utils as utils
from pytorch3d.renderer import utils as rend_utils
class ViewMetrics(torch.nn.Module):
def forward(
self,
image_sampling_grid: torch.Tensor,
images: Optional[torch.Tensor] = None,
images_pred: Optional[torch.Tensor] = None,
depths: Optional[torch.Tensor] = None,
depths_pred: Optional[torch.Tensor] = None,
masks: Optional[torch.Tensor] = None,
masks_pred: Optional[torch.Tensor] = None,
masks_crop: Optional[torch.Tensor] = None,
grad_theta: Optional[torch.Tensor] = None,
density_grid: Optional[torch.Tensor] = None,
keys_prefix: str = "loss_",
mask_renders_by_pred: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Calculates various differentiable metrics useful for supervising
differentiable rendering pipelines.
Args:
image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D
image locations at which the predictions are defined.
All ground truth inputs are sampled at these
locations in order to extract values that correspond
to the predictions.
images: A tensor of shape `(B, H, W, 3)` containing ground truth
rgb values.
images_pred: A tensor of shape `(B, ..., 3)` containing predicted
rgb values.
depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth
depth values.
depths_pred: A tensor of shape `(B, ..., 1)` containing predicted
depth values.
masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth
foreground masks.
masks_pred: A tensor of shape `(B, ..., 1)` containing predicted
foreground masks.
grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation
of a gradient of a signed distance function w.r.t.
input 3D coordinates used to compute the eikonal loss.
density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a
`Hg x Wg x Dg` voxel grid of density values.
keys_prefix: A common prefix for all keys in the output dictionary
containing all metrics.
mask_renders_by_pred: If `True`, masks rendered images by the predicted
`masks_pred` prior to computing all rgb metrics.
Returns:
metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the
names of the output metrics `metric_name_i` with their corresponding
values `metric_value_i` represented as 0-dimensional float tensors.
The calculated metrics are:
rgb_huber: A robust huber loss between `image_pred` and `image`.
rgb_mse: Mean squared error between `image_pred` and `image`.
rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`.
rgb_psnr_fg: Peak signal-to-noise ratio between the foreground
region of `image_pred` and `image` as defined by `mask`.
rgb_mse_fg: Mean squared error between the foreground
region of `image_pred` and `image` as defined by `mask`.
mask_neg_iou: (1 - intersection-over-union) between `mask_pred`
and `mask`.
mask_bce: Binary cross entropy between `mask_pred` and `mask`.
mask_beta_prior: A loss enforcing strictly binary values
of `mask_pred`: `log(mask_pred) + log(1-mask_pred)`
depth_abs: Mean per-pixel L1 distance between
`depth_pred` and `depth`.
depth_abs_fg: Mean per-pixel L1 distance between the foreground
region of `depth_pred` and `depth` as defined by `mask`.
eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`.
density_tv: The Total Variation regularizer of density
values in `density_grid` (sum of L1 distances of values
of all 4-neighbouring cells).
depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative
predicted depth values.
"""
# TODO: extract functions
# reshape from B x ... x DIM to B x DIM x -1 x 1
images_pred, masks_pred, depths_pred = [
_reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred]
]
# reshape the sampling grid as well
# TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var
# now that we use rend_utils.ndc_grid_sample
image_sampling_grid = image_sampling_grid.reshape(
image_sampling_grid.shape[0], -1, 1, 2
)
# closure with the given image_sampling_grid
def sample(tensor, mode):
if tensor is None:
return tensor
return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode)
# eval all results in this size
images = sample(images, mode="bilinear")
depths = sample(depths, mode="nearest")
masks = sample(masks, mode="nearest")
masks_crop = sample(masks_crop, mode="nearest")
if masks_crop is None and images_pred is not None:
masks_crop = torch.ones_like(images_pred[:, :1])
if masks_crop is None and depths_pred is not None:
masks_crop = torch.ones_like(depths_pred[:, :1])
preds = {}
if images is not None and images_pred is not None:
# TODO: mask_renders_by_pred is always false; simplify
preds.update(
_rgb_metrics(
images,
images_pred,
masks,
masks_pred,
masks_crop,
mask_renders_by_pred,
)
)
if masks_pred is not None:
preds["mask_beta_prior"] = utils.beta_prior(masks_pred)
if masks is not None and masks_pred is not None:
preds["mask_neg_iou"] = utils.neg_iou_loss(
masks_pred, masks, mask=masks_crop
)
preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop)
if depths is not None and depths_pred is not None:
assert masks_crop is not None
_, abs_ = utils.eval_depth(
depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0
)
preds["depth_abs"] = abs_.mean()
if masks is not None:
mask = masks * masks_crop
_, abs_ = utils.eval_depth(
depths_pred, depths, get_best_scale=True, mask=mask, crop=0
)
preds["depth_abs_fg"] = abs_.mean()
# regularizers
if grad_theta is not None:
preds["eikonal"] = _get_eikonal_loss(grad_theta)
if density_grid is not None:
preds["density_tv"] = _get_grid_tv_loss(density_grid)
if depths_pred is not None:
preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred)
if keys_prefix is not None:
preds = {(keys_prefix + k): v for k, v in preds.items()}
return preds
def _rgb_metrics(
images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred
):
assert masks_crop is not None
if mask_renders_by_pred:
images = images[..., masks_pred.reshape(-1), :]
masks_crop = masks_crop[..., masks_pred.reshape(-1), :]
masks = masks is not None and masks[..., masks_pred.reshape(-1), :]
rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True)
rgb_loss = utils.huber(rgb_squared, scaling=0.03)
crop_mass = masks_crop.sum().clamp(1.0)
# print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP
preds = {
"rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass,
"rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass,
"rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop),
}
if masks is not None:
masks = masks_crop * masks
preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks)
preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0)
return preds
def _get_eikonal_loss(grad_theta):
return ((grad_theta.norm(2, dim=1) - 1) ** 2).mean()
def _get_grid_tv_loss(grid, log_domain: bool = True, eps: float = 1e-5):
if log_domain:
if (grid <= -eps).any():
warnings.warn("Grid has negative values; this will produce NaN loss")
grid = torch.log(grid + eps)
# this is an isotropic version, note that it ignores last rows/cols
return torch.mean(
utils.safe_sqrt(
(grid[..., :-1, :-1, 1:] - grid[..., :-1, :-1, :-1]) ** 2
+ (grid[..., :-1, 1:, :-1] - grid[..., :-1, :-1, :-1]) ** 2
+ (grid[..., 1:, :-1, :-1] - grid[..., :-1, :-1, :-1]) ** 2,
eps=1e-5,
)
)
def _get_depth_neg_penalty_loss(depth):
neg_penalty = depth.clamp(min=None, max=0.0) ** 2
return torch.mean(neg_penalty)
def _reshape_nongrid_var(x):
if x is None:
return None
ba, *_, dim = x.shape
return x.reshape(ba, -1, 1, dim).permute(0, 3, 1, 2).contiguous()

View File

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List
import torch
from pytorch3d.implicitron.dataset.utils import is_known_frame
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
NewViewSynthesisPrediction,
)
from pytorch3d.implicitron.tools.point_cloud_utils import (
get_rgbd_point_cloud,
render_point_cloud_pytorch3d,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
class ModelDBIR(torch.nn.Module):
"""
A simple depth-based image rendering model.
"""
def __init__(
self,
image_size: int = 256,
bg_color: float = 0.0,
max_points: int = -1,
):
"""
Initializes a simple DBIR model.
Args:
image_size: The size of the rendered rectangular images.
bg_color: The color of the background.
max_points: Maximum number of points in the point cloud
formed by unprojecting all source view depths.
If more points are present, they are randomly subsampled
to #max_size points without replacement.
"""
super().__init__()
self.image_size = image_size
self.bg_color = bg_color
self.max_points = max_points
def forward(
self,
camera: CamerasBase,
image_rgb: torch.Tensor,
depth_map: torch.Tensor,
fg_probability: torch.Tensor,
frame_type: List[str],
**kwargs,
) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass
"""
Given a set of input source cameras images and depth maps, unprojects
all RGBD maps to a colored point cloud and renders into the target views.
Args:
camera: A batch of `N` PyTorch3D cameras.
image_rgb: A batch of `N` images of shape `(N, 3, H, W)`.
depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`.
fg_probability: A batch of `N` foreground probability maps
of shape `(N, 1, H, W)`.
frame_type: A list of `N` strings containing frame type indicators
which specify target and source views.
Returns:
preds: A dict with the following fields:
nvs_prediction: The rendered colors, depth and mask
of the target views.
point_cloud: The point cloud of the scene. It's renders are
stored in `nvs_prediction`.
"""
is_known = is_known_frame(frame_type)
is_known_idx = torch.where(is_known)[0]
mask_fg = (fg_probability > 0.5).type_as(image_rgb)
point_cloud = get_rgbd_point_cloud(
camera[is_known_idx],
image_rgb[is_known_idx],
depth_map[is_known_idx],
mask_fg[is_known_idx],
)
pcl_size = int(point_cloud.num_points_per_cloud())
if (self.max_points > 0) and (pcl_size > self.max_points):
prm = torch.randperm(pcl_size)[: self.max_points]
point_cloud = Pointclouds(
point_cloud.points_padded()[:, prm, :],
# pyre-fixme[16]: Optional type has no attribute `__getitem__`.
features=point_cloud.features_padded()[:, prm, :],
)
is_target_idx = torch.where(~is_known)[0]
depth_render, image_render, mask_render = [], [], []
# render into target frames in a for loop to save memory
for tgt_idx in is_target_idx:
_image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d(
camera[int(tgt_idx)],
point_cloud,
render_size=(self.image_size, self.image_size),
point_radius=1e-2,
topk=10,
bg_color=self.bg_color,
)
_image_render = _image_render.clamp(0.0, 1.0)
# the mask is the set of pixels with opacity bigger than eps
_mask_render = (_mask_render > 1e-4).float()
depth_render.append(_depth_render)
image_render.append(_image_render)
mask_render.append(_mask_render)
nvs_prediction = NewViewSynthesisPrediction(
**{
k: torch.cat(v, dim=0)
for k, v in zip(
["depth_render", "image_render", "mask_render"],
[depth_render, image_render, mask_render],
)
}
)
preds = {
"nvs_prediction": nvs_prediction,
"point_cloud": point_cloud,
}
return preds

View File

@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from pytorch3d.implicitron.tools.config import ReplaceableBase
class EvaluationMode(Enum):
TRAINING = "training"
EVALUATION = "evaluation"
class RenderSamplingMode(Enum):
MASK_SAMPLE = "mask_sample"
FULL_GRID = "full_grid"
@dataclass
class RendererOutput:
"""
A structure for storing the output of a renderer.
Args:
features: rendered features (usually RGB colors), (B, ..., C) tensor.
depth: rendered ray-termination depth map, in NDC coordinates, (B, ..., 1) tensor.
mask: rendered object mask, values in [0, 1], (B, ..., 1) tensor.
prev_stage: for multi-pass renderers (e.g. in NeRF),
a reference to the output of the previous stage.
normals: surface normals, for renderers that estimate them; (B, ..., 3) tensor.
points: ray-termination points in the world coordinates, (B, ..., 3) tensor.
aux: dict for implementation-specific renderer outputs.
"""
features: torch.Tensor
depths: torch.Tensor
masks: torch.Tensor
prev_stage: Optional[RendererOutput] = None
normals: Optional[torch.Tensor] = None
points: Optional[torch.Tensor] = None # TODO: redundant with depths
aux: Dict[str, Any] = field(default_factory=lambda: {})
class ImplicitFunctionWrapper(torch.nn.Module):
def __init__(self, fn: torch.nn.Module):
super().__init__()
self._fn = fn
self.bound_args = {}
def bind_args(self, **bound_args):
self.bound_args = bound_args
self._fn.on_bind_args()
def unbind_args(self):
self.bound_args = {}
def forward(self, *args, **kwargs):
return self._fn(*args, **{**kwargs, **self.bound_args})
class BaseRenderer(ABC, ReplaceableBase):
"""
Base class for all Renderer implementations.
"""
def __init__(self):
super().__init__()
@abstractmethod
def forward(
self,
ray_bundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs
) -> RendererOutput:
"""
Each Renderer should implement its own forward function
that returns an instance of RendererOutput.
Args:
ray_bundle: A RayBundle object containing the following variables:
origins: A tensor of shape (minibatch, ..., 3) denoting
the origins of the rendering rays.
directions: A tensor of shape (minibatch, ..., 3)
containing the direction vectors of rendering rays.
lengths: A tensor of shape
(minibatch, ..., num_points_per_ray)containing the
lengths at which the ray points are sampled.
The coordinates of the points on the rays are thus computed
as `origins + lengths * directions`.
xys: A tensor of shape
(minibatch, ..., 2) containing the
xy locations of each ray's pixel in the NDC screen space.
implicit_functions: List of ImplicitFunctionWrappers which define the
implicit function methods to be used. Most Renderers only allow
a single implicit function. Currently, only the MultiPassEARenderer
allows specifying mulitple values in the list.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
**kwargs: In addition to the name args, custom keyword args can be specified.
For example in the SignedDistanceFunctionRenderer, an object_mask is
required which needs to be passed via the kwargs.
Returns:
instance of RendererOutput
"""
pass

View File

@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
@registry.register
class LSTMRenderer(BaseRenderer, torch.nn.Module):
"""
Implements the learnable LSTM raymarching function from SRN [1].
Settings:
num_raymarch_steps: The number of LSTM raymarching steps.
init_depth: Initializes the bias of the last raymarching LSTM layer so that
the farthest point from the camera reaches a far z-plane that
lies `init_depth` units from the camera plane.
init_depth_noise_std: The standard deviation of the random normal noise
added to the initial depth of each marched ray.
hidden_size: The dimensionality of the LSTM's hidden state.
n_feature_channels: The number of feature channels returned by the
implicit_function evaluated at each raymarching step.
verbose: If `True`, prints raymarching debug info.
References:
[1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G..
"Scene representation networks: Continuous 3d-structure-aware
neural scene representations." NeurIPS 2019.
"""
num_raymarch_steps: int = 10
init_depth: float = 17.0
init_depth_noise_std: float = 5e-4
hidden_size: int = 16
n_feature_channels: int = 256
verbose: bool = False
def __post_init__(self):
super().__init__()
self._lstm = torch.nn.LSTMCell(
input_size=self.n_feature_channels,
hidden_size=self.hidden_size,
)
self._lstm.apply(_init_recurrent_weights)
_lstm_forget_gate_init(self._lstm)
self._out_layer = torch.nn.Linear(self.hidden_size, 1)
one_step = self.init_depth / self.num_raymarch_steps
self._out_layer.bias.data.fill_(one_step)
self._out_layer.weight.data.normal_(mean=0.0, std=1e-3)
def forward(
self,
ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: A single-element list of ImplicitFunctionWrappers which
defines the implicit function to be used.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering, specifically the RayPointRefiner and the density_noise_std.
Returns:
instance of RendererOutput
"""
if len(implicit_functions) != 1:
raise ValueError("LSTM renderer expects a single implicit function.")
implicit_function = implicit_functions[0]
if ray_bundle.lengths.shape[-1] != 1:
raise ValueError(
"LSTM renderer requires a ray-bundle with a single point per ray"
+ " which is the initial raymarching point."
)
# jitter the initial depths
ray_bundle_t = ray_bundle._replace(
lengths=ray_bundle.lengths
+ torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std
)
states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None]
signed_distance = torch.zeros_like(ray_bundle_t.lengths)
raymarch_features = None
for t in range(self.num_raymarch_steps + 1):
# move signed_distance along each ray
ray_bundle_t = ray_bundle_t._replace(
lengths=ray_bundle_t.lengths + signed_distance
)
# eval the raymarching function
raymarch_features, _ = implicit_function(
ray_bundle_t,
raymarch_features=None,
)
if self.verbose:
# print some stats
print(
f"{t}: mu={float(signed_distance.mean()):1.2e};"
+ f" std={float(signed_distance.std()):1.2e};"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
# param but got `Tensor`.
+ f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};"
# pyre-fixme[6]: Expected `Union[bytearray, bytes, str,
# typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st
# param but got `Tensor`.
+ f" std_d={float(ray_bundle_t.lengths.std()):1.2e};"
)
if t == self.num_raymarch_steps:
break
# run the lstm marcher
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
state_h, state_c = self._lstm(
raymarch_features.view(-1, raymarch_features.shape[-1]),
states[-1],
)
if state_h.requires_grad:
state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
# predict the next step size
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
# log the lstm states
states.append((state_h, state_c))
opacity_logits, features = implicit_function(
raymarch_features=raymarch_features,
ray_bundle=ray_bundle_t,
)
mask = torch.sigmoid(opacity_logits)
depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm(
dim=-1, keepdim=True
)
return RendererOutput(
features=features[..., 0, :],
depths=depth,
masks=mask[..., 0, :],
)
def _init_recurrent_weights(self) -> None:
# copied from SRN codebase
for m in self.modules():
if type(m) in [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]:
for name, param in m.named_parameters():
if "weight_ih" in name:
torch.nn.init.kaiming_normal_(param.data)
elif "weight_hh" in name:
torch.nn.init.orthogonal_(param.data)
elif "bias" in name:
param.data.fill_(0)
def _lstm_forget_gate_init(lstm_layer) -> None:
# copied from SRN codebase
for name, parameter in lstm_layer.named_parameters():
if "bias" not in name:
continue
n = parameter.size(0)
start, end = n // 4, n // 2
parameter.data[start:end].fill_(1.0)

View File

@ -0,0 +1,171 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Tuple
import torch
from pytorch3d.implicitron.tools.config import registry
from .base import BaseRenderer, EvaluationMode, RendererOutput
from .ray_point_refiner import RayPointRefiner
from .raymarcher import GenericRaymarcher
@registry.register
class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module):
"""
Implements the multi-pass rendering function, in particular,
with emission-absorption ray marching used in NeRF [1]. First, it evaluates
opacity-based ray-point weights and then optionally (in case more implicit
functions are given) resamples points using importance sampling and evaluates
new weights.
During each ray marching pass, features, depth map, and masks
are integrated: Let o_i be the opacity estimated by the implicit function,
and d_i be the offset between points `i` and `i+1` along the respective ray.
Ray marching is performed using the following equations:
```
ray_opacity_n = cap_fn(sum_i=1^n cap_fn(d_i * o_i)),
weight_n = weight_fn(cap_fn(d_i * o_i), 1 - ray_opacity_{n-1}),
```
and the final rendered quantities are computed by a dot-product of ray values
with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`.
See below for possible values of `cap_fn` and `weight_fn`.
Settings:
n_pts_per_ray_fine_training: The number of points sampled per ray for the
fine rendering pass during training.
n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the
fine rendering pass during evaluation.
stratified_sampling_coarse_training: Enable/disable stratified sampling during
training.
stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during
evaluation.
append_coarse_samples_to_fine: Add the fine ray points to the coarse points
after sampling.
bg_color: The background color. A tuple of either 1 element or of D elements,
where D matches the feature dimensionality; it is broadcasted when necessary.
density_noise_std_train: Standard deviation of the noise added to the
opacity field.
capping_function: The capping function of the raymarcher.
Options:
- "exponential" (`cap_fn(x) = 1 - exp(-x)`)
- "cap1" (`cap_fn(x) = min(x, 1)`)
Set to "exponential" for the standard Emission Absorption raymarching.
weight_function: The weighting function of the raymarcher.
Options:
- "product" (`weight_fn(w, x) = w * x`)
- "minimum" (`weight_fn(w, x) = min(w, x)`)
Set to "product" for the standard Emission Absorption raymarching.
background_opacity: The raw opacity value (i.e. before exponentiation)
of the background.
blend_output: If `True`, alpha-blends the output renders with the
background color using the rendered opacity mask.
References:
[1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance
fields for view synthesis." ECCV 2020.
"""
n_pts_per_ray_fine_training: int = 64
n_pts_per_ray_fine_evaluation: int = 64
stratified_sampling_coarse_training: bool = True
stratified_sampling_coarse_evaluation: bool = False
append_coarse_samples_to_fine: bool = True
bg_color: Tuple[float, ...] = (0.0,)
density_noise_std_train: float = 0.0
capping_function: str = "exponential" # exponential | cap1
weight_function: str = "product" # product | minimum
background_opacity: float = 1e10
blend_output: bool = False
def __post_init__(self):
super().__init__()
self._refiners = {
EvaluationMode.TRAINING: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_training,
random_sampling=self.stratified_sampling_coarse_training,
add_input_samples=self.append_coarse_samples_to_fine,
),
EvaluationMode.EVALUATION: RayPointRefiner(
n_pts_per_ray=self.n_pts_per_ray_fine_evaluation,
random_sampling=self.stratified_sampling_coarse_evaluation,
add_input_samples=self.append_coarse_samples_to_fine,
),
}
self._raymarcher = GenericRaymarcher(
1,
self.bg_color,
capping_function=self.capping_function,
weight_function=self.weight_function,
background_opacity=self.background_opacity,
blend_output=self.blend_output,
)
def forward(
self,
ray_bundle,
implicit_functions=[],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
**kwargs
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: List of ImplicitFunctionWrappers which
define the implicit functions to be used sequentially in
the raymarching step. The output of raymarching with
implicit_functions[n-1] is refined, and then used as
input for raymarching with implicit_functions[n].
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering
Returns:
instance of RendererOutput
"""
if not implicit_functions:
raise ValueError("EA renderer expects implicit functions")
return self._run_raymarcher(
ray_bundle,
implicit_functions,
None,
evaluation_mode,
)
def _run_raymarcher(
self, ray_bundle, implicit_functions, prev_stage, evaluation_mode
):
density_noise_std = (
self.density_noise_std_train
if evaluation_mode == EvaluationMode.TRAINING
else 0.0
)
features, depth, mask, weights, aux = self._raymarcher(
*implicit_functions[0](ray_bundle),
ray_lengths=ray_bundle.lengths,
density_noise_std=density_noise_std,
)
output = RendererOutput(
features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage
)
# we may need to make a recursive call
if len(implicit_functions) > 1:
fine_ray_bundle = self._refiners[evaluation_mode](ray_bundle, weights)
output = self._run_raymarcher(
fine_ray_bundle,
implicit_functions[1:],
output,
evaluation_mode,
)
return output

View File

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields
from pytorch3d.renderer import RayBundle
from pytorch3d.renderer.implicit.sample_pdf import sample_pdf
@expand_args_fields
# pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized.
# pyre-fixme[13]: Attribute `random_sampling` is never initialized.
class RayPointRefiner(Configurable, torch.nn.Module):
"""
Implements the importance sampling of points along rays.
The input is a `RayBundle` object with a `ray_weights` tensor
which specifies the probabilities of sampling a point along each ray.
This raysampler is used for the fine rendering pass of NeRF.
As such, the forward pass accepts the RayBundle output by the
raysampling of the coarse rendering pass. Hence, it does not
take cameras as input.
Args:
n_pts_per_ray: The number of points to sample along each ray.
random_sampling: If `False`, returns equispaced percentiles of the
distribution defined by the input weights, otherwise performs
sampling from that distribution.
add_input_samples: Concatenates and returns the sampled values
together with the input samples.
"""
n_pts_per_ray: int
random_sampling: bool
add_input_samples: bool = True
def __post_init__(self) -> None:
super().__init__()
def forward(
self,
input_ray_bundle: RayBundle,
ray_weights: torch.Tensor,
**kwargs,
) -> RayBundle:
"""
Args:
input_ray_bundle: An instance of `RayBundle` specifying the
source rays for sampling of the probability distribution.
ray_weights: A tensor of shape
`(..., input_ray_bundle.legths.shape[-1])` with non-negative
elements defining the probability distribution to sample
ray points from.
Returns:
ray_bundle: A new `RayBundle` instance containing the input ray
points together with `n_pts_per_ray` additionally sampled
points per ray. For each ray, the lengths are sorted.
"""
z_vals = input_ray_bundle.lengths
with torch.no_grad():
z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5)
z_samples = sample_pdf(
z_vals_mid.view(-1, z_vals_mid.shape[-1]),
ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1],
self.n_pts_per_ray,
det=not self.random_sampling,
).view(*z_vals.shape[:-1], self.n_pts_per_ray)
if self.add_input_samples:
# Add the new samples to the input ones.
z_vals = torch.cat((z_vals, z_samples), dim=-1)
else:
z_vals = z_samples
# Resort by depth.
z_vals, _ = torch.sort(z_vals, dim=-1)
return RayBundle(
origins=input_ray_bundle.origins,
directions=input_ray_bundle.directions,
lengths=z_vals,
xys=input_ray_bundle.xys,
)

View File

@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Optional, Tuple
import torch
from pytorch3d.implicitron.tools import camera_utils
from pytorch3d.implicitron.tools.config import Configurable
from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle
from pytorch3d.renderer.cameras import CamerasBase
from .base import EvaluationMode, RenderSamplingMode
class RaySampler(Configurable, torch.nn.Module):
"""
Samples a fixed number of points along rays which are in turn sampled for
each camera in a batch.
This class utilizes `NDCMultinomialRaysampler` which allows to either
randomly sample rays from an input foreground saliency mask
(`RenderSamplingMode.MASK_SAMPLE`), or on a rectangular image grid
(`RenderSamplingMode.FULL_GRID`). The sampling mode can be set separately
for training and evaluation by setting `self.sampling_mode_training`
and `self.sampling_mode_training` accordingly.
The class allows two modes of sampling points along the rays:
1) Sampling between fixed near and far z-planes:
Active when `self.scene_extent <= 0`, samples points along each ray
with approximately uniform spacing of z-coordinates between
the minimum depth `self.min_depth` and the maximum depth `self.max_depth`.
This sampling is useful for rendering scenes where the camera is
in a constant distance from the focal point of the scene.
2) Adaptive near/far plane estimation around the world scene center:
Active when `self.scene_extent > 0`. Samples points on each
ray between near and far planes whose depths are determined based on
the distance from the camera center to a predefined scene center.
More specifically,
`min_depth = max(
(self.scene_center-camera_center).norm() - self.scene_extent, eps
)` and
`max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`.
This sampling is ideal for object-centric scenes whose contents are
centered around a known `self.scene_center` and fit into a bounding sphere
with a radius of `self.scene_extent`.
Similar to the sampling mode, the sampling parameters can be set separately
for training and evaluation.
Settings:
image_width: The horizontal size of the image grid.
image_height: The vertical size of the image grid.
scene_center: The xyz coordinates of the center of the scene used
along with `scene_extent` to compute the min and max depth planes
for sampling ray-points.
scene_extent: The radius of the scene bounding sphere centered at `scene_center`.
If `scene_extent <= 0`, the raysampler samples points between
`self.min_depth` and `self.max_depth` depths instead.
sampling_mode_training: The ray sampling mode for training. This should be a str
option from the RenderSamplingMode Enum
sampling_mode_evaluation: Same as above but for evaluation.
n_pts_per_ray_training: The number of points sampled along each ray during training.
n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation.
n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid
min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`.
max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`.
stratified_point_sampling_training: if set, performs stratified random sampling
along the ray; otherwise takes ray points at deterministic offsets.
stratified_point_sampling_evaluation: Same as above but for evaluation.
"""
image_width: int = 400
image_height: int = 400
scene_center: Tuple[float, float, float] = field(
default_factory=lambda: (0.0, 0.0, 0.0)
)
scene_extent: float = 0.0
sampling_mode_training: str = "mask_sample"
sampling_mode_evaluation: str = "full_grid"
n_pts_per_ray_training: int = 64
n_pts_per_ray_evaluation: int = 64
n_rays_per_image_sampled_from_mask: int = 1024
min_depth: float = 0.1
max_depth: float = 8.0
# stratified sampling vs taking points at deterministic offsets
stratified_point_sampling_training: bool = True
stratified_point_sampling_evaluation: bool = False
def __post_init__(self):
super().__init__()
self.scene_center = torch.FloatTensor(self.scene_center)
self._sampling_mode = {
EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training),
EvaluationMode.EVALUATION: RenderSamplingMode(
self.sampling_mode_evaluation
),
}
self._raysamplers = {
EvaluationMode.TRAINING: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_training,
min_depth=self.min_depth,
max_depth=self.max_depth,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.TRAINING]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_training,
),
EvaluationMode.EVALUATION: NDCMultinomialRaysampler(
image_width=self.image_width,
image_height=self.image_height,
n_pts_per_ray=self.n_pts_per_ray_evaluation,
min_depth=self.min_depth,
max_depth=self.max_depth,
n_rays_per_image=self.n_rays_per_image_sampled_from_mask
if self._sampling_mode[EvaluationMode.EVALUATION]
== RenderSamplingMode.MASK_SAMPLE
else None,
unit_directions=True,
stratified_sampling=self.stratified_point_sampling_evaluation,
),
}
def forward(
self,
cameras: CamerasBase,
evaluation_mode: EvaluationMode,
mask: Optional[torch.Tensor] = None,
) -> RayBundle:
"""
Args:
cameras: A batch of `batch_size` cameras from which the rays are emitted.
evaluation_mode: one of `EvaluationMode.TRAINING` or
`EvaluationMode.EVALUATION` which determines the sampling mode
that is used.
mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode.
Defines a non-negative mask of shape
`(batch_size, image_height, image_width)` where each per-pixel
value is proportional to the probability of sampling the
corresponding pixel's ray.
Returns:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
"""
sample_mask = None
if (
# pyre-fixme[29]
self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
and mask is not None
):
sample_mask = torch.nn.functional.interpolate(
mask,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[int]`.
size=[self.image_height, self.image_width],
mode="nearest",
)[:, 0]
if self.scene_extent > 0.0:
# Override the min/max depth set in initialization based on the
# input cameras.
min_depth, max_depth = camera_utils.get_min_max_depth_bounds(
cameras, self.scene_center, self.scene_extent
)
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.Module]` is not a function.
ray_bundle = self._raysamplers[evaluation_mode](
cameras=cameras,
mask=sample_mask,
min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None,
max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None,
)
return ray_bundle

View File

@ -0,0 +1,573 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr
# Copyright (c) 2020 Lior Yariv
from typing import Any, Callable, Tuple
import torch
import torch.nn as nn
from pytorch3d.implicitron.tools.config import Configurable
class RayTracing(Configurable, nn.Module):
"""
Finds the intersection points of rays with the implicit surface defined
by a signed distance function (SDF). The algorithm follows the pipeline:
1. Initialise start and end points on rays by the intersections with
the circumscribing sphere.
2. Run sphere tracing from both ends.
3. Divide the untraced segments of non-convergent rays into uniform
intervals and find the one with the sign transition.
4. Run the secant method to estimate the point of the sign transition.
Args:
object_bounding_sphere: The radius of the initial sphere circumscribing
the object.
sdf_threshold: Absolute SDF value small enough for the sphere tracer
to consider it a surface.
line_search_step: Length of the backward correction on sphere tracing
iterations.
line_step_iters: Number of backward correction iterations.
sphere_tracing_iters: Maximum number of sphere tracing iterations
(the actual number of iterations may be smaller if all ray
intersections are found).
n_steps: Number of intervals sampled for unconvergent rays.
n_secant_steps: Number of iterations in the secant algorithm.
"""
object_bounding_sphere: float = 1.0
sdf_threshold: float = 5.0e-5
line_search_step: float = 0.5
line_step_iters: int = 1
sphere_tracing_iters: int = 10
n_steps: int = 100
n_secant_steps: int = 8
def __post_init__(self):
super().__init__()
def forward(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
object_mask: torch.BoolTensor,
ray_directions: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
sdf: A callable that takes a (N, 3) tensor of points and returns
a tensor of (N,) SDF values.
cam_loc: A tensor of (B, N, 3) ray origins.
object_mask: A (N, 3) tensor of indicators whether a sampled pixel
corresponds to the rendered object or background.
ray_directions: A tensor of (B, N, 3) ray directions.
Returns:
curr_start_points: A tensor of (B*N, 3) found intersection points
with the implicit surface.
network_object_mask: A tensor of (B*N,) indicators denoting whether
intersections were found.
acc_start_dis: A tensor of (B*N,) distances from the ray origins
to intersrection points.
"""
batch_size, num_pixels, _ = ray_directions.shape
device = cam_loc.device
sphere_intersections, mask_intersect = _get_sphere_intersection(
cam_loc, ray_directions, r=self.object_bounding_sphere
)
(
curr_start_points,
unfinished_mask_start,
acc_start_dis,
acc_end_dis,
min_dis,
max_dis,
) = self.sphere_tracing(
batch_size,
num_pixels,
sdf,
cam_loc,
ray_directions,
mask_intersect,
sphere_intersections,
)
network_object_mask = acc_start_dis < acc_end_dis
# The non convergent rays should be handled by the sampler
sampler_mask = unfinished_mask_start
sampler_net_obj_mask = torch.zeros_like(
sampler_mask, dtype=torch.bool, device=device
)
if sampler_mask.sum() > 0:
sampler_min_max = torch.zeros((batch_size, num_pixels, 2), device=device)
sampler_min_max.reshape(-1, 2)[sampler_mask, 0] = acc_start_dis[
sampler_mask
]
sampler_min_max.reshape(-1, 2)[sampler_mask, 1] = acc_end_dis[sampler_mask]
sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler(
sdf, cam_loc, object_mask, ray_directions, sampler_min_max, sampler_mask
)
curr_start_points[sampler_mask] = sampler_pts[sampler_mask]
acc_start_dis[sampler_mask] = sampler_dists[sampler_mask]
network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask]
if not self.training:
return curr_start_points, network_object_mask, acc_start_dis
# in case we are training, we are updating curr_start_points and acc_start_dis for
ray_directions = ray_directions.reshape(-1, 3)
mask_intersect = mask_intersect.reshape(-1)
object_mask = object_mask.reshape(-1)
in_mask = ~network_object_mask & object_mask & ~sampler_mask
out_mask = ~object_mask & ~sampler_mask
# pyre-fixme[16]: `Tensor` has no attribute `__invert__`.
mask_left_out = (in_mask | out_mask) & ~mask_intersect
if (
mask_left_out.sum() > 0
): # project the origin to the not intersect points on the sphere
cam_left_out = cam_loc.reshape(-1, 3)[mask_left_out]
rays_left_out = ray_directions[mask_left_out]
acc_start_dis[mask_left_out] = -torch.bmm(
rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3, 1)
).squeeze()
curr_start_points[mask_left_out] = (
cam_left_out + acc_start_dis[mask_left_out].unsqueeze(1) * rays_left_out
)
mask = (in_mask | out_mask) & mask_intersect
if mask.sum() > 0:
min_dis[network_object_mask & out_mask] = acc_start_dis[
network_object_mask & out_mask
]
min_mask_points, min_mask_dist = self.minimal_sdf_points(
sdf, cam_loc, ray_directions, mask, min_dis, max_dis
)
curr_start_points[mask] = min_mask_points
acc_start_dis[mask] = min_mask_dist
return curr_start_points, network_object_mask, acc_start_dis
def sphere_tracing(
self,
batch_size: int,
num_pixels: int,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
ray_directions: torch.Tensor,
mask_intersect: torch.Tensor,
sphere_intersections: torch.Tensor,
) -> Tuple[Any, Any, Any, Any, Any, Any]:
"""
Run sphere tracing algorithm for max iterations
from both sides of unit sphere intersection
Args:
batch_size:
num_pixels:
sdf:
cam_loc:
ray_directions:
mask_intersect:
sphere_intersections:
Returns:
curr_start_points:
unfinished_mask_start:
acc_start_dis:
acc_end_dis:
min_dis:
max_dis:
"""
device = cam_loc.device
sphere_intersections_points = (
cam_loc[..., None, :]
+ sphere_intersections[..., None] * ray_directions[..., None, :]
)
unfinished_mask_start = mask_intersect.reshape(-1).clone()
unfinished_mask_end = mask_intersect.reshape(-1).clone()
# Initialize start current points
curr_start_points = torch.zeros(batch_size * num_pixels, 3, device=device)
curr_start_points[unfinished_mask_start] = sphere_intersections_points[
:, :, 0, :
].reshape(-1, 3)[unfinished_mask_start]
acc_start_dis = torch.zeros(batch_size * num_pixels, device=device)
acc_start_dis[unfinished_mask_start] = sphere_intersections.reshape(-1, 2)[
unfinished_mask_start, 0
]
# Initialize end current points
curr_end_points = torch.zeros(batch_size * num_pixels, 3, device=device)
curr_end_points[unfinished_mask_end] = sphere_intersections_points[
:, :, 1, :
].reshape(-1, 3)[unfinished_mask_end]
acc_end_dis = torch.zeros(batch_size * num_pixels, device=device)
acc_end_dis[unfinished_mask_end] = sphere_intersections.reshape(-1, 2)[
unfinished_mask_end, 1
]
# Initialise min and max depth
min_dis = acc_start_dis.clone()
max_dis = acc_end_dis.clone()
# Iterate on the rays (from both sides) till finding a surface
iters = 0
# TODO: sdf should also pass info about batches
next_sdf_start = torch.zeros_like(acc_start_dis)
next_sdf_start[unfinished_mask_start] = sdf(
curr_start_points[unfinished_mask_start]
)
next_sdf_end = torch.zeros_like(acc_end_dis)
next_sdf_end[unfinished_mask_end] = sdf(curr_end_points[unfinished_mask_end])
while True:
# Update sdf
curr_sdf_start = torch.zeros_like(acc_start_dis)
curr_sdf_start[unfinished_mask_start] = next_sdf_start[
unfinished_mask_start
]
curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0
curr_sdf_end = torch.zeros_like(acc_end_dis)
curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end]
curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0
# Update masks
unfinished_mask_start = unfinished_mask_start & (
curr_sdf_start > self.sdf_threshold
)
unfinished_mask_end = unfinished_mask_end & (
curr_sdf_end > self.sdf_threshold
)
if (
unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0
) or iters == self.sphere_tracing_iters:
break
iters += 1
# Make step
# Update distance
acc_start_dis = acc_start_dis + curr_sdf_start
acc_end_dis = acc_end_dis - curr_sdf_end
# Update points
curr_start_points = (
cam_loc
+ acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)
curr_end_points = (
cam_loc
+ acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)
# Fix points which wrongly crossed the surface
next_sdf_start = torch.zeros_like(acc_start_dis)
next_sdf_start[unfinished_mask_start] = sdf(
curr_start_points[unfinished_mask_start]
)
next_sdf_end = torch.zeros_like(acc_end_dis)
next_sdf_end[unfinished_mask_end] = sdf(
curr_end_points[unfinished_mask_end]
)
not_projected_start = next_sdf_start < 0
not_projected_end = next_sdf_end < 0
not_proj_iters = 0
while (
not_projected_start.sum() > 0 or not_projected_end.sum() > 0
) and not_proj_iters < self.line_step_iters:
# Step backwards
acc_start_dis[not_projected_start] -= (
(1 - self.line_search_step) / (2 ** not_proj_iters)
) * curr_sdf_start[not_projected_start]
curr_start_points[not_projected_start] = (
cam_loc
+ acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)[not_projected_start]
acc_end_dis[not_projected_end] += (
(1 - self.line_search_step) / (2 ** not_proj_iters)
) * curr_sdf_end[not_projected_end]
curr_end_points[not_projected_end] = (
cam_loc
+ acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions
).reshape(-1, 3)[not_projected_end]
# Calc sdf
next_sdf_start[not_projected_start] = sdf(
curr_start_points[not_projected_start]
)
next_sdf_end[not_projected_end] = sdf(
curr_end_points[not_projected_end]
)
# Update mask
not_projected_start = next_sdf_start < 0
not_projected_end = next_sdf_end < 0
not_proj_iters += 1
unfinished_mask_start = unfinished_mask_start & (
acc_start_dis < acc_end_dis
)
unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis)
return (
curr_start_points,
unfinished_mask_start,
acc_start_dis,
acc_end_dis,
min_dis,
max_dis,
)
def ray_sampler(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
object_mask: torch.Tensor,
ray_directions: torch.Tensor,
sampler_min_max: torch.Tensor,
sampler_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample the ray in a given range and run secant on rays which have sign transition.
Args:
sdf:
cam_loc:
object_mask:
ray_directions:
sampler_min_max:
sampler_mask:
Returns:
"""
batch_size, num_pixels, _ = ray_directions.shape
device = cam_loc.device
n_total_pxl = batch_size * num_pixels
sampler_pts = torch.zeros(n_total_pxl, 3, device=device)
sampler_dists = torch.zeros(n_total_pxl, device=device)
intervals_dist = torch.linspace(0, 1, steps=self.n_steps, device=device).view(
1, 1, -1
)
pts_intervals = sampler_min_max[:, :, 0].unsqueeze(-1) + intervals_dist * (
sampler_min_max[:, :, 1] - sampler_min_max[:, :, 0]
).unsqueeze(-1)
points = (
cam_loc[..., None, :]
+ pts_intervals[..., None] * ray_directions[..., None, :]
)
# Get the non convergent rays
mask_intersect_idx = torch.nonzero(sampler_mask).flatten()
points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :]
pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask]
sdf_val_all = []
for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0):
sdf_val_all.append(sdf(pnts))
sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps)
tmp = torch.sign(sdf_val) * torch.arange(
self.n_steps, 0, -1, device=device, dtype=torch.float32
).reshape(1, self.n_steps)
# Force argmin to return the first min value
sampler_pts_ind = torch.argmin(tmp, -1)
sampler_pts[mask_intersect_idx] = points[
torch.arange(points.shape[0]), sampler_pts_ind, :
]
sampler_dists[mask_intersect_idx] = pts_intervals[
torch.arange(pts_intervals.shape[0]), sampler_pts_ind
]
true_surface_pts = object_mask.reshape(-1)[sampler_mask]
net_surface_pts = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0
# take points with minimal SDF value for P_out pixels
p_out_mask = ~(true_surface_pts & net_surface_pts)
n_p_out = p_out_mask.sum()
if n_p_out > 0:
out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1)
sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][
torch.arange(n_p_out), out_pts_idx, :
]
sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[
p_out_mask, :
][torch.arange(n_p_out), out_pts_idx]
# Get Network object mask
sampler_net_obj_mask = sampler_mask.clone()
sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False
# Run Secant method
secant_pts = (
net_surface_pts & true_surface_pts if self.training else net_surface_pts
)
n_secant_pts = secant_pts.sum()
if n_secant_pts > 0:
# Get secant z predictions
z_high = pts_intervals[
torch.arange(pts_intervals.shape[0]), sampler_pts_ind
][secant_pts]
sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][
secant_pts
]
z_low = pts_intervals[secant_pts][
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
]
sdf_low = sdf_val[secant_pts][
torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1
]
cam_loc_secant = cam_loc.reshape(-1, 3)[mask_intersect_idx[secant_pts]]
ray_directions_secant = ray_directions.reshape((-1, 3))[
mask_intersect_idx[secant_pts]
]
z_pred_secant = self.secant(
sdf_low,
sdf_high,
z_low,
z_high,
cam_loc_secant,
ray_directions_secant,
# pyre-fixme[6]: For 7th param expected `Module` but got `(Tensor)
# -> Tensor`.
sdf,
)
# Get points
sampler_pts[mask_intersect_idx[secant_pts]] = (
cam_loc_secant + z_pred_secant.unsqueeze(-1) * ray_directions_secant
)
sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant
return sampler_pts, sampler_net_obj_mask, sampler_dists
def secant(
self,
sdf_low: torch.Tensor,
sdf_high: torch.Tensor,
z_low: torch.Tensor,
z_high: torch.Tensor,
cam_loc: torch.Tensor,
ray_directions: torch.Tensor,
sdf: nn.Module,
) -> torch.Tensor:
"""
Runs the secant method for interval [z_low, z_high] for n_secant_steps
"""
z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
for _ in range(self.n_secant_steps):
p_mid = cam_loc + z_pred.unsqueeze(-1) * ray_directions
sdf_mid = sdf(p_mid)
ind_low = sdf_mid > 0
if ind_low.sum() > 0:
z_low[ind_low] = z_pred[ind_low]
sdf_low[ind_low] = sdf_mid[ind_low]
ind_high = sdf_mid < 0
if ind_high.sum() > 0:
z_high[ind_high] = z_pred[ind_high]
sdf_high[ind_high] = sdf_mid[ind_high]
z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low
return z_pred
def minimal_sdf_points(
self,
sdf: Callable[[torch.Tensor], torch.Tensor],
cam_loc: torch.Tensor,
ray_directions: torch.Tensor,
mask: torch.Tensor,
min_dis: torch.Tensor,
max_dis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Find points with minimal SDF value on rays for P_out pixels
"""
n_mask_points = mask.sum()
n = self.n_steps
steps = torch.empty(n, device=cam_loc.device).uniform_(0.0, 1.0)
mask_max_dis = max_dis[mask].unsqueeze(-1)
mask_min_dis = min_dis[mask].unsqueeze(-1)
steps = (
steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis)
+ mask_min_dis
)
mask_points = cam_loc.reshape(-1, 3)[mask]
mask_rays = ray_directions[mask, :]
mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze(
-1
) * mask_rays.unsqueeze(1).repeat(1, n, 1)
points = mask_points_all.reshape(-1, 3)
mask_sdf_all = []
for pnts in torch.split(points, 100000, dim=0):
mask_sdf_all.append(sdf(pnts))
mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n)
min_vals, min_idx = mask_sdf_all.min(-1)
min_mask_points = mask_points_all.reshape(-1, n, 3)[
torch.arange(0, n_mask_points), min_idx
]
min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx]
return min_mask_points, min_mask_dist
# TODO: support variable origins
def _get_sphere_intersection(
cam_loc: torch.Tensor, ray_directions: torch.Tensor, r: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
# Input: n_images x 3 ; n_images x n_rays x 3
# Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays
n_imgs, n_pix, _ = ray_directions.shape
device = cam_loc.device
# cam_loc = cam_loc.unsqueeze(-1)
# ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze()
ray_cam_dot = (ray_directions * cam_loc).sum(-1) # n_images x n_rays
under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, dim=-1) ** 2 - r ** 2)
under_sqrt = under_sqrt.reshape(-1)
mask_intersect = under_sqrt > 0
sphere_intersections = torch.zeros(n_imgs * n_pix, 2, device=device)
sphere_intersections[mask_intersect] = torch.sqrt(
under_sqrt[mask_intersect]
).unsqueeze(-1) * torch.tensor([-1.0, 1.0], device=device)
sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[
mask_intersect
].unsqueeze(-1)
sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2)
sphere_intersections = sphere_intersections.clamp_min(0.0)
mask_intersect = mask_intersect.reshape(n_imgs, n_pix)
return sphere_intersections, mask_intersect

View File

@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, Dict, Tuple, Union
import torch
from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs
_TTensor = torch.Tensor
class GenericRaymarcher(torch.nn.Module):
"""
This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher`
and NeuralVolumes' Accumulative ray marcher. It additionally returns
the rendering weights that can be used in the NVS pipeline to carry out
the importance ray-sampling in the refining pass.
Different from `EmissionAbsorptionRaymarcher`, it takes raw
(non-exponentiated) densities.
Args:
bg_color: background_color. Must be of shape (1,) or (feature_dim,)
"""
def __init__(
self,
surface_thickness: int = 1,
bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,),
capping_function: str = "exponential", # exponential | cap1
weight_function: str = "product", # product | minimum
background_opacity: float = 0.0,
density_relu: bool = True,
blend_output: bool = True,
):
"""
Args:
surface_thickness: Denotes the overlap between the absorption
function and the density function.
"""
super().__init__()
self.surface_thickness = surface_thickness
self.density_relu = density_relu
self.background_opacity = background_opacity
self.blend_output = blend_output
if not isinstance(bg_color, torch.Tensor):
bg_color = torch.tensor(bg_color)
if bg_color.ndim != 1:
raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor")
self.register_buffer("_bg_color", bg_color, persistent=False)
self._capping_function: Callable[[_TTensor], _TTensor] = {
"exponential": lambda x: 1.0 - torch.exp(-x),
"cap1": lambda x: x.clamp(max=1.0),
}[capping_function]
self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = {
"product": lambda curr, acc: curr * acc,
"minimum": lambda curr, acc: torch.minimum(curr, acc),
}[weight_function]
def forward(
self,
rays_densities: torch.Tensor,
rays_features: torch.Tensor,
aux: Dict[str, Any],
ray_lengths: torch.Tensor,
density_noise_std: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]:
"""
Args:
rays_densities: Per-ray density values represented with a tensor
of shape `(..., n_points_per_ray, 1)`.
rays_features: Per-ray feature values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
aux: a dictionary with extra information.
ray_lengths: Per-ray depth values represented with a tensor
of shape `(..., n_points_per_ray, feature_dim)`.
density_noise_std: the magnitude of the noise added to densities.
Returns:
features: A tensor of shape `(..., feature_dim)` containing
the rendered features for each ray.
depth: A tensor of shape `(..., 1)` containing estimated depth.
opacities: A tensor of shape `(..., 1)` containing rendered opacsities.
weights: A tensor of shape `(..., n_points_per_ray)` containing
the ray-specific non-negative opacity weights. In general, they
don't sum to 1 but do not overcome it, i.e.
`(weights.sum(dim=-1) <= 1.0).all()` holds.
"""
_check_raymarcher_inputs(
rays_densities,
rays_features,
ray_lengths,
z_can_be_none=True,
features_can_be_none=False,
density_1d=True,
)
deltas = torch.cat(
(
ray_lengths[..., 1:] - ray_lengths[..., :-1],
self.background_opacity * torch.ones_like(ray_lengths[..., :1]),
),
dim=-1,
)
rays_densities = rays_densities[..., 0]
if density_noise_std > 0.0:
rays_densities = (
rays_densities + torch.randn_like(rays_densities) * density_noise_std
)
if self.density_relu:
rays_densities = torch.relu(rays_densities)
weighted_densities = deltas * rays_densities
capped_densities = self._capping_function(weighted_densities)
rays_opacities = self._capping_function(
torch.cumsum(weighted_densities, dim=-1)
)
opacities = rays_opacities[..., -1:]
absorption_shifted = (-rays_opacities + 1.0).roll(
self.surface_thickness, dims=-1
)
absorption_shifted[..., : self.surface_thickness] = 1.0
weights = self._weight_function(capped_densities, absorption_shifted)
features = (weights[..., None] * rays_features).sum(dim=-2)
depth = (weights * ray_lengths)[..., None].sum(dim=-2)
alpha = opacities if self.blend_output else 1
if self._bg_color.shape[-1] not in [1, features.shape[-1]]:
raise ValueError("Wrong number of background color channels.")
features = alpha * features + (1 - opacities) * self._bg_color
return features, depth, opacities, weights, aux

View File

@ -0,0 +1,101 @@
# @lint-ignore-every LICENSELINT
# Adapted from RenderingNetwork from IDR
# https://github.com/lioryariv/idr/
# Copyright (c) 2020 Lior Yariv
import torch
from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle
from torch import nn
class RayNormalColoringNetwork(torch.nn.Module):
def __init__(
self,
feature_vector_size=3,
mode="idr",
d_in=9,
d_out=3,
dims=(512, 512, 512, 512),
weight_norm=True,
n_harmonic_functions_dir=0,
pooled_feature_dim=0,
):
super().__init__()
self.mode = mode
self.output_dimensions = d_out
dims = [d_in + feature_vector_size] + list(dims) + [d_out]
self.embedview_fn = None
if n_harmonic_functions_dir > 0:
self.embedview_fn = HarmonicEmbedding(
n_harmonic_functions_dir, append_input=True
)
dims[0] += self.embedview_fn.get_output_dim() - 3
if pooled_feature_dim > 0:
print("Pooled features in rendering network.")
dims[0] += pooled_feature_dim
self.num_layers = len(dims)
layers = []
for layer_idx in range(self.num_layers - 1):
out_dim = dims[layer_idx + 1]
lin = nn.Linear(dims[layer_idx], out_dim)
if weight_norm:
lin = nn.utils.weight_norm(lin)
layers.append(lin)
self.linear_layers = torch.nn.ModuleList(layers)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(
self,
feature_vectors: torch.Tensor,
points,
normals,
ray_bundle: RayBundle,
masks=None,
pooling_fn=None,
):
if masks is not None and not masks.any():
return torch.zeros_like(normals)
view_dirs = ray_bundle.directions
if masks is not None:
# in case of IDR, other outputs are passed here after applying the mask
view_dirs = view_dirs.reshape(view_dirs.shape[0], -1, 3)[
:, masks.reshape(-1)
]
if self.embedview_fn is not None:
view_dirs = self.embedview_fn(view_dirs)
if self.mode == "idr":
rendering_input = torch.cat(
[points, view_dirs, normals, feature_vectors], dim=-1
)
elif self.mode == "no_view_dir":
rendering_input = torch.cat([points, normals, feature_vectors], dim=-1)
elif self.mode == "no_normal":
rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
else:
raise ValueError(f"Unsupported rendering mode: {self.mode}")
if pooling_fn is not None:
featspool = pooling_fn(points[None])[0]
rendering_input = torch.cat((rendering_input, featspool), dim=-1)
x = rendering_input
for layer_idx in range(self.num_layers - 1):
x = self.linear_layers[layer_idx](x)
if layer_idx < self.num_layers - 2:
x = self.relu(x)
x = self.tanh(x)
return x

View File

@ -0,0 +1,253 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/
# implicit_differentiable_renderer.py
# Copyright (c) 2020 Lior Yariv
import functools
import math
from typing import List, Optional, Tuple
import torch
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import get_default_args_field, registry
from pytorch3d.implicitron.tools.utils import evaluating
from pytorch3d.renderer import RayBundle
from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput
from .ray_tracing import RayTracing
from .rgb_net import RayNormalColoringNetwork
@registry.register
class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):
render_features_dimensions: int = 3
ray_tracer_args: DictConfig = get_default_args_field(RayTracing)
ray_normal_coloring_network_args: DictConfig = get_default_args_field(
RayNormalColoringNetwork
)
bg_color: Tuple[float, ...] = (0.0,)
soft_mask_alpha: float = 50.0
def __post_init__(
self,
):
super().__init__()
render_features_dimensions = self.render_features_dimensions
if len(self.bg_color) not in [1, render_features_dimensions]:
raise ValueError(
f"Background color should have {render_features_dimensions} entries."
)
self.ray_tracer = RayTracing(**self.ray_tracer_args)
self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere")
self.ray_normal_coloring_network_args[
"feature_vector_size"
] = render_features_dimensions
self._rgb_network = RayNormalColoringNetwork(
**self.ray_normal_coloring_network_args
)
self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False)
def forward(
self,
ray_bundle: RayBundle,
implicit_functions: List[ImplicitFunctionWrapper],
evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION,
object_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> RendererOutput:
"""
Args:
ray_bundle: A `RayBundle` object containing the parametrizations of the
sampled rendering rays.
implicit_functions: single element list of ImplicitFunctionWrappers which
defines the implicit function to be used.
evaluation_mode: one of EvaluationMode.TRAINING or
EvaluationMode.EVALUATION which determines the settings used for
rendering.
kwargs:
object_mask: BoolTensor, denoting the silhouette of the object.
This is a required keyword argument for SignedDistanceFunctionRenderer
Returns:
instance of RendererOutput
"""
if len(implicit_functions) != 1:
raise ValueError(
"SignedDistanceFunctionRenderer supports only single pass."
)
if object_mask is None:
raise ValueError("Expected object_mask to be provided in the kwargs")
object_mask = object_mask.bool()
implicit_function = implicit_functions[0]
implicit_function_gradient = functools.partial(gradient, implicit_function)
# object_mask: silhouette of the object
batch_size, *spatial_size, _ = ray_bundle.lengths.shape
num_pixels = math.prod(spatial_size)
cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3)
ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3)
object_mask = object_mask.reshape(batch_size, -1)
with torch.no_grad(), evaluating(implicit_function):
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
points, network_object_mask, dists = self.ray_tracer(
sdf=lambda x: implicit_function(x)[
:, 0
], # TODO: get rid of this wrapper
cam_loc=cam_loc,
object_mask=object_mask,
ray_directions=ray_dirs,
)
# TODO: below, cam_loc might as well be different
depth = dists.reshape(batch_size, num_pixels, 1)
points = (cam_loc + depth * ray_dirs).reshape(-1, 3)
sdf_output = implicit_function(points)[:, 0:1]
# NOTE most of the intermediate variables are flattened for
# no apparent reason (here and in the ray tracer)
ray_dirs = ray_dirs.reshape(-1, 3)
object_mask = object_mask.reshape(-1)
# TODO: move it to loss computation
if evaluation_mode == EvaluationMode.TRAINING:
surface_mask = network_object_mask & object_mask
surface_points = points[surface_mask]
surface_dists = dists[surface_mask].unsqueeze(-1)
surface_ray_dirs = ray_dirs[surface_mask]
surface_cam_loc = cam_loc.reshape(-1, 3)[surface_mask]
surface_output = sdf_output[surface_mask]
N = surface_points.shape[0]
# Sample points for the eikonal loss
# pyre-fixme[9]
eik_bounding_box: float = self.object_bounding_sphere
n_eik_points = batch_size * num_pixels // 2
eikonal_points = torch.empty(
n_eik_points, 3, device=self._bg_color.device
).uniform_(-eik_bounding_box, eik_bounding_box)
eikonal_pixel_points = points.clone()
eikonal_pixel_points = eikonal_pixel_points.detach()
eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0)
points_all = torch.cat([surface_points, eikonal_points], dim=0)
output = implicit_function(surface_points)
surface_sdf_values = output[
:N, 0:1
].detach() # how is it different from sdf_output?
g = implicit_function_gradient(points_all)
surface_points_grad = g[:N, 0, :].clone().detach()
grad_theta = g[N:, 0, :]
differentiable_surface_points = _sample_network(
surface_output,
surface_sdf_values,
surface_points_grad,
surface_dists,
surface_cam_loc,
surface_ray_dirs,
)
else:
surface_mask = network_object_mask
differentiable_surface_points = points[surface_mask]
grad_theta = None
empty_render = differentiable_surface_points.shape[0] == 0
features = implicit_function(differentiable_surface_points)[None, :, 1:]
normals_full = features.new_zeros(
batch_size, *spatial_size, 3, requires_grad=empty_render
)
render_full = (
features.new_ones(
batch_size,
*spatial_size,
self.render_features_dimensions,
requires_grad=empty_render,
)
* self._bg_color
)
mask_full = features.new_ones(
batch_size, *spatial_size, 1, requires_grad=empty_render
)
if not empty_render:
normals = implicit_function_gradient(differentiable_surface_points)[
None, :, 0, :
]
normals_full.view(-1, 3)[surface_mask] = normals
render_full.view(-1, self.render_features_dimensions)[
surface_mask
] = self._rgb_network( # pyre-fixme[29]:
features,
differentiable_surface_points[None],
normals,
ray_bundle,
surface_mask[None, :, None],
pooling_fn=None, # TODO
)
mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid(
-self.soft_mask_alpha * sdf_output[~surface_mask]
)
# scatter points with surface_mask
points_full = ray_bundle.origins.detach().clone()
points_full.view(-1, 3)[surface_mask] = differentiable_surface_points
# TODO: it is sparse here but otherwise dense
return RendererOutput(
features=render_full,
normals=normals_full,
depths=depth.reshape(batch_size, *spatial_size, 1),
masks=mask_full, # this is a differentiable approximation, see (7) in the paper
points=points_full,
aux={"grad_theta": grad_theta}, # TODO: will be moved to eikonal loss
# TODO: do we need sdf_output, grad_theta? Only for loss probably
)
def _sample_network(
surface_output,
surface_sdf_values,
surface_points_grad,
surface_dists,
surface_cam_loc,
surface_ray_dirs,
eps=1e-4,
):
# t -> t(theta)
surface_ray_dirs_0 = surface_ray_dirs.detach()
surface_points_dot = torch.bmm(
surface_points_grad.view(-1, 1, 3), surface_ray_dirs_0.view(-1, 3, 1)
).squeeze(-1)
dot_sign = (surface_points_dot >= 0).to(surface_points_dot) * 2 - 1
surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / (
surface_points_dot.abs().clip(eps) * dot_sign
)
# t(theta) -> x(theta,c,v)
surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs
return surface_points_theta_c_v
@torch.enable_grad()
def gradient(module, x):
x.requires_grad_(True)
y = module.forward(x)[:, :1]
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
gradients = torch.autograd.grad(
outputs=y,
inputs=x,
grad_outputs=d_output,
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
return gradients.unsqueeze(1)

View File

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import math
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as Fu
import torchvision
from pytorch3d.implicitron.tools.config import Configurable
logger = logging.getLogger(__name__)
MASK_FEATURE_NAME = "mask"
IMAGE_FEATURE_NAME = "image"
_FEAT_DIMS = {
"resnet18": (64, 128, 256, 512),
"resnet34": (64, 128, 256, 512),
"resnet50": (256, 512, 1024, 2048),
"resnet101": (256, 512, 1024, 2048),
"resnet152": (256, 512, 1024, 2048),
}
_RESNET_MEAN = [0.485, 0.456, 0.406]
_RESNET_STD = [0.229, 0.224, 0.225]
class ResNetFeatureExtractor(Configurable, torch.nn.Module):
"""
Implements an image feature extractor. Depending on the settings allows
to extract:
- deep features: A CNN ResNet backbone from torchvision (with/without
pretrained weights) which extracts deep features.
- masks: Segmentation masks.
- images: Raw input RGB images.
Settings:
name: name of the resnet backbone (from torchvision)
pretrained: If true, will load the pretrained weights
stages: List of stages from which to extract features.
Features from each stage are returned as key value
pairs in the forward function
normalize_image: If set will normalize the RGB values of
the image based on the Resnet mean/std
image_rescale: If not 1.0, this rescale factor will be
used to resize the image
first_max_pool: If set, a max pool layer is added after the first
convolutional layer
proj_dim: The number of output channels for the convolutional layers
l2_norm: If set, l2 normalization is applied to the extracted features
add_masks: If set, the masks will be saved in the output dictionary
add_images: If set, the images will be saved in the output dictionary
global_average_pool: If set, global average pooling step is performed
feature_rescale: If not 1.0, this rescale factor will be used to
rescale the output features
"""
name: str = "resnet34"
pretrained: bool = True
stages: Tuple[int, ...] = (1, 2, 3, 4)
normalize_image: bool = True
image_rescale: float = 128 / 800.0
first_max_pool: bool = True
proj_dim: int = 32
l2_norm: bool = True
add_masks: bool = True
add_images: bool = True
global_average_pool: bool = False # this can simulate global/non-spacial features
feature_rescale: float = 1.0
def __post_init__(self):
super().__init__()
if self.normalize_image:
# register buffers needed to normalize the image
for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)):
self.register_buffer(
k,
torch.FloatTensor(v).view(1, 3, 1, 1),
persistent=False,
)
self._feat_dim = {}
if len(self.stages) == 0:
# do not extract any resnet features
pass
else:
net = getattr(torchvision.models, self.name)(pretrained=self.pretrained)
if self.first_max_pool:
self.stem = torch.nn.Sequential(
net.conv1, net.bn1, net.relu, net.maxpool
)
else:
self.stem = torch.nn.Sequential(net.conv1, net.bn1, net.relu)
self.max_stage = max(self.stages)
self.layers = torch.nn.ModuleList()
self.proj_layers = torch.nn.ModuleList()
for stage in range(self.max_stage):
stage_name = f"layer{stage+1}"
feature_name = self._get_resnet_stage_feature_name(stage)
if (stage + 1) in self.stages:
if (
self.proj_dim > 0
and _FEAT_DIMS[self.name][stage] > self.proj_dim
):
proj = torch.nn.Conv2d(
_FEAT_DIMS[self.name][stage],
self.proj_dim,
1,
1,
bias=True,
)
self._feat_dim[feature_name] = self.proj_dim
else:
proj = torch.nn.Identity()
self._feat_dim[feature_name] = _FEAT_DIMS[self.name][stage]
else:
proj = torch.nn.Identity()
self.proj_layers.append(proj)
self.layers.append(getattr(net, stage_name))
if self.add_masks:
self._feat_dim[MASK_FEATURE_NAME] = 1
if self.add_images:
self._feat_dim[IMAGE_FEATURE_NAME] = 3
logger.info(f"Feat extractor total dim = {self.get_feat_dims()}")
self.stages = set(self.stages) # convert to set for faster "in"
def _get_resnet_stage_feature_name(self, stage) -> str:
return f"res_layer_{stage+1}"
def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor:
return (img - self._resnet_mean) / self._resnet_std
def get_feat_dims(self, size_dict: bool = False):
if size_dict:
return copy.deepcopy(self._feat_dim)
# pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute
# `values`.
return sum(self._feat_dim.values())
def forward(
self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None
) -> Dict[Any, torch.Tensor]:
"""
Args:
imgs: A batch of input images of shape `(B, 3, H, W)`.
masks: A batch of input masks of shape `(B, 3, H, W)`.
Returns:
out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i`
and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`.
"""
out_feats = {}
imgs_input = imgs
if self.image_rescale != 1.0:
imgs_resized = Fu.interpolate(
imgs_input,
# pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but
# got `float`.
scale_factor=self.image_rescale,
mode="bilinear",
)
else:
imgs_resized = imgs_input
if self.normalize_image:
imgs_normed = self._resnet_normalize_image(imgs_resized)
else:
imgs_normed = imgs_resized
if len(self.stages) > 0:
# pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
# is not a function.
feats = self.stem(imgs_normed)
# pyre-fixme[6]: For 1st param expected `Iterable[Variable[_T1]]` but
# got `Union[Tensor, Module]`.
# pyre-fixme[6]: For 2nd param expected `Iterable[Variable[_T2]]` but
# got `Union[Tensor, Module]`.
for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
feats = layer(feats)
# just a sanity check below
assert feats.shape[1] == _FEAT_DIMS[self.name][stage]
if (stage + 1) in self.stages:
f = proj(feats)
if self.global_average_pool:
f = f.mean(dims=(2, 3))
if self.l2_norm:
normfac = 1.0 / math.sqrt(len(self.stages))
f = Fu.normalize(f, dim=1) * normfac
feature_name = self._get_resnet_stage_feature_name(stage)
out_feats[feature_name] = f
if self.add_masks:
assert masks is not None
out_feats[MASK_FEATURE_NAME] = masks
if self.add_images:
assert imgs_input is not None
out_feats[IMAGE_FEATURE_NAME] = imgs_resized
if self.feature_rescale != 1.0:
out_feats = {k: self.feature_rescale * f for k, f in out_feats.items()}
# pyre-fixme[7]: Incompatible return type, expected `Dict[typing.Any, Tensor]`
# but got `Dict[typing.Any, float]`
return out_feats

View File

@ -0,0 +1,666 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
from enum import Enum
from typing import Dict, Optional, Sequence, Union
import torch
import torch.nn.functional as F
from pytorch3d.implicitron.models.view_pooling.view_sampling import (
cameras_points_cartesian_product,
)
from pytorch3d.implicitron.tools.config import ReplaceableBase, registry
from pytorch3d.ops import wmean
from pytorch3d.renderer.cameras import CamerasBase
class ReductionFunction(Enum):
AVG = "avg" # simple average
MAX = "max" # maximum
STD = "std" # standard deviation
STD_AVG = "std_avg" # average of per-dimension standard deviations
class FeatureAggregatorBase(ABC, ReplaceableBase):
"""
Base class for aggregating features.
Typically, the aggregated features and their masks are output by `ViewSampler`
which samples feature tensors extracted from a set of source images.
Settings:
exclude_target_view: If `True`/`False`, enables/disables pooling
from target view to itself.
exclude_target_view_mask_features: If `True`,
mask the features from the target view before aggregation
concatenate_output: If `True`,
concatenate the aggregated features into a single tensor,
otherwise return a dictionary mapping feature names to tensors.
"""
exclude_target_view: bool = True
exclude_target_view_mask_features: bool = True
concatenate_output: bool = True
@abstractmethod
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, reduce_dim, n_samples, sum(dim_1, ... dim_N))`
containing the concatenation of the aggregated features `feats_sampled`.
`reduce_dim` depends on the specific feature aggregator
implementation and typically equals 1 or `n_source_views`.
If `concatenate_output==False`, the aggregator does not concatenate
the aggregated features and returns a dictionary of per-feature
aggregations `{f_i: t_i_aggregated}` instead. Each `t_i_aggregated`
is of shape `(minibatch, reduce_dim, n_samples, aggr_dim_i)`.
"""
raise NotImplementedError()
@registry.register
class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
This aggregator does not perform any feature aggregation. Depending on the
settings the aggregator allows to mask target view features and concatenate
the outputs.
"""
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects
corresponding to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
feats_aggregated = feats_sampled
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
@registry.register
class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
Aggregates using a set of predefined `reduction_functions` and concatenates
the results of each aggregation function along the
channel dimension. The reduction functions singularize the second dimension
of the sampled features which stacks the source views.
Settings:
reduction_functions: A list of `ReductionFunction`s` that reduce the
the stack of source-view-specific features to a single feature.
"""
reduction_functions: Sequence[ReductionFunction] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape `(minibatch, 1, n_samples, aggr_dim_i)`.
"""
pts_batch, n_cameras = masks_sampled.shape[:2]
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
masks_sampled.device,
self.exclude_target_view,
)
aggr_weigths = masks_sampled * sampling_mask
feats_aggregated = {
k: _avgmaxstd_reduction_function(
f,
aggr_weigths,
dim=1,
reduction_functions=self.reduction_functions,
)
for k, f in feats_sampled.items()
}
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
@registry.register
class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
Performs a weighted aggregation using a set of predefined `reduction_functions`
and concatenates the results of each aggregation function along the
channel dimension. The weights are proportional to the cosine of the
angle between the target ray and the source ray:
```
weight = (
dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight
)**self.weight_by_ray_angle_gamma
```
The reduction functions singularize the second dimension
of the sampled features which stacks the source views.
Settings:
reduction_functions: A list of `ReductionFunction`s that reduce the
the stack of source-view-specific features to a single feature.
min_ray_angle_weight: The minimum possible aggregation weight
before rasising to the power of `self.weight_by_ray_angle_gamma`.
weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles
used when calculating the angle-based aggregation weights.
"""
reduction_functions: Sequence[ReductionFunction] = (
ReductionFunction.AVG,
ReductionFunction.STD,
)
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions)
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects
corresponding to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if camera is None:
raise ValueError("camera cannot be None for angle weighted aggregation")
if pts is None:
raise ValueError("Points cannot be None for angle weighted aggregation")
pts_batch, n_cameras = masks_sampled.shape[:2]
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
view_sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
masks_sampled.device,
self.exclude_target_view,
)
aggr_weights = _get_angular_reduction_weights(
view_sampling_mask,
masks_sampled,
camera,
pts,
self.min_ray_angle_weight,
self.weight_by_ray_angle_gamma,
)
assert torch.isfinite(aggr_weights).all()
feats_aggregated = {
k: _avgmaxstd_reduction_function(
f,
aggr_weights,
dim=1,
reduction_functions=self.reduction_functions,
)
for k, f in feats_sampled.items()
}
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
@registry.register
class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase):
"""
This aggregator does not perform any feature aggregation. It only weights
the features by the weights proportional to the cosine of the
angle between the target ray and the source ray:
```
weight = (
dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight
)**self.weight_by_ray_angle_gamma
```
Settings:
min_ray_angle_weight: The minimum possible aggregation weight
before rasising to the power of `self.weight_by_ray_angle_gamma`.
weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles
used when calculating the angle-based aggregation weights.
Additionally the aggregator allows to mask target view features and to concatenate
the outputs.
"""
weight_by_ray_angle_gamma: float = 1.0
min_ray_angle_weight: float = 0.1
def __post_init__(self):
super().__init__()
def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]):
return _get_reduction_aggregator_feature_dim(feats, [])
def forward(
self,
feats_sampled: Dict[str, torch.Tensor],
masks_sampled: torch.Tensor,
camera: Optional[CamerasBase] = None,
pts: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`,
where each `t_i` is a tensor of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
masks_sampled: A binary mask represented as a tensor of shape
`(minibatch, n_source_views, n_samples, 1)` denoting valid
sampled features.
camera: A batch of `n_source_views` `CamerasBase` objects corresponding
to the source view cameras.
pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the
3D points whose 2D projections to source views were sampled in
order to generate `feats_sampled` and `masks_sampled`.
Returns:
feats_aggregated: If `concatenate_output==True`, a tensor
of shape `(minibatch, n_source_views, n_samples, sum(dim_1, ... dim_N))`.
If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}`
with each `t_i_aggregated` of shape
`(minibatch, n_source_views, n_samples, dim_i)`.
"""
if camera is None:
raise ValueError("camera cannot be None for angle weighted aggregation")
if pts is None:
raise ValueError("Points cannot be None for angle weighted aggregation")
pts_batch, n_cameras = masks_sampled.shape[:2]
if self.exclude_target_view_mask_features:
feats_sampled = _mask_target_view_features(feats_sampled)
view_sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
masks_sampled.device,
self.exclude_target_view,
)
aggr_weights = _get_angular_reduction_weights(
view_sampling_mask,
masks_sampled,
camera,
pts,
self.min_ray_angle_weight,
self.weight_by_ray_angle_gamma,
)
feats_aggregated = {
k: f * aggr_weights[..., None] for k, f in feats_sampled.items()
}
if self.concatenate_output:
feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1)
return feats_aggregated
def _get_reduction_aggregator_feature_dim(
feats_or_feats_dim: Union[Dict[str, torch.Tensor], int],
reduction_functions: Sequence[ReductionFunction],
):
if isinstance(feats_or_feats_dim, int):
feat_dim = feats_or_feats_dim
else:
feat_dim = int(sum(f.shape[1] for f in feats_or_feats_dim.values()))
if len(reduction_functions) == 0:
return feat_dim
return sum(
_get_reduction_function_output_dim(
reduction_function,
feat_dim,
)
for reduction_function in reduction_functions
)
def _get_reduction_function_output_dim(
reduction_function: ReductionFunction,
feat_dim: int,
) -> int:
if reduction_function == ReductionFunction.STD_AVG:
return 1
else:
return feat_dim
def _get_view_sampling_mask(
n_cameras: int,
pts_batch: int,
device: Union[str, torch.device],
exclude_target_view: bool,
):
return (
-torch.eye(n_cameras, device=device, dtype=torch.float32)
* float(exclude_target_view)
+ 1.0
)[:pts_batch]
def _mask_target_view_features(
feats_sampled: Dict[str, torch.Tensor],
):
# mask out the sampled features to be sure we dont use them
# anywhere later
one_feature_sampled = next(iter(feats_sampled.values()))
pts_batch, n_cameras = one_feature_sampled.shape[:2]
view_sampling_mask = _get_view_sampling_mask(
n_cameras,
pts_batch,
one_feature_sampled.device,
True,
)
view_sampling_mask = view_sampling_mask.view(
pts_batch, n_cameras, *([1] * (one_feature_sampled.ndim - 2))
)
return {k: f * view_sampling_mask for k, f in feats_sampled.items()}
def _get_angular_reduction_weights(
view_sampling_mask: torch.Tensor,
masks_sampled: torch.Tensor,
camera: CamerasBase,
pts: torch.Tensor,
min_ray_angle_weight: float,
weight_by_ray_angle_gamma: float,
):
aggr_weights = masks_sampled.clone()[..., 0]
assert not any(v is None for v in [camera, pts])
angle_weight = _get_ray_angle_weights(
camera,
pts,
min_ray_angle_weight,
weight_by_ray_angle_gamma,
)
assert torch.isfinite(angle_weight).all()
# multiply the final aggr weights with ray angles
view_sampling_mask = view_sampling_mask.view(
*view_sampling_mask.shape[:2], *([1] * (aggr_weights.ndim - 2))
)
aggr_weights = (
aggr_weights * angle_weight.reshape_as(aggr_weights) * view_sampling_mask
)
return aggr_weights
def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
n_cameras = camera.R.shape[0]
pts_batch = pts.shape[0]
camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts)
# does not produce nans randomly unlike get_camera_center() below
cam_centers_rep = -torch.bmm(
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
# torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
# torch.Tensor, torch.nn.modules.module.Module]` is not a function.
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
camera_rep.T[:, None],
# pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
camera_rep.R.permute(0, 2, 1),
).reshape(-1, *([1] * (pts.ndim - 2)), 3)
# cam_centers_rep = camera_rep.get_camera_center().reshape(
# -1, *([1]*(pts.ndim - 2)), 3
# )
ray_dirs = F.normalize(pts_rep - cam_centers_rep, dim=-1)
# camera_rep = [ pts_rep = [
# camera[0] pts[0],
# camera[0] pts[1],
# camera[0] ...,
# ... pts[batch_pts-1],
# camera[1] pts[0],
# camera[1] pts[1],
# camera[1] ...,
# ... pts[batch_pts-1],
# ... ...,
# camera[n_cameras-1] pts[0],
# camera[n_cameras-1] pts[1],
# camera[n_cameras-1] ...,
# ... pts[batch_pts-1],
# ] ]
ray_dirs_reshape = ray_dirs.view(n_cameras, pts_batch, -1, 3)
# [
# [pts_0 in cam_0, pts_1 in cam_0, ..., pts_m in cam_0],
# [pts_0 in cam_1, pts_1 in cam_1, ..., pts_m in cam_1],
# ...
# [pts_0 in cam_n, pts_1 in cam_n, ..., pts_m in cam_n],
# ]
ray_dirs_pts = torch.stack([ray_dirs_reshape[i, i] for i in range(pts_batch)])
ray_dir_dot_prods = (ray_dirs_pts[None] * ray_dirs_reshape).sum(
dim=-1
) # pts_batch x n_cameras x n_pts
return ray_dir_dot_prods.transpose(0, 1)
def _get_ray_angle_weights(
camera: CamerasBase,
pts: torch.Tensor,
min_ray_angle_weight: float,
weight_by_ray_angle_gamma: float,
):
ray_dir_dot_prods = _get_ray_dir_dot_prods(
camera, pts
) # pts_batch x n_cameras x ... x 3
angle_weight_01 = ray_dir_dot_prods * 0.5 + 0.5 # [-1, 1] to [0, 1]
angle_weight = (angle_weight_01 + min_ray_angle_weight) ** weight_by_ray_angle_gamma
return angle_weight
def _avgmaxstd_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
reduction_functions: Sequence[ReductionFunction],
dim: int = 1,
):
"""
Args:
x: Features to aggreagate. Tensor of shape `(batch, n_views, ..., dim)`.
w: Aggregation weights. Tensor of shape `(batch, n_views, ...,)`.
dim: the dimension along which to aggregate.
reduction_functions: The set of reduction functions.
Returns:
x_aggr: Aggregation of `x` to a tensor of shape `(batch, 1, ..., dim_aggregate)`.
"""
pooled_features = []
mu = None
std = None
if ReductionFunction.AVG in reduction_functions:
# average pool
mu = _avg_reduction_function(x, w, dim=dim)
pooled_features.append(mu)
if ReductionFunction.STD in reduction_functions:
# standard-dev pool
std = _std_reduction_function(x, w, dim=dim, mu=mu)
pooled_features.append(std)
if ReductionFunction.STD_AVG in reduction_functions:
# average-of-standard-dev pool
stdavg = _std_avg_reduction_function(x, w, dim=dim, mu=mu, std=std)
pooled_features.append(stdavg)
if ReductionFunction.MAX in reduction_functions:
max_ = _max_reduction_function(x, w, dim=dim)
pooled_features.append(max_)
# cat all results along the feature dimension (the last dim)
x_aggr = torch.cat(pooled_features, dim=-1)
# zero out features that were all masked out
any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr)
x_aggr = x_aggr * any_active[..., None]
# some asserts to check that everything was done right
assert torch.isfinite(x_aggr).all()
assert x_aggr.shape[1] == 1
return x_aggr
def _avg_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
):
mu = wmean(x, w, dim=dim, eps=1e-2)
return mu
def _std_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
mu: Optional[torch.Tensor] = None, # pre-computed mean
):
if mu is None:
mu = _avg_reduction_function(x, w, dim=dim)
std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt()
# FIXME: somehow this is extremely heavy in mem?
return std
def _std_avg_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
mu: Optional[torch.Tensor] = None, # pre-computed mean
std: Optional[torch.Tensor] = None, # pre-computed std
):
if std is None:
std = _std_reduction_function(x, w, dim=dim, mu=mu)
stdmean = std.mean(dim=-1, keepdim=True)
return stdmean
def _max_reduction_function(
x: torch.Tensor,
w: torch.Tensor,
dim: int = 1,
big_M_factor: float = 10.0,
):
big_M = x.max(dim=dim, keepdim=True).values.abs() * big_M_factor
max_ = (x * w - ((1 - w) * big_M)).max(dim=dim, keepdim=True).values
return max_

View File

@ -0,0 +1,291 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Dict, List, Optional, Tuple, Union
import torch
from pytorch3d.implicitron.tools.config import Configurable
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.renderer.utils import ndc_grid_sample
class ViewSampler(Configurable, torch.nn.Module):
"""
Implements sampling of image-based features at the 2d projections of a set
of 3D points.
Args:
masked_sampling: If `True`, the `sampled_masks` output of `self.forward`
contains the input `masks` sampled at the 2d projections. Otherwise,
all entries of `sampled_masks` are set to 1.
sampling_mode: Controls the mode of the `torch.nn.functional.grid_sample`
function used to interpolate the sampled feature tensors at the
locations of the 2d projections.
"""
masked_sampling: bool = False
sampling_mode: str = "bilinear"
def __post_init__(self):
super().__init__()
def forward(
self,
*, # force kw args
pts: torch.Tensor,
seq_id_pts: Union[List[int], List[str], torch.LongTensor],
camera: CamerasBase,
seq_id_camera: Union[List[int], List[str], torch.LongTensor],
feats: Dict[str, torch.Tensor],
masks: Optional[torch.Tensor],
**kwargs,
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Project each point cloud from a batch of point clouds to corresponding
input cameras and sample features at the 2D projection locations.
Args:
pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords.
seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes
from which `pts` were extracted, or a list of string names.
camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`.
seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes
corresponding to cameras in `camera`, or a list of string names.
feats: a dict of tensors of per-image features `{feat_i: T_i}`.
Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`.
masks: `[n_cameras x 1 x H x W]`, define valid image regions
for sampling `feats`.
Returns:
sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`.
Each `sampled_T_i` of shape `[pts_batch, n_cameras, n_pts, dim_i]`.
sampled_masks: A tensor with mask of the sampled features
of shape `(pts_batch, n_cameras, n_pts, 1)`.
"""
# convert sequence ids to long tensors
seq_id_pts, seq_id_camera = [
handle_seq_id(seq_id, pts.device) for seq_id in [seq_id_pts, seq_id_camera]
]
if self.masked_sampling and masks is None:
raise ValueError(
"Masks have to be provided for `self.masked_sampling==True`"
)
# project pts to all cameras and sample feats from the locations of
# the 2D projections
sampled_feats_all_cams, sampled_masks_all_cams = project_points_and_sample(
pts,
feats,
camera,
masks if self.masked_sampling else None,
sampling_mode=self.sampling_mode,
)
# generate the mask that invalidates features sampled from
# non-corresponding cameras
camera_pts_mask = (seq_id_camera[None] == seq_id_pts[:, None])[
..., None, None
].to(pts)
# mask the sampled features and masks
sampled_feats = {
k: f * camera_pts_mask for k, f in sampled_feats_all_cams.items()
}
sampled_masks = sampled_masks_all_cams * camera_pts_mask
return sampled_feats, sampled_masks
def project_points_and_sample(
pts: torch.Tensor,
feats: Dict[str, torch.Tensor],
camera: CamerasBase,
masks: Optional[torch.Tensor],
eps: float = 1e-2,
sampling_mode: str = "bilinear",
) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]:
"""
Project each point cloud from a batch of point clouds to all input cameras
and sample features at the 2D projection locations.
Args:
pts: `(pts_batch, n_pts, 3)` tensor containing a batch of 3D point clouds.
feats: A dict `{feat_i: feat_T_i}` of features to sample,
where each `feat_T_i` is a tensor of shape
`(n_cameras, feat_i_dim, feat_i_H, feat_i_W)`
of `feat_i_dim`-dimensional features extracted from `n_cameras`
source views.
camera: A batch of `n_cameras` cameras corresponding to their feature
tensors `feat_T_i` from `feats`.
masks: A tensor of shape `(n_cameras, 1, mask_H, mask_W)` denoting
valid locations for sampling.
eps: A small constant controlling the minimum depth of projections
of `pts` to avoid divisons by zero in the projection operation.
sampling_mode: Sampling mode of the grid sampler.
Returns:
sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`.
Each `sampled_T_i` is of shape
`(pts_batch, n_cameras, n_pts, feat_i_dim)`.
sampled_masks: A tensor with the mask of the sampled features
of shape `(pts_batch, n_cameras, n_pts, 1)`.
If `masks` is `None`, the returned `sampled_masks` will be
filled with 1s.
"""
n_cameras = camera.R.shape[0]
pts_batch = pts.shape[0]
n_pts = pts.shape[1:-1]
camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts)
# The eps here is super-important to avoid NaNs in backprop!
proj_rep = camera_rep.transform_points(
pts_rep.reshape(n_cameras * pts_batch, -1, 3), eps=eps
)[..., :2]
# [ pts1 in cam1, pts2 in cam1, pts3 in cam1,
# pts1 in cam2, pts2 in cam2, pts3 in cam2,
# pts1 in cam3, pts2 in cam3, pts3 in cam3 ]
# reshape for the grid sampler
sampling_grid_ndc = proj_rep.view(n_cameras, pts_batch, -1, 2)
# [ [pts1 in cam1, pts2 in cam1, pts3 in cam1],
# [pts1 in cam2, pts2 in cam2, pts3 in cam2],
# [pts1 in cam3, pts2 in cam3, pts3 in cam3] ]
# n_cameras x pts_batch x n_pts x 2
# sample both feats
feats_sampled = {
k: ndc_grid_sample(
f,
sampling_grid_ndc,
mode=sampling_mode,
align_corners=False,
)
.permute(2, 0, 3, 1)
.reshape(pts_batch, n_cameras, *n_pts, -1)
for k, f in feats.items()
} # {k: pts_batch x n_cameras x *n_pts x dim} for each feat type "k"
if masks is not None:
# sample masks
masks_sampled = (
ndc_grid_sample(
masks,
sampling_grid_ndc,
mode=sampling_mode,
align_corners=False,
)
.permute(2, 0, 3, 1)
.reshape(pts_batch, n_cameras, *n_pts, 1)
)
else:
masks_sampled = sampling_grid_ndc.new_ones(pts_batch, n_cameras, *n_pts, 1)
return feats_sampled, masks_sampled
def handle_seq_id(
seq_id: Union[torch.LongTensor, List[str], List[int]],
device,
) -> torch.LongTensor:
"""
Converts the input sequence id to a LongTensor.
Args:
seq_id: A sequence of sequence ids.
device: The target device of the output.
Returns
long_seq_id: `seq_id` converted to a `LongTensor` and moved to `device`.
"""
if not torch.is_tensor(seq_id):
if isinstance(seq_id[0], str):
seq_id = [hash(s) for s in seq_id]
seq_id = torch.tensor(seq_id, dtype=torch.long, device=device)
return seq_id.to(device)
def cameras_points_cartesian_product(
camera: CamerasBase, pts: torch.Tensor
) -> Tuple[CamerasBase, torch.Tensor]:
"""
Generates all pairs of pairs of elements from 'camera' and 'pts' and returns
`camera_rep` and `pts_rep` such that:
```
camera_rep = [ pts_rep = [
camera[0] pts[0],
camera[0] pts[1],
camera[0] ...,
... pts[batch_pts-1],
camera[1] pts[0],
camera[1] pts[1],
camera[1] ...,
... pts[batch_pts-1],
... ...,
camera[n_cameras-1] pts[0],
camera[n_cameras-1] pts[1],
camera[n_cameras-1] ...,
... pts[batch_pts-1],
] ]
```
Args:
camera: A batch of `n_cameras` cameras.
pts: A batch of `batch_pts` points of shape `(batch_pts, ..., dim)`
Returns:
camera_rep: A batch of batch_pts*n_cameras cameras such that:
```
camera_rep = [
camera[0]
camera[0]
camera[0]
...
camera[1]
camera[1]
camera[1]
...
...
camera[n_cameras-1]
camera[n_cameras-1]
camera[n_cameras-1]
]
```
pts_rep: Repeated `pts` of shape `(batch_pts*n_cameras, ..., dim)`,
such that:
```
pts_rep = [
pts[0],
pts[1],
...,
pts[batch_pts-1],
pts[0],
pts[1],
...,
pts[batch_pts-1],
...,
pts[0],
pts[1],
...,
pts[batch_pts-1],
]
```
"""
n_cameras = camera.R.shape[0]
batch_pts = pts.shape[0]
pts_rep = pts.repeat(n_cameras, *[1 for _ in pts.shape[1:]])
idx_cams = (
torch.arange(n_cameras)[:, None]
.expand(
n_cameras,
batch_pts,
)
.reshape(batch_pts * n_cameras)
)
camera_rep = camera[idx_cams]
return camera_rep, pts_rep

View File

@ -0,0 +1,254 @@
# a copy-paste from https://github.com/vsitzmann/scene-representation-networks/blob/master/hyperlayers.py
# fmt: off
# flake8: noqa
'''Pytorch implementations of hyper-network modules.
isort:skip_file
'''
import functools
import torch
import torch.nn as nn
from . import pytorch_prototyping
def partialclass(cls, *args, **kwds):
class NewCls(cls):
__init__ = functools.partialmethod(cls.__init__, *args, **kwds)
return NewCls
class LookupLayer(nn.Module):
def __init__(self, in_ch, out_ch, num_objects):
super().__init__()
self.out_ch = out_ch
self.lookup_lin = LookupLinear(in_ch, out_ch, num_objects=num_objects)
self.norm_nl = nn.Sequential(
nn.LayerNorm([self.out_ch], elementwise_affine=False), nn.ReLU(inplace=True)
)
def forward(self, obj_idx):
net = nn.Sequential(self.lookup_lin(obj_idx), self.norm_nl)
return net
class LookupFC(nn.Module):
def __init__(
self,
hidden_ch,
num_hidden_layers,
num_objects,
in_ch,
out_ch,
outermost_linear=False,
):
super().__init__()
self.layers = nn.ModuleList()
self.layers.append(
LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects)
)
for i in range(num_hidden_layers):
self.layers.append(
LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects)
)
if outermost_linear:
self.layers.append(
LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)
)
else:
self.layers.append(
LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects)
)
def forward(self, obj_idx):
net = []
for i in range(len(self.layers)):
net.append(self.layers[i](obj_idx))
return nn.Sequential(*net)
class LookupLinear(nn.Module):
def __init__(self, in_ch, out_ch, num_objects):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch)
for i in range(num_objects):
nn.init.kaiming_normal_(
self.hypo_params.weight.data[i, : self.in_ch * self.out_ch].view(
self.out_ch, self.in_ch
),
a=0.0,
nonlinearity="relu",
mode="fan_in",
)
self.hypo_params.weight.data[i, self.in_ch * self.out_ch :].fill_(0.0)
def forward(self, obj_idx):
hypo_params = self.hypo_params(obj_idx)
# Indices explicit to catch erros in shape of output layer
weights = hypo_params[..., : self.in_ch * self.out_ch]
biases = hypo_params[
..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch
]
biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch)
weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch)
return BatchLinear(weights=weights, biases=biases)
class HyperLayer(nn.Module):
"""A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU."""
def __init__(
self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch
):
super().__init__()
self.hyper_linear = HyperLinear(
in_ch=in_ch,
out_ch=out_ch,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
self.norm_nl = nn.Sequential(
nn.LayerNorm([out_ch], elementwise_affine=False), nn.ReLU(inplace=True)
)
def forward(self, hyper_input):
"""
:param hyper_input: input to hypernetwork.
:return: nn.Module; predicted fully connected network.
"""
return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl)
class HyperFC(nn.Module):
"""Builds a hypernetwork that predicts a fully connected neural network."""
def __init__(
self,
hyper_in_ch,
hyper_num_hidden_layers,
hyper_hidden_ch,
hidden_ch,
num_hidden_layers,
in_ch,
out_ch,
outermost_linear=False,
):
super().__init__()
PreconfHyperLinear = partialclass(
HyperLinear,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
PreconfHyperLayer = partialclass(
HyperLayer,
hyper_in_ch=hyper_in_ch,
hyper_num_hidden_layers=hyper_num_hidden_layers,
hyper_hidden_ch=hyper_hidden_ch,
)
self.layers = nn.ModuleList()
self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch))
for i in range(num_hidden_layers):
self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch))
if outermost_linear:
self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch))
else:
self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch))
def forward(self, hyper_input):
"""
:param hyper_input: Input to hypernetwork.
:return: nn.Module; Predicted fully connected neural network.
"""
net = []
for i in range(len(self.layers)):
net.append(self.layers[i](hyper_input))
return nn.Sequential(*net)
class BatchLinear(nn.Module):
def __init__(self, weights, biases):
"""Implements a batch linear layer.
:param weights: Shape: (batch, out_ch, in_ch)
:param biases: Shape: (batch, 1, out_ch)
"""
super().__init__()
self.weights = weights
self.biases = biases
def __repr__(self):
return "BatchLinear(in_ch=%d, out_ch=%d)" % (
self.weights.shape[-1],
self.weights.shape[-2],
)
def forward(self, input):
output = input.matmul(
self.weights.permute(
*[i for i in range(len(self.weights.shape) - 2)], -1, -2
)
)
output += self.biases
return output
def last_hyper_layer_init(m) -> None:
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
# pyre-fixme[41]: `data` cannot be reassigned. It is a read-only property.
m.weight.data *= 1e-1
class HyperLinear(nn.Module):
"""A hypernetwork that predicts a single linear layer (weights & biases)."""
def __init__(
self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch
):
super().__init__()
self.in_ch = in_ch
self.out_ch = out_ch
self.hypo_params = pytorch_prototyping.FCBlock(
in_features=hyper_in_ch,
hidden_ch=hyper_hidden_ch,
num_hidden_layers=hyper_num_hidden_layers,
out_features=(in_ch * out_ch) + out_ch,
outermost_linear=True,
)
self.hypo_params[-1].apply(last_hyper_layer_init)
def forward(self, hyper_input):
hypo_params = self.hypo_params(hyper_input)
# Indices explicit to catch erros in shape of output layer
weights = hypo_params[..., : self.in_ch * self.out_ch]
biases = hypo_params[
..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch
]
biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch)
weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch)
return BatchLinear(weights=weights, biases=biases)

View File

@ -0,0 +1,772 @@
# a copy-paste from https://raw.githubusercontent.com/vsitzmann/pytorch_prototyping/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py
# fmt: off
# flake8: noqa
'''A number of custom pytorch modules with sane defaults that I find useful for model prototyping.
isort:skip_file
'''
import torch
import torch.nn as nn
import torchvision.utils
from torch.nn import functional as F
class FCLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_features, out_features),
nn.LayerNorm([out_features]),
nn.ReLU(inplace=True),
)
def forward(self, input):
return self.net(input)
# From https://gist.github.com/wassname/ecd2dac6fc8f9918149853d17e3abf02
class LayerNormConv2d(nn.Module):
def __init__(self, num_features, eps=1e-5, affine=True):
super().__init__()
self.num_features = num_features
self.affine = affine
self.eps = eps
if self.affine:
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
shape = [-1] + [1] * (x.dim() - 1)
mean = x.view(x.size(0), -1).mean(1).view(*shape)
std = x.view(x.size(0), -1).std(1).view(*shape)
y = (x - mean) / (std + self.eps)
if self.affine:
shape = [1, -1] + [1] * (x.dim() - 2)
y = self.gamma.view(*shape) * y + self.beta.view(*shape)
return y
class FCBlock(nn.Module):
def __init__(
self,
hidden_ch,
num_hidden_layers,
in_features,
out_features,
outermost_linear=False,
):
super().__init__()
self.net = []
self.net.append(FCLayer(in_features=in_features, out_features=hidden_ch))
for i in range(num_hidden_layers):
self.net.append(FCLayer(in_features=hidden_ch, out_features=hidden_ch))
if outermost_linear:
self.net.append(nn.Linear(in_features=hidden_ch, out_features=out_features))
else:
self.net.append(FCLayer(in_features=hidden_ch, out_features=out_features))
self.net = nn.Sequential(*self.net)
self.net.apply(self.init_weights)
def __getitem__(self, item):
return self.net[item]
def init_weights(self, m):
if type(m) == nn.Linear:
nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in")
def forward(self, input):
return self.net(input)
class DownBlock3D(nn.Module):
"""A 3D convolutional downsampling block."""
def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d):
super().__init__()
self.net = [
nn.ReplicationPad3d(1),
nn.Conv3d(
in_channels,
out_channels,
kernel_size=4,
padding=0,
stride=2,
bias=False if norm is not None else True,
),
]
if norm is not None:
self.net += [norm(out_channels, affine=True)]
self.net += [nn.LeakyReLU(0.2, True)]
self.net = nn.Sequential(*self.net)
def forward(self, x):
return self.net(x)
class UpBlock3D(nn.Module):
"""A 3D convolutional upsampling block."""
def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d):
super().__init__()
self.net = [
nn.ConvTranspose3d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=False if norm is not None else True,
),
]
if norm is not None:
self.net += [norm(out_channels, affine=True)]
self.net += [nn.ReLU(True)]
self.net = nn.Sequential(*self.net)
def forward(self, x, skipped=None):
if skipped is not None:
input = torch.cat([skipped, x], dim=1)
else:
input = x
return self.net(input)
class Conv3dSame(torch.nn.Module):
"""3D convolution that pads to keep spatial dimensions equal.
Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
bias=True,
padding_layer=nn.ReplicationPad3d,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
:param bias: Whether or not to use bias.
:param padding_layer: Which padding to use. Default is reflection padding.
"""
super().__init__()
ka = kernel_size // 2
kb = ka - 1 if kernel_size % 2 == 0 else ka
self.net = nn.Sequential(
padding_layer((ka, kb, ka, kb, ka, kb)),
nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, stride=1),
)
def forward(self, x):
return self.net(x)
class Conv2dSame(torch.nn.Module):
"""2D convolution that pads to keep spatial dimensions equal.
Cannot deal with stride. Only quadratic kernels (=scalar kernel_size).
"""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
bias=True,
padding_layer=nn.ReflectionPad2d,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported).
:param bias: Whether or not to use bias.
:param padding_layer: Which padding to use. Default is reflection padding.
"""
super().__init__()
ka = kernel_size // 2
kb = ka - 1 if kernel_size % 2 == 0 else ka
self.net = nn.Sequential(
padding_layer((ka, kb, ka, kb)),
nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1),
)
self.weight = self.net[1].weight
self.bias = self.net[1].bias
def forward(self, x):
return self.net(x)
class UpBlock(nn.Module):
"""A 2d-conv upsampling block with a variety of options for upsampling, and following best practices / with
reasonable defaults. (LeakyReLU, kernel size multiple of stride)
"""
def __init__(
self,
in_channels,
out_channels,
post_conv=True,
use_dropout=False,
dropout_prob=0.1,
norm=nn.BatchNorm2d,
upsampling_mode="transpose",
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param post_conv: Whether to have another convolutional layer after the upsampling layer.
:param use_dropout: bool. Whether to use dropout or not.
:param dropout_prob: Float. The dropout probability (if use_dropout is True)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param upsampling_mode: Which upsampling mode:
transpose: Upsampling with stride-2, kernel size 4 transpose convolutions.
bilinear: Feature map is upsampled with bilinear upsampling, then a conv layer.
nearest: Feature map is upsampled with nearest neighbor upsampling, then a conv layer.
shuffle: Feature map is upsampled with pixel shuffling, then a conv layer.
"""
super().__init__()
net = list()
if upsampling_mode == "transpose":
net += [
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=4,
stride=2,
padding=1,
bias=True if norm is None else False,
)
]
elif upsampling_mode == "bilinear":
net += [nn.UpsamplingBilinear2d(scale_factor=2)]
net += [
Conv2dSame(
in_channels,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
elif upsampling_mode == "nearest":
net += [nn.UpsamplingNearest2d(scale_factor=2)]
net += [
Conv2dSame(
in_channels,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
elif upsampling_mode == "shuffle":
net += [nn.PixelShuffle(upscale_factor=2)]
net += [
Conv2dSame(
in_channels // 4,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
else:
raise ValueError("Unknown upsampling mode!")
if norm is not None:
net += [norm(out_channels, affine=True)]
net += [nn.ReLU(True)]
if use_dropout:
net += [nn.Dropout2d(dropout_prob, False)]
if post_conv:
net += [
Conv2dSame(
out_channels,
out_channels,
kernel_size=3,
bias=True if norm is None else False,
)
]
if norm is not None:
net += [norm(out_channels, affine=True)]
net += [nn.ReLU(True)]
if use_dropout:
net += [nn.Dropout2d(0.1, False)]
self.net = nn.Sequential(*net)
def forward(self, x, skipped=None):
if skipped is not None:
input = torch.cat([skipped, x], dim=1)
else:
input = x
return self.net(input)
class DownBlock(nn.Module):
"""A 2D-conv downsampling block following best practices / with reasonable defaults
(LeakyReLU, kernel size multiple of stride)
"""
def __init__(
self,
in_channels,
out_channels,
prep_conv=True,
middle_channels=None,
use_dropout=False,
dropout_prob=0.1,
norm=nn.BatchNorm2d,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param prep_conv: Whether to have another convolutional layer before the downsampling layer.
:param middle_channels: If prep_conv is true, this sets the number of channels between the prep and downsampling
convs.
:param use_dropout: bool. Whether to use dropout or not.
:param dropout_prob: Float. The dropout probability (if use_dropout is True)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
"""
super().__init__()
if middle_channels is None:
middle_channels = in_channels
net = list()
if prep_conv:
net += [
nn.ReflectionPad2d(1),
nn.Conv2d(
in_channels,
middle_channels,
kernel_size=3,
padding=0,
stride=1,
bias=True if norm is None else False,
),
]
if norm is not None:
net += [norm(middle_channels, affine=True)]
net += [nn.LeakyReLU(0.2, True)]
if use_dropout:
net += [nn.Dropout2d(dropout_prob, False)]
net += [
nn.ReflectionPad2d(1),
nn.Conv2d(
middle_channels,
out_channels,
kernel_size=4,
padding=0,
stride=2,
bias=True if norm is None else False,
),
]
if norm is not None:
net += [norm(out_channels, affine=True)]
net += [nn.LeakyReLU(0.2, True)]
if use_dropout:
net += [nn.Dropout2d(dropout_prob, False)]
self.net = nn.Sequential(*net)
def forward(self, x):
return self.net(x)
class Unet3d(nn.Module):
"""A 3d-Unet implementation with sane defaults."""
def __init__(
self,
in_channels,
out_channels,
nf0,
num_down,
max_channels,
norm=nn.BatchNorm3d,
outermost_linear=False,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param nf0: Number of features at highest level of U-Net
:param num_down: Number of downsampling stages.
:param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage)
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one.
"""
super().__init__()
assert num_down > 0, "Need at least one downsampling layer in UNet3d."
# Define the in block
self.in_layer = [Conv3dSame(in_channels, nf0, kernel_size=3, bias=False)]
if norm is not None:
self.in_layer += [norm(nf0, affine=True)]
self.in_layer += [nn.LeakyReLU(0.2, True)]
self.in_layer = nn.Sequential(*self.in_layer)
# Define the center UNet block. The feature map has height and width 1 --> no batchnorm.
self.unet_block = UnetSkipConnectionBlock3d(
int(min(2 ** (num_down - 1) * nf0, max_channels)),
int(min(2 ** (num_down - 1) * nf0, max_channels)),
norm=None,
)
for i in list(range(0, num_down - 1))[::-1]:
self.unet_block = UnetSkipConnectionBlock3d(
int(min(2 ** i * nf0, max_channels)),
int(min(2 ** (i + 1) * nf0, max_channels)),
submodule=self.unet_block,
norm=norm,
)
# Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer
# automatically receives the output of the in_layer and the output of the last unet layer.
self.out_layer = [
Conv3dSame(2 * nf0, out_channels, kernel_size=3, bias=outermost_linear)
]
if not outermost_linear:
if norm is not None:
self.out_layer += [norm(out_channels, affine=True)]
self.out_layer += [nn.ReLU(True)]
self.out_layer = nn.Sequential(*self.out_layer)
def forward(self, x):
in_layer = self.in_layer(x)
unet = self.unet_block(in_layer)
out_layer = self.out_layer(unet)
return out_layer
class UnetSkipConnectionBlock3d(nn.Module):
"""Helper class for building a 3D unet."""
def __init__(self, outer_nc, inner_nc, norm=nn.BatchNorm3d, submodule=None):
super().__init__()
if submodule is None:
model = [
DownBlock3D(outer_nc, inner_nc, norm=norm),
UpBlock3D(inner_nc, outer_nc, norm=norm),
]
else:
model = [
DownBlock3D(outer_nc, inner_nc, norm=norm),
submodule,
UpBlock3D(2 * inner_nc, outer_nc, norm=norm),
]
self.model = nn.Sequential(*model)
def forward(self, x):
forward_passed = self.model(x)
return torch.cat([x, forward_passed], 1)
class UnetSkipConnectionBlock(nn.Module):
"""Helper class for building a 2D unet."""
def __init__(
self,
outer_nc,
inner_nc,
upsampling_mode,
norm=nn.BatchNorm2d,
submodule=None,
use_dropout=False,
dropout_prob=0.1,
):
super().__init__()
if submodule is None:
model = [
DownBlock(
outer_nc,
inner_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
),
UpBlock(
inner_nc,
outer_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
upsampling_mode=upsampling_mode,
),
]
else:
model = [
DownBlock(
outer_nc,
inner_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
),
submodule,
UpBlock(
2 * inner_nc,
outer_nc,
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
upsampling_mode=upsampling_mode,
),
]
self.model = nn.Sequential(*model)
def forward(self, x):
forward_passed = self.model(x)
return torch.cat([x, forward_passed], 1)
class Unet(nn.Module):
"""A 2d-Unet implementation with sane defaults."""
def __init__(
self,
in_channels,
out_channels,
nf0,
num_down,
max_channels,
use_dropout,
upsampling_mode="transpose",
dropout_prob=0.1,
norm=nn.BatchNorm2d,
outermost_linear=False,
):
"""
:param in_channels: Number of input channels
:param out_channels: Number of output channels
:param nf0: Number of features at highest level of U-Net
:param num_down: Number of downsampling stages.
:param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage)
:param use_dropout: Whether to use dropout or no.
:param dropout_prob: Dropout probability if use_dropout=True.
:param upsampling_mode: Which type of upsampling should be used. See "UpBlock" for documentation.
:param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity.
:param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one.
"""
super().__init__()
assert num_down > 0, "Need at least one downsampling layer in UNet."
# Define the in block
self.in_layer = [
Conv2dSame(
in_channels, nf0, kernel_size=3, bias=True if norm is None else False
)
]
if norm is not None:
self.in_layer += [norm(nf0, affine=True)]
self.in_layer += [nn.LeakyReLU(0.2, True)]
if use_dropout:
self.in_layer += [nn.Dropout2d(dropout_prob)]
self.in_layer = nn.Sequential(*self.in_layer)
# Define the center UNet block
self.unet_block = UnetSkipConnectionBlock(
min(2 ** (num_down - 1) * nf0, max_channels),
min(2 ** (num_down - 1) * nf0, max_channels),
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=None, # Innermost has no norm (spatial dimension 1)
upsampling_mode=upsampling_mode,
)
for i in list(range(0, num_down - 1))[::-1]:
self.unet_block = UnetSkipConnectionBlock(
min(2 ** i * nf0, max_channels),
min(2 ** (i + 1) * nf0, max_channels),
use_dropout=use_dropout,
dropout_prob=dropout_prob,
submodule=self.unet_block,
norm=norm,
upsampling_mode=upsampling_mode,
)
# Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer
# automatically receives the output of the in_layer and the output of the last unet layer.
self.out_layer = [
Conv2dSame(
2 * nf0,
out_channels,
kernel_size=3,
bias=outermost_linear or (norm is None),
)
]
if not outermost_linear:
if norm is not None:
self.out_layer += [norm(out_channels, affine=True)]
self.out_layer += [nn.ReLU(True)]
if use_dropout:
self.out_layer += [nn.Dropout2d(dropout_prob)]
self.out_layer = nn.Sequential(*self.out_layer)
self.out_layer_weight = self.out_layer[0].weight
def forward(self, x):
in_layer = self.in_layer(x)
unet = self.unet_block(in_layer)
out_layer = self.out_layer(unet)
return out_layer
class Identity(nn.Module):
"""Helper module to allow Downsampling and Upsampling nets to default to identity if they receive an empty list."""
def __init__(self):
super().__init__()
def forward(self, input):
return input
class DownsamplingNet(nn.Module):
"""A subnetwork that downsamples a 2D feature map with strided convolutions."""
def __init__(
self,
per_layer_out_ch,
in_channels,
use_dropout,
dropout_prob=0.1,
last_layer_one=False,
norm=nn.BatchNorm2d,
):
"""
:param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of
list defines number of downsampling steps (each step dowsamples by factor of 2.)
:param in_channels: Number of input channels.
:param use_dropout: Whether or not to use dropout.
:param dropout_prob: Dropout probability.
:param last_layer_one: Whether the output of the last layer will have a spatial size of 1. In that case,
the last layer will not have batchnorm, else, it will.
:param norm: Which norm to use. Defaults to BatchNorm.
"""
super().__init__()
if not len(per_layer_out_ch):
self.downs = Identity()
else:
self.downs = list()
self.downs.append(
DownBlock(
in_channels,
per_layer_out_ch[0],
use_dropout=use_dropout,
dropout_prob=dropout_prob,
middle_channels=per_layer_out_ch[0],
norm=norm,
)
)
for i in range(0, len(per_layer_out_ch) - 1):
if last_layer_one and (i == len(per_layer_out_ch) - 2):
norm = None
self.downs.append(
DownBlock(
per_layer_out_ch[i],
per_layer_out_ch[i + 1],
dropout_prob=dropout_prob,
use_dropout=use_dropout,
norm=norm,
)
)
self.downs = nn.Sequential(*self.downs)
def forward(self, input):
return self.downs(input)
class UpsamplingNet(nn.Module):
"""A subnetwork that upsamples a 2D feature map with a variety of upsampling options."""
def __init__(
self,
per_layer_out_ch,
in_channels,
upsampling_mode,
use_dropout,
dropout_prob=0.1,
first_layer_one=False,
norm=nn.BatchNorm2d,
):
"""
:param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of
list defines number of upsampling steps (each step upsamples by factor of 2.)
:param in_channels: Number of input channels.
:param upsampling_mode: Mode of upsampling. For documentation, see class "UpBlock"
:param use_dropout: Whether or not to use dropout.
:param dropout_prob: Dropout probability.
:param first_layer_one: Whether the input to the last layer will have a spatial size of 1. In that case,
the first layer will not have a norm, else, it will.
:param norm: Which norm to use. Defaults to BatchNorm.
"""
super().__init__()
if not len(per_layer_out_ch):
self.ups = Identity()
else:
self.ups = list()
self.ups.append(
UpBlock(
in_channels,
per_layer_out_ch[0],
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=None if first_layer_one else norm,
upsampling_mode=upsampling_mode,
)
)
for i in range(0, len(per_layer_out_ch) - 1):
self.ups.append(
UpBlock(
per_layer_out_ch[i],
per_layer_out_ch[i + 1],
use_dropout=use_dropout,
dropout_prob=dropout_prob,
norm=norm,
upsampling_mode=upsampling_mode,
)
)
self.ups = nn.Sequential(*self.ups)
def forward(self, input):
return self.ups(input)

View File

View File

@ -0,0 +1,142 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# TODO: all this potentially goes to PyTorch3D
import math
from typing import Tuple
import pytorch3d as pt3d
import torch
from pytorch3d.renderer.cameras import CamerasBase
def jitter_extrinsics(
R: torch.Tensor,
T: torch.Tensor,
max_angle: float = (math.pi * 2.0),
translation_std: float = 1.0,
scale_std: float = 0.3,
):
"""
Jitter the extrinsic camera parameters `R` and `T` with a random similarity
transformation. The transformation rotates by a random angle between [0, max_angle];
scales by a random factor exp(N(0, scale_std)), where N(0, scale_std) is
a random sample from a normal distrubtion with zero mean and variance scale_std;
and translates by a 3D offset sampled from N(0, translation_std).
"""
assert all(x >= 0.0 for x in (max_angle, translation_std, scale_std))
N = R.shape[0]
R_jit = pt3d.transforms.random_rotations(1, device=R.device)
R_jit = pt3d.transforms.so3_exponential_map(
pt3d.transforms.so3_log_map(R_jit) * max_angle
)
T_jit = torch.randn_like(R_jit[:1, :, 0]) * translation_std
rigid_transform = pt3d.ops.eyes(dim=4, N=N, device=R.device)
rigid_transform[:, :3, :3] = R_jit.expand(N, 3, 3)
rigid_transform[:, 3, :3] = T_jit.expand(N, 3)
scale_jit = torch.exp(torch.randn_like(T_jit[:, 0]) * scale_std).expand(N)
return apply_camera_alignment(R, T, rigid_transform, scale_jit)
def apply_camera_alignment(
R: torch.Tensor,
T: torch.Tensor,
rigid_transform: torch.Tensor,
scale: torch.Tensor,
):
"""
Args:
R: Camera rotation matrix of shape (N, 3, 3).
T: Camera translation of shape (N, 3).
rigid_transform: A tensor of shape (N, 4, 4) representing a batch of
N 4x4 tensors that map the scene pointcloud from misaligned coords
to the aligned space.
scale: A list of N scaling factors. A tensor of shape (N,)
Returns:
R_aligned: The aligned rotations R.
T_aligned: The aligned translations T.
"""
R_rigid = rigid_transform[:, :3, :3]
T_rigid = rigid_transform[:, 3:, :3]
R_aligned = R_rigid.permute(0, 2, 1).bmm(R)
T_aligned = scale[:, None] * (T - (T_rigid @ R_aligned)[:, 0])
return R_aligned, T_aligned
def get_min_max_depth_bounds(cameras, scene_center, scene_extent):
"""
Estimate near/far depth plane as:
near = dist(cam_center, self.scene_center) - self.scene_extent
far = dist(cam_center, self.scene_center) + self.scene_extent
"""
cam_center = cameras.get_camera_center()
center_dist = (
((cam_center - scene_center.to(cameras.R)[None]) ** 2)
.sum(dim=-1)
.clamp(0.001)
.sqrt()
)
center_dist = center_dist.clamp(scene_extent + 1e-3)
min_depth = center_dist - scene_extent
max_depth = center_dist + scene_extent
return min_depth, max_depth
def volumetric_camera_overlaps(
cameras: CamerasBase,
scene_extent: float = 8.0,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
resol: int = 16,
weigh_by_ray_angle: bool = True,
):
"""
Compute the overlaps between viewing frustrums of all pairs of cameras
in `cameras`.
"""
device = cameras.device
ba = cameras.R.shape[0]
n_vox = int(resol ** 3)
grid = pt3d.structures.Volumes(
densities=torch.zeros([1, 1, resol, resol, resol], device=device),
volume_translation=-torch.FloatTensor(scene_center)[None].to(device),
voxel_size=2.0 * scene_extent / resol,
).get_coord_grid(world_coordinates=True)
grid = grid.view(1, n_vox, 3).expand(ba, n_vox, 3)
gridp = cameras.transform_points(grid, eps=1e-2)
proj_in_camera = (
torch.prod((gridp[..., :2].abs() <= 1.0), dim=-1)
* (gridp[..., 2] > 0.0).float()
) # ba x n_vox
if weigh_by_ray_angle:
rays = torch.nn.functional.normalize(
grid - cameras.get_camera_center()[:, None], dim=-1
)
rays_masked = rays * proj_in_camera[..., None]
# - slow and readable:
# inter = torch.zeros(ba, ba)
# for i1 in range(ba):
# for i2 in range(ba):
# inter[i1, i2] = (
# 1 + (rays_masked[i1] * rays_masked[i2]
# ).sum(dim=-1)).sum()
# - fast:
rays_masked = rays_masked.view(ba, n_vox * 3)
inter = n_vox + (rays_masked @ rays_masked.t())
else:
inter = proj_in_camera @ proj_in_camera.t()
mass = torch.diag(inter)
iou = inter / (mass[:, None] + mass[None, :] - inter).clamp(0.1)
return iou

View File

@ -0,0 +1,231 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass
from math import pi
from typing import Optional
import torch
from pytorch3d.common.compat import eigh, lstsq
def _get_rotation_to_best_fit_xy(
points: torch.Tensor, centroid: torch.Tensor
) -> torch.Tensor:
"""
Returns a rotation r such that points @ r has a best fit plane
parallel to the xy plane
Args:
points: (N, 3) tensor of points in 3D
centroid: (3,) their centroid
Returns:
(3,3) tensor rotation matrix
"""
points_centered = points - centroid[None]
return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]]
def _signed_area(path: torch.Tensor) -> torch.Tensor:
"""
Calculates the signed area / Lévy area of a 2D path. If the path is closed,
i.e. ends where it starts, this is the integral of the winding number over
the whole plane. If not, consider a closed path made by adding a straight
line from the end to the start; the signed area is the integral of the
winding number (also over the plane) with respect to that closed path.
If this number is positive, it indicates in some sense that the path
turns anticlockwise more than clockwise, and vice versa.
Args:
path: N x 2 tensor of points.
Returns:
signed area, shape ()
"""
# This calculation is a sum of areas of triangles of the form
# (path[0], path[i], path[i+1]), where each triangle is half a
# parallelogram.
x, y = (path[1:] - path[:1]).unbind(1)
return (y[1:] * x[:-1] - x[1:] * y[:-1]).sum() * 0.5
@dataclass(frozen=True)
class Circle2D:
"""
Contains details of a circle in a plane.
Members
center: tensor shape (2,)
radius: tensor shape ()
generated_points: points around the circle, shape (n_points, 2)
"""
center: torch.Tensor
radius: torch.Tensor
generated_points: torch.Tensor
def fit_circle_in_2d(
points2d, *, n_points: int = 0, angles: Optional[torch.Tensor] = None
) -> Circle2D:
"""
Simple best fitting of a circle to 2D points. In particular, the circle which
minimizes the sum of the squares of the squared-distances to the circle.
Finds (a,b) and r to minimize the sum of squares (over the x,y pairs) of
r**2 - [(x-a)**2+(y-b)**2]
i.e.
(2*a)*x + (2*b)*y + (r**2 - a**2 - b**2)*1 - (x**2 + y**2)
In addition, generates points along the circle. If angles is None (default)
then n_points around the circle equally spaced are given. These begin at the
point closest to the first input point. They continue in the direction which
seems to match the movement of points in points2d, as judged by its
signed area. If `angles` are provided, then n_points is ignored, and points
along the circle at the given angles are returned, with the starting point
and direction as before.
(Note that `generated_points` is affected by the order of the points in
points2d, but the other outputs are not.)
Args:
points2d: N x 2 tensor of 2D points
n_points: number of points to generate on the circle, if angles not given
angles: optional angles in radians of points to generate.
Returns:
Circle2D object
"""
design = torch.cat([points2d, torch.ones_like(points2d[:, :1])], dim=1)
rhs = (points2d ** 2).sum(1)
n_provided = points2d.shape[0]
if n_provided < 3:
raise ValueError(f"{n_provided} points are not enough to determine a circle")
solution = lstsq(design, rhs)
center = solution[:2] / 2
radius = torch.sqrt(solution[2] + (center ** 2).sum())
if n_points > 0:
if angles is not None:
warnings.warn("n_points ignored because angles provided")
else:
angles = torch.linspace(0, 2 * pi, n_points, device=points2d.device)
if angles is not None:
initial_direction_xy = (points2d[0] - center).unbind()
initial_angle = torch.atan2(initial_direction_xy[1], initial_direction_xy[0])
with torch.no_grad():
anticlockwise = _signed_area(points2d) > 0
if anticlockwise:
use_angles = initial_angle + angles
else:
use_angles = initial_angle - angles
generated_points = center[None] + radius * torch.stack(
[torch.cos(use_angles), torch.sin(use_angles)], dim=-1
)
else:
generated_points = points2d.new_zeros(0, 2)
return Circle2D(center=center, radius=radius, generated_points=generated_points)
@dataclass(frozen=True)
class Circle3D:
"""
Contains details of a circle in 3D.
Members
center: tensor shape (3,)
radius: tensor shape ()
normal: tensor shape (3,)
generated_points: points around the circle, shape (n_points, 3)
"""
center: torch.Tensor
radius: torch.Tensor
normal: torch.Tensor
generated_points: torch.Tensor
def fit_circle_in_3d(
points,
*,
n_points: int = 0,
angles: Optional[torch.Tensor] = None,
offset: Optional[torch.Tensor] = None,
up: Optional[torch.Tensor] = None,
) -> Circle3D:
"""
Simple best fit circle to 3D points. Uses circle_2d in the
least-squares best fit plane.
In addition, generates points along the circle. If angles is None (default)
then n_points around the circle equally spaced are given. These begin at the
point closest to the first input point. They continue in the direction which
seems to be match the movement of points. If angles is provided, then n_points
is ignored, and points along the circle at the given angles are returned,
with the starting point and direction as before.
Further, an offset can be given to add to the generated points; this is
interpreted in a rotated coordinate system where (0, 0, 1) is normal to the
circle, specifically the normal which is approximately in the direction of a
given `up` vector. The remaining rotation is disambiguated in an unspecified
but deterministic way.
(Note that `generated_points` is affected by the order of the points in
points, but the other outputs are not.)
Args:
points2d: N x 3 tensor of 3D points
n_points: number of points to generate on the circle
angles: optional angles in radians of points to generate.
offset: optional tensor (3,), a displacement expressed in a "canonical"
coordinate system to add to the generated points.
up: optional tensor (3,), a vector which helps define the
"canonical" coordinate system for interpretting `offset`.
Required if offset is used.
Returns:
Circle3D object
"""
centroid = points.mean(0)
r = _get_rotation_to_best_fit_xy(points, centroid)
normal = r[:, 2]
rotated_points = (points - centroid) @ r
result_2d = fit_circle_in_2d(
rotated_points[:, :2], n_points=n_points, angles=angles
)
center_3d = result_2d.center @ r[:, :2].t() + centroid
n_generated_points = result_2d.generated_points.shape[0]
if n_generated_points > 0:
generated_points_in_plane = torch.cat(
[
result_2d.generated_points,
torch.zeros_like(result_2d.generated_points[:, :1]),
],
dim=1,
)
if offset is not None:
if up is None:
raise ValueError("Missing `up` input for interpreting offset")
with torch.no_grad():
swap = torch.dot(up, normal) < 0
if swap:
# We need some rotation which takes +z to -z. Here's one.
generated_points_in_plane += offset * offset.new_tensor([1, -1, -1])
else:
generated_points_in_plane += offset
generated_points = generated_points_in_plane @ r.t() + centroid
else:
generated_points = points.new_zeros(0, 3)
return Circle3D(
radius=result_2d.radius,
center=center_3d,
normal=normal,
generated_points=generated_points,
)

View File

@ -0,0 +1,714 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
import dataclasses
import inspect
import warnings
from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast
from omegaconf import DictConfig, OmegaConf, open_dict
"""
This functionality allows a configurable system to be determined in a dataclass-type
way. It is a generalization of omegaconf's "structured", in the dataclass case.
Core functionality:
- Configurable -- A base class used to label a class as being one which uses this
system. Uses class members and __post_init__ like a dataclass.
- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically.
- get_default_args -- gets an omegaconf.DictConfig for initializing
a given class or calling a given function.
- run_auto_creation -- Initialises nested members. To be called in __post_init__.
In addition, a Configurable may contain members whose type is decided at runtime.
- ReplaceableBase -- As a base instead of Configurable, labels a class to say that
any child class can be used instead.
- registry -- A global store of named child classes of ReplaceableBase classes.
Used as `@registry.register` decorator on class definition.
Additional utility functions:
- remove_unused_components -- used for simplifying a DictConfig instance.
- get_default_args_field -- default for DictConfig member of another configurable.
1. The simplest usage of this functionality is as follows. First a schema is defined
in dataclass style.
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
It can be used like
b_args = get_default_args(B)
b = B(**b_args)
In this case, get_default_args(B) returns an omegaconf.DictConfig with the right
members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to
something like the following. (The modification itself is done by the function
`expand_args_fields`, which is called inside `get_default_args`.)
@dataclasses.dataclass
class A:
n: int = 9
@dataclasses.dataclass
class B:
a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9}))
def __post_init__(self):
self.a = A(**self.a_args)
2. Pluggability. Instead of a dataclass-style member being given a concrete class,
you can give a base class and the implementation is looked up by name in the global
`registry` in this module. E.g.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
will expand to
@dataclasses.dataclass
class A:
k: int = 1
@dataclasses.dataclass
class A1(A):
m: int = 3
@dataclasses.dataclass
class A2(A):
n: str = "2"
@dataclasses.dataclass
class B:
a_class_type: str = "A2"
a_A1_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
a_A2_args: DictConfig = dataclasses.field(
default_factory=lambda: DictConfig({"k": 1, "m": 3}
)
def __post_init__(self):
if self.a_class_type == "A1":
self.a = A1(**self.a_A1_args)
elif self.a_class_type == "A2":
self.a = A2(**self.a_A2_args)
else:
raise ValueError(...)
3. Aside from these classes, the members of these classes should be things
which DictConfig is happy with: e.g. (bool, int, str, None, float) and what
can be built from them with DictConfigs and lists of them.
In addition, you can call get_default_args on a function or class to get
the DictConfig of its defaulted arguments, assuming those are all things
which DictConfig is happy with. If you want to use such a thing as a member
of another configured class, `get_default_args_field` is a helper.
"""
_unprocessed_warning: str = (
" must be processed before it can be used."
+ " This is done by calling expand_args_fields "
+ "or get_default_args on it."
)
TYPE_SUFFIX: str = "_class_type"
ARGS_SUFFIX: str = "_args"
class ReplaceableBase:
"""
Base class for dataclass-style classes which
can be stored in the registry.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj = super().__new__(cls)
if cls is not ReplaceableBase and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
return obj
class Configurable:
"""
This indicates a class which is not ReplaceableBase
but still needs to be
expanded into a dataclass with expand_args_fields.
This expansion is delayed.
"""
def __new__(cls, *args, **kwargs):
"""
This function only exists to raise a
warning if class construction is attempted
without processing.
"""
obj = super().__new__(cls)
if cls is not Configurable and not _is_actually_dataclass(cls):
warnings.warn(cls.__name__ + _unprocessed_warning)
return obj
_X = TypeVar("X", bound=ReplaceableBase)
class _Registry:
"""
Register from names to classes. In particular, we say that direct subclasses of
ReplaceableBase are "base classes" and we register subclasses of each base class
in a separate namespace.
"""
def __init__(self) -> None:
self._mapping: Dict[
Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]]
] = defaultdict(dict)
def register(self, some_class: Type[_X]) -> Type[_X]:
"""
A class decorator, to register a class in self.
"""
name = some_class.__name__
self._register(some_class, name=name)
return some_class
def _register(
self,
some_class: Type[ReplaceableBase],
*,
base_class: Optional[Type[ReplaceableBase]] = None,
name: str,
) -> None:
"""
Register a new member.
Args:
cls: the new member
base_class: (optional) what the new member is a type for
name: name for the new member
"""
if base_class is None:
base_class = self._base_class_from_class(some_class)
if base_class is None:
raise ValueError(
f"Cannot register {some_class}. Cannot tell what it is."
)
if some_class is base_class:
raise ValueError(f"Attempted to register the base class {some_class}")
self._mapping[base_class][name] = some_class
def get(
self, base_class_wanted: Type[ReplaceableBase], name: str
) -> Type[ReplaceableBase]:
"""
Retrieve a class from the registry by name
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
name: what to look for
Returns:
class type
"""
if self._is_base_class(base_class_wanted):
base_class = base_class_wanted
else:
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
result = self._mapping[base_class].get(name)
if result is None:
raise ValueError(f"{name} has not been registered.")
if not issubclass(result, base_class_wanted):
raise ValueError(
f"{name} resolves to {result} which does not subclass {base_class_wanted}"
)
return result
def get_all(
self, base_class_wanted: Type[ReplaceableBase]
) -> List[Type[ReplaceableBase]]:
"""
Retrieve all registered implementations from the registry
Args:
base_class_wanted: parent type of type we are looking for.
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
Returns:
list of class types
"""
if self._is_base_class(base_class_wanted):
return list(self._mapping[base_class_wanted].values())
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
return [
class_
for class_ in self._mapping[base_class].values()
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
]
@staticmethod
def _is_base_class(some_class: Type[ReplaceableBase]) -> bool:
"""
Return whether the given type is a direct subclass of ReplaceableBase
and so gets used as a namespace.
"""
return ReplaceableBase in some_class.__bases__
@staticmethod
def _base_class_from_class(
some_class: Type[ReplaceableBase],
) -> Optional[Type[ReplaceableBase]]:
"""
Find the parent class of some_class which inherits ReplaceableBase, or None
"""
for base in some_class.mro()[-3::-1]:
if base is not ReplaceableBase and issubclass(base, ReplaceableBase):
return base
return None
# Global instance of the registry
registry = _Registry()
def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]:
"""
Return the default creation function for a member. This is a function which
could be called in __post_init__ to initialise the member, and will be called
from run_auto_creation.
Args:
name: name of the member
type_: declared type of the member
pluggable: True if the member's declared type inherits ReplaceableBase,
in which case the actual type to be created is decided at
runtime.
Returns:
Function taking one argument, the object whose member should be
initialized.
"""
def inner(self):
expand_args_fields(type_)
args = getattr(self, name + ARGS_SUFFIX)
setattr(self, name, type_(**args))
def inner_pluggable(self):
type_name = getattr(self, name + TYPE_SUFFIX)
chosen_class = registry.get(type_, type_name)
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
# If this warning is raised, it means that a new definition of
# the chosen class has been registered since our class was processed
# (i.e. expanded). A DictConfig which comes from our get_default_args
# (which might have triggered the processing) will contain the old default
# values for the members of the chosen class. Changes to those defaults which
# were made in the redefinition will not be reflected here.
warnings.warn(f"New implementation of {type_name} is being chosen.")
expand_args_fields(chosen_class)
args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}")
setattr(self, name, chosen_class(**args))
return inner_pluggable if pluggable else inner
def run_auto_creation(self: Any) -> None:
"""
Run all the functions named in self._creation_functions.
"""
for create_function in self._creation_functions:
getattr(self, create_function)()
def _is_configurable_class(C) -> bool:
return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase))
def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig:
"""
Get the DictConfig of args to call C - which might be a type or a function.
If C is a subclass of Configurable or ReplaceableBase, we make sure
it has been processed with expand_args_fields. If C is a dataclass,
including a subclass of Configurable or ReplaceableBase, the output
will be a typed DictConfig.
Args:
C: the class or function to be processed
_do_not_process: (internal use) When this function is called from
expand_args_fields, we specify any class currently being
processed, to make sure we don't try to process a class
while it is already being processed.
Returns:
new DictConfig object
"""
if C is None:
return DictConfig({})
if _is_configurable_class(C):
if C in _do_not_process:
raise ValueError(
f"Internal recursion error. Need processed {C},"
f" but cannot get it. _do_not_process={_do_not_process}"
)
# This is safe to run multiple times. It will return
# straight away if C has already been processed.
expand_args_fields(C, _do_not_process=_do_not_process)
kwargs = {}
if dataclasses.is_dataclass(C):
# Note that if get_default_args_field is used somewhere in C,
# this call is recursive. No special care is needed,
# because in practice get_default_args_field is used for
# separate types than the outer type.
out = OmegaConf.structured(C)
exclude = getattr(C, "_processed_members", ())
with open_dict(out):
for field in exclude:
out.pop(field, None)
return out
if _is_configurable_class(C):
raise ValueError(f"Failed to process {C}")
# returns dict of keyword args of a callable C
sig = inspect.signature(C)
for pname, defval in dict(sig.parameters).items():
if defval.default == inspect.Parameter.empty:
# print('skipping %s' % pname)
continue
else:
kwargs[pname] = copy.deepcopy(defval.default)
return DictConfig(kwargs)
def _is_actually_dataclass(some_class) -> bool:
# Return whether the class some_class has been processed with
# the dataclass annotation. This is more specific than
# dataclasses.is_dataclass which returns True on anything
# deriving from a dataclass.
# Checking for __init__ would also work for our purpose.
return "__dataclass_fields__" in some_class.__dict__
def expand_args_fields(
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_X]:
"""
This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function.
For classes of type ReplaceableBase, you can add some_class to the registry before
or after calling this function. But potential inner classes need to be registered
before this function is run on the outer class.
The transformations this function makes, before the concluding
dataclasses.dataclass, are as follows. if X is a base class with registered
subclasses Y and Z, replace
x: X
and optionally
x_class_type: str = "Y"
def create_x(self):...
with
x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
self.x = registry.get(X, self.x_class_type)(
**self.getattr(f"x_{self.x_class_type}_args)
)
x_class_type: str = "UNDEFAULTED"
without adding the optional things if they are already there.
Similarly, if X is a subclass of Configurable,
x: X
and optionally
def create_x(self):...
will be replaced with
x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig())
def create_x(self):
self.x = X(self.x_args)
Also adds the following class members, unannotated so that dataclass
ignores them.
- _creation_functions: Tuple[str] of all the create_ functions,
including those from base classes.
- _known_implementations: Dict[str, Type] containing the classes which
have been found from the registry.
(used only to raise a warning if it one has been overwritten)
- _processed_members: a Set[str] of all the members which have been transformed.
Args:
some_class: the class to be processed
_do_not_process: Internal use for get_default_args: Because get_default_args calls
and is called by this function, we let it specify any class currently
being processed, to make sure we don't try to process a class while
it is already being processed.
Returns:
some_class itself, which has been modified in place. This
allows this function to be used as a class decorator.
"""
if _is_actually_dataclass(some_class):
return some_class
# The functions this class's run_auto_creation will run.
creation_functions: List[str] = []
# The classes which this type knows about from the registry
# We could use a weakref.WeakValueDictionary here which would mean
# that we don't warn if the class we should have expected is elsewhere
# unused.
known_implementations: Dict[str, Type] = {}
# Names of members which have been processed.
processed_members: Set[str] = set()
# For all bases except ReplaceableBase and Configurable and object,
# we need to process them before our own processing. This is
# because dataclasses expect to inherit dataclasses and not unprocessed
# dataclasses.
for base in some_class.mro()[-3:0:-1]:
if base is ReplaceableBase:
continue
if base is Configurable:
continue
if not issubclass(base, (Configurable, ReplaceableBase)):
continue
expand_args_fields(base, _do_not_process=_do_not_process)
if "_creation_functions" in base.__dict__:
creation_functions.extend(base._creation_functions)
if "_known_implementations" in base.__dict__:
known_implementations.update(base._known_implementations)
if "_processed_members" in base.__dict__:
processed_members.update(base._processed_members)
to_process: List[Tuple[str, Type, bool]] = []
if "__annotations__" in some_class.__dict__:
for name, type_ in some_class.__annotations__.items():
if not isinstance(type_, type):
# type_ could be something like typing.Tuple
continue
if (
issubclass(type_, ReplaceableBase)
and ReplaceableBase in type_.__bases__
):
to_process.append((name, type_, True))
elif issubclass(type_, Configurable):
to_process.append((name, type_, False))
for name, type_, pluggable in to_process:
_process_member(
name=name,
type_=type_,
pluggable=pluggable,
some_class=cast(type, some_class),
creation_functions=creation_functions,
_do_not_process=_do_not_process,
known_implementations=known_implementations,
)
processed_members.add(name)
for key, count in Counter(creation_functions).items():
if count > 1:
warnings.warn(f"Clash with {key} in a base class.")
some_class._creation_functions = tuple(creation_functions)
some_class._processed_members = processed_members
some_class._known_implementations = known_implementations
dataclasses.dataclass(eq=False)(some_class)
return some_class
def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()):
"""
Get a dataclass field which defaults to get_default_args(...)
Args:
As for get_default_args.
Returns:
function to return new DictConfig object
"""
def create():
return get_default_args(C, _do_not_process=_do_not_process)
return dataclasses.field(default_factory=create)
def _process_member(
*,
name: str,
type_: Type,
pluggable: bool,
some_class: Type,
creation_functions: List[str],
_do_not_process: Tuple[type, ...],
known_implementations: Dict[str, Type],
) -> None:
"""
Make the modification (of expand_args_fields) to some_class for a single member.
Args:
name: member name
type_: member declared type
plugglable: whether member has dynamic type
some_class: (MODIFIED IN PLACE) the class being processed
creation_functions: (MODIFIED IN PLACE) the names of the create functions
_do_not_process: as for expand_args_fields.
known_implementations: (MODIFIED IN PLACE) known types from the registry
"""
# Because we are adding defaultable members, make
# sure they go at the end of __annotations__ in case
# there are non-defaulted standard class members.
del some_class.__annotations__[name]
if pluggable:
type_name = name + TYPE_SUFFIX
if type_name not in some_class.__annotations__:
some_class.__annotations__[type_name] = str
setattr(some_class, type_name, "UNDEFAULTED")
for derived_type in registry.get_all(type_):
if derived_type in _do_not_process:
continue
if issubclass(derived_type, some_class):
# When derived_type is some_class we have a simple
# recursion to avoid. When it's a strict subclass the
# situation is even worse.
continue
known_implementations[derived_type.__name__] = derived_type
args_name = f"{name}_{derived_type.__name__}{ARGS_SUFFIX}"
if args_name in some_class.__annotations__:
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
some_class.__annotations__[args_name] = DictConfig
setattr(
some_class,
args_name,
get_default_args_field(
derived_type, _do_not_process=_do_not_process + (some_class,)
),
)
else:
args_name = name + ARGS_SUFFIX
if args_name in some_class.__annotations__:
raise ValueError(
f"Cannot generate {args_name} because it is already present."
)
if issubclass(type_, some_class) or type_ in _do_not_process:
raise ValueError(f"Cannot process {type_} inside {some_class}")
some_class.__annotations__[args_name] = DictConfig
setattr(
some_class,
args_name,
get_default_args_field(
type_,
_do_not_process=_do_not_process + (some_class,),
),
)
creation_function_name = f"create_{name}"
if not hasattr(some_class, creation_function_name):
setattr(
some_class,
creation_function_name,
_default_create(name, type_, pluggable),
)
creation_functions.append(creation_function_name)
def remove_unused_components(dict_: DictConfig) -> None:
"""
Assuming dict_ represents the state of a configurable,
modify it to remove all the portions corresponding to
pluggable parts which are not in use.
For example, if renderer_class_type is SignedDistanceFunctionRenderer,
the renderer_MultiPassEmissionAbsorptionRenderer_args will be
removed.
Args:
dict_: (MODIFIED IN PLACE) a DictConfig instance
"""
keys = [key for key in dict_ if isinstance(key, str)]
suffix_length = len(TYPE_SUFFIX)
replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)]
args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)]
for replaceable in replaceables:
selected_type = dict_[replaceable + TYPE_SUFFIX]
expect = replaceable + "_" + selected_type + ARGS_SUFFIX
with open_dict(dict_):
for key in args_keys:
if key.startswith(replaceable + "_") and key != expect:
del dict_[key]
for key in dict_:
if isinstance(dict_.get(key), DictConfig):
remove_unused_components(dict_[key])

View File

@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as Fu
from pytorch3d.ops import wmean
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
def cleanup_eval_depth(
point_cloud: Pointclouds,
camera: CamerasBase,
depth: torch.Tensor,
mask: torch.Tensor,
sigma: float = 0.01,
image=None,
):
ba, _, H, W = depth.shape
pcl = point_cloud.points_padded()
n_pts = point_cloud.num_points_per_cloud()
pcl_mask = (
torch.arange(pcl.shape[1], dtype=torch.int64, device=pcl.device)[None]
< n_pts[:, None]
).type_as(pcl)
pcl_proj = camera.transform_points(pcl, eps=1e-2)[..., :-1]
pcl_depth = camera.get_world_to_view_transform().transform_points(pcl)[..., -1]
depth_and_idx = torch.cat(
(
depth,
torch.arange(H * W).view(1, 1, H, W).expand(ba, 1, H, W).type_as(depth),
),
dim=1,
)
depth_and_idx_sampled = Fu.grid_sample(
depth_and_idx, -pcl_proj[:, None], mode="nearest"
)[:, :, 0].view(ba, 2, -1)
depth_sampled, idx_sampled = depth_and_idx_sampled.split([1, 1], dim=1)
df = (depth_sampled[:, 0] - pcl_depth).abs()
# the threshold is a sigma-multiple of the standard deviation of the depth
mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1)
std = (
wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1))
.clamp(1e-4)
.sqrt()
.view(ba, -1)
)
good_df_thr = std * sigma
good_depth = (df <= good_df_thr).float() * pcl_mask
perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1)
# print(f'Kept {100.0 * perc_kept.mean():1.3f} % points')
good_depth_raster = torch.zeros_like(depth).view(ba, -1)
# pyre-ignore[16]: scatter_add_
good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth)
good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float()
# if float(torch.rand(1)) > 0.95:
# depth_ok = depth * good_depth_mask
# # visualize
# visdom_env = 'depth_cleanup_dbg'
# from visdom import Visdom
# # from tools.vis_utils import make_depth_image
# from pytorch3d.vis.plotly_vis import plot_scene
# viz = Visdom()
# show_pcls = {
# 'pointclouds': point_cloud,
# }
# for d, nm in zip(
# (depth, depth_ok),
# ('pointclouds_unproj', 'pointclouds_unproj_ok'),
# ):
# pointclouds_unproj = get_rgbd_point_cloud(
# camera, image, d,
# )
# if int(pointclouds_unproj.num_points_per_cloud()) > 0:
# show_pcls[nm] = pointclouds_unproj
# scene_dict = {'1': {
# **show_pcls,
# 'cameras': camera,
# }}
# scene = plot_scene(
# scene_dict,
# pointcloud_max_points=5000,
# pointcloud_marker_size=1.5,
# camera_scale=1.0,
# )
# viz.plotlyplot(scene, env=visdom_env, win='scene')
# # depth_image_ok = make_depth_image(depths_ok, masks)
# # viz.images(depth_image_ok, env=visdom_env, win='depth_ok')
# # depth_image = make_depth_image(depths, masks)
# # viz.images(depth_image, env=visdom_env, win='depth')
# # # viz.images(rgb_rendered, env=visdom_env, win='images_render')
# # viz.images(images, env=visdom_env, win='images')
# import pdb; pdb.set_trace()
return good_depth_mask

View File

@ -0,0 +1,226 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform
from pytorch3d.transforms import Scale
def generate_eval_video_cameras(
train_cameras,
n_eval_cams: int = 100,
trajectory_type: str = "figure_eight",
trajectory_scale: float = 0.2,
scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0),
up: Tuple[float, float, float] = (0.0, 0.0, 1.0),
focal_length: Optional[torch.FloatTensor] = None,
principal_point: Optional[torch.FloatTensor] = None,
time: Optional[torch.FloatTensor] = None,
infer_up_as_plane_normal: bool = True,
traj_offset: Optional[Tuple[float, float, float]] = None,
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
) -> PerspectiveCameras:
"""
Generate a camera trajectory rendering a scene from multiple viewpoints.
Args:
train_dataset: The training dataset object.
n_eval_cams: Number of cameras in the trajectory.
trajectory_type: The type of the camera trajectory. Can be one of:
circular_lsq_fit: Camera centers follow a trajectory obtained
by fitting a 3D circle to train_cameras centers.
All cameras are looking towards scene_center.
figure_eight: Figure-of-8 trajectory around the center of the
central camera of the training dataset.
trefoil_knot: Same as 'figure_eight', but the trajectory has a shape
of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot).
figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape
of a figure-eight knot
(https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)).
trajectory_scale: The extent of the trajectory.
up: The "up" vector of the scene (=the normal of the scene floor).
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
cam_centers = train_cameras.get_camera_center()
# get the nearest camera center to the mean of centers
mean_camera_idx = (
((cam_centers - cam_centers.mean(dim=0)[None]) ** 2)
.sum(dim=1)
.min(dim=0)
.indices
)
# generate the knot trajectory in canonical coords
if time is None:
time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams]
else:
assert time.numel() == n_eval_cams
if trajectory_type == "trefoil_knot":
traj = _trefoil_knot(time)
elif trajectory_type == "figure_eight_knot":
traj = _figure_eight_knot(time)
elif trajectory_type == "figure_eight":
traj = _figure_eight(time)
else:
raise ValueError(f"bad trajectory type: {trajectory_type}")
traj[:, 2] -= traj[:, 2].max()
# transform the canonical knot to the coord frame of the mean camera
mean_camera = PerspectiveCameras(
**{
k: getattr(train_cameras, k)[[int(mean_camera_idx)]]
for k in ("focal_length", "principal_point", "R", "T")
}
)
traj_trans = Scale(cam_centers.std(dim=0).mean() * trajectory_scale).compose(
mean_camera.get_world_to_view_transform().inverse()
)
if traj_offset_canonical is not None:
traj_trans = traj_trans.translate(
torch.FloatTensor(traj_offset_canonical)[None].to(traj)
)
traj = traj_trans.transform_points(traj)
plane_normal = _fit_plane(cam_centers)[:, 0]
if infer_up_as_plane_normal:
up = _disambiguate_normal(plane_normal, up)
elif trajectory_type == "circular_lsq_fit":
### fit plane to the camera centers
# get the center of the plane as the median of the camera centers
cam_centers = train_cameras.get_camera_center()
if time is not None:
angle = time
else:
angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams).to(cam_centers)
fit = fit_circle_in_3d(
cam_centers,
angles=angle,
offset=angle.new_tensor(traj_offset_canonical)
if traj_offset_canonical is not None
else None,
up=angle.new_tensor(up),
)
traj = fit.generated_points
# scalethe trajectory
_t_mu = traj.mean(dim=0, keepdim=True)
traj = (traj - _t_mu) * trajectory_scale + _t_mu
plane_normal = fit.normal
if infer_up_as_plane_normal:
up = _disambiguate_normal(plane_normal, up)
else:
raise ValueError(f"Uknown trajectory_type {trajectory_type}.")
if traj_offset is not None:
traj = traj + torch.FloatTensor(traj_offset)[None].to(traj)
# point all cameras towards the center of the scene
R, T = look_at_view_transform(
eye=traj,
at=(scene_center,), # (1, 3)
up=(up,), # (1, 3)
device=traj.device,
)
# get the average focal length and principal point
if focal_length is None:
focal_length = train_cameras.focal_length.mean(dim=0).repeat(n_eval_cams, 1)
if principal_point is None:
principal_point = train_cameras.principal_point.mean(dim=0).repeat(
n_eval_cams, 1
)
test_cameras = PerspectiveCameras(
focal_length=focal_length,
principal_point=principal_point,
R=R,
T=T,
device=focal_length.device,
)
# _visdom_plot_scene(
# train_cameras,
# test_cameras,
# )
return test_cameras
def _disambiguate_normal(normal, up):
up_t = torch.tensor(up).to(normal)
flip = (up_t * normal).sum().sign()
up = normal * flip
up = up.tolist()
return up
def _fit_plane(x):
x = x - x.mean(dim=0)[None]
cov = (x.t() @ x) / x.shape[0]
_, e_vec = eigh(cov)
return e_vec
def _visdom_plot_scene(
train_cameras,
test_cameras,
) -> None:
from pytorch3d.vis.plotly_vis import plot_scene
p = plot_scene(
{
"scene": {
"train_cams": train_cameras,
"test_cams": test_cameras,
}
}
)
from visdom import Visdom
viz = Visdom()
viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs")
import pdb
pdb.set_trace()
def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5):
x = (2 + (2 * t).cos()) * (3 * t).cos()
y = (2 + (2 * t).cos()) * (3 * t).sin()
z = (4 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5):
x = t.sin() + 2 * (2 * t).sin()
y = t.cos() - 2 * (2 * t).cos()
z = -(3 * t).sin() * z_scale
return torch.stack((x, y, z), dim=-1)
def _figure_eight(t: torch.Tensor, z_scale: float = 0.5):
x = t.cos()
y = (2 * t).sin() / 2
z = t.sin() * z_scale
return torch.stack((x, y, z), dim=-1)

View File

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import torch
def mask_background(
image_rgb: torch.Tensor,
mask_fg: torch.Tensor,
dim_color: int = 1,
bg_color: Union[torch.Tensor, str, float] = 0.0,
) -> torch.Tensor:
"""
Mask the background input image tensor `image_rgb` with `bg_color`.
The background regions are obtained from the binary foreground segmentation
mask `mask_fg`.
"""
tgt_view = [1, 1, 1, 1]
tgt_view[dim_color] = 3
# obtain the background color tensor
if isinstance(bg_color, torch.Tensor):
bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb)
elif isinstance(bg_color, float):
bg_color_t = torch.tensor(
[bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype
).view(*tgt_view)
elif isinstance(bg_color, str):
if bg_color == "white":
bg_color_t = image_rgb.new_ones(tgt_view)
elif bg_color == "black":
bg_color_t = image_rgb.new_zeros(tgt_view)
else:
raise ValueError(_invalid_color_error_msg(bg_color))
else:
raise ValueError(_invalid_color_error_msg(bg_color))
# cast to the image_rgb's type
mask_fg = mask_fg.type_as(image_rgb)
# mask the bg
image_masked = mask_fg * image_rgb + (1 - mask_fg) * bg_color_t
return image_masked
def _invalid_color_error_msg(bg_color) -> str:
return (
f"Invalid bg_color={bg_color}. Plese set bg_color to a 3-element"
+ " tensor. or a string (white | black), or a float."
)

View File

@ -0,0 +1,231 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Optional, Tuple
import torch
from torch.nn import functional as F
def eval_depth(
pred: torch.Tensor,
gt: torch.Tensor,
crop: int = 1,
mask: Optional[torch.Tensor] = None,
get_best_scale: bool = True,
mask_thr: float = 0.5,
best_scale_clamp_thr: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Evaluate the depth error between the prediction `pred` and the ground
truth `gt`.
Args:
pred: A tensor of shape (N, 1, H, W) denoting the predicted depth maps.
gt: A tensor of shape (N, 1, H, W) denoting the ground truth depth maps.
crop: The number of pixels to crop from the border.
mask: A mask denoting the valid regions of the gt depth.
get_best_scale: If `True`, estimates a scaling factor of the predicted depth
that yields the best mean squared error between `pred` and `gt`.
This is typically enabled for cases where predicted reconstructions
are inherently defined up to an arbitrary scaling factor.
mask_thr: A constant used to threshold the `mask` to specify the valid
regions.
best_scale_clamp_thr: The threshold for clamping the divisor in best
scale estimation.
Returns:
mse_depth: Mean squared error between `pred` and `gt`.
abs_depth: Mean absolute difference between `pred` and `gt`.
"""
# chuck out the border
if crop > 0:
gt = gt[:, :, crop:-crop, crop:-crop]
pred = pred[:, :, crop:-crop, crop:-crop]
if mask is not None:
# mult gt by mask
if crop > 0:
mask = mask[:, :, crop:-crop, crop:-crop]
gt = gt * (mask > mask_thr).float()
dmask = (gt > 0.0).float()
dmask_mass = torch.clamp(dmask.sum((1, 2, 3)), 1e-4)
if get_best_scale:
# mult preds by a scalar "scale_best"
# s.t. we get best possible mse error
scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr)
pred = pred * scale_best[:, None, None, None]
df = gt - pred
mse_depth = (dmask * (df ** 2)).sum((1, 2, 3)) / dmask_mass
abs_depth = (dmask * df.abs()).sum((1, 2, 3)) / dmask_mass
return mse_depth, abs_depth
def estimate_depth_scale_factor(pred, gt, mask, clamp_thr):
xy = pred * gt * mask
xx = pred * pred * mask
scale_best = xy.mean((1, 2, 3)) / torch.clamp(xx.mean((1, 2, 3)), clamp_thr)
return scale_best
def calc_psnr(
x: torch.Tensor,
y: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`.
"""
mse = calc_mse(x, y, mask=mask)
psnr = torch.log10(mse.clamp(1e-10)) * (-10.0)
return psnr
def calc_mse(
x: torch.Tensor,
y: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Calculates the mean square error between tensors `x` and `y`.
"""
if mask is None:
return torch.mean((x - y) ** 2)
else:
return (((x - y) ** 2) * mask).sum() / mask.expand_as(x).sum().clamp(1e-5)
def calc_bce(
pred: torch.Tensor,
gt: torch.Tensor,
equal_w: bool = True,
pred_eps: float = 0.01,
mask: Optional[torch.Tensor] = None,
lerp_bound: Optional[float] = None,
) -> torch.Tensor:
"""
Calculates the binary cross entropy.
"""
if pred_eps > 0.0:
# up/low bound the predictions
pred = torch.clamp(pred, pred_eps, 1.0 - pred_eps)
if mask is None:
mask = torch.ones_like(gt)
if equal_w:
mask_fg = (gt > 0.5).float() * mask
mask_bg = (1 - mask_fg) * mask
weight = mask_fg / mask_fg.sum().clamp(1.0) + mask_bg / mask_bg.sum().clamp(1.0)
# weight sum should be at this point ~2
weight = weight * (weight.numel() / weight.sum().clamp(1.0))
else:
weight = torch.ones_like(gt) * mask
if lerp_bound is not None:
return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound)
else:
return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight)
def binary_cross_entropy_lerp(
pred: torch.Tensor,
gt: torch.Tensor,
weight: torch.Tensor,
lerp_bound: float,
):
"""
Binary cross entropy which avoids exploding gradients by linearly
extrapolating the log function for log(1-pred) mad log(pred) whenever
pred or 1-pred is smaller than lerp_bound.
"""
loss = log_lerp(1 - pred, lerp_bound) * (1 - gt) + log_lerp(pred, lerp_bound) * gt
loss_reduced = -(loss * weight).sum() / weight.sum().clamp(1e-4)
return loss_reduced
def log_lerp(x: torch.Tensor, b: float):
"""
Linearly extrapolated log for x < b.
"""
assert b > 0
return torch.where(x >= b, x.log(), math.log(b) + (x - b) / b)
def rgb_l1(
pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Calculates the mean absolute error between the predicted colors `pred`
and ground truth colors `target`.
"""
if mask is None:
mask = torch.ones_like(pred[:, :1])
return ((pred - target).abs() * mask).sum(dim=(1, 2, 3)) / mask.sum(
dim=(1, 2, 3)
).clamp(1)
def huber(dfsq: torch.Tensor, scaling: float = 0.03) -> torch.Tensor:
"""
Calculates the huber function of the input squared error `dfsq`.
The function smoothly transitions from a region with unit gradient
to a hyperbolic function at `dfsq=scaling`.
"""
loss = (safe_sqrt(1 + dfsq / (scaling * scaling), eps=1e-4) - 1) * scaling
return loss
def neg_iou_loss(
predict: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This is a great loss because it emphasizes on the active
regions of the predict and targets
"""
return 1.0 - iou(predict, target, mask=mask)
def safe_sqrt(A: torch.Tensor, eps: float = float(1e-4)) -> torch.Tensor:
"""
performs safe differentiable sqrt
"""
return (torch.clamp(A, float(0)) + eps).sqrt()
def iou(
predict: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
This is a great loss because it emphasizes on the active
regions of the predict and targets
"""
dims = tuple(range(predict.dim())[1:])
if mask is not None:
predict = predict * mask
target = target * mask
intersect = (predict * target).sum(dims)
union = (predict + target - predict * target).sum(dims) + 1e-4
return (intersect / union).sum() / intersect.numel()
def beta_prior(pred: torch.Tensor, cap: float = 0.1) -> torch.Tensor:
if cap <= 0.0:
raise ValueError("capping should be positive to avoid unbound loss")
min_value = math.log(cap) + math.log(cap + 1.0)
return (torch.log(pred + cap) + torch.log(1.0 - pred + cap)).mean() - min_value

View File

@ -0,0 +1,163 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import glob
import os
import shutil
import tempfile
import torch
def load_stats(flstats):
from pytorch3d.implicitron.tools.stats import Stats
try:
stats = Stats.load(flstats)
except:
print("Cant load stats! %s" % flstats)
stats = None
return stats
def get_model_path(fl) -> str:
fl = os.path.splitext(fl)[0]
flmodel = "%s.pth" % fl
return flmodel
def get_optimizer_path(fl) -> str:
fl = os.path.splitext(fl)[0]
flopt = "%s_opt.pth" % fl
return flopt
def get_stats_path(fl, eval_results: bool = False):
fl = os.path.splitext(fl)[0]
if eval_results:
for postfix in ("_2", ""):
flstats = os.path.join(os.path.dirname(fl), f"stats_test{postfix}.jgz")
if os.path.isfile(flstats):
break
else:
flstats = "%s_stats.jgz" % fl
# pyre-fixme[61]: `flstats` is undefined, or not always defined.
return flstats
def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None:
"""
This functions stores model files safely so that no model files exist on the
file system in case the saving procedure gets interrupted.
This is done first by saving the model files to a temporary directory followed
by (atomic) moves to the target location. Note, that this can still result
in a corrupt set of model files in case interruption happens while performing
the moves. It is however quite improbable that a crash would occur right at
this time.
"""
print(f"saving model files safely to {fl}")
# first store everything to a tmpdir
with tempfile.TemporaryDirectory() as tmpdir:
tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1])
stored_tmp_fls = save_model(model, stats, tmpfl, optimizer=optimizer, cfg=cfg)
tgt_fls = [
(
os.path.join(os.path.split(fl)[0], os.path.split(tmpfl)[-1])
if (tmpfl is not None)
else None
)
for tmpfl in stored_tmp_fls
]
# then move from the tmpdir to the right location
for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls):
if tgt_fl is None:
continue
# print(f'Moving {tmpfl} --> {tgt_fl}\n')
shutil.move(tmpfl, tgt_fl)
def save_model(model, stats, fl, optimizer=None, cfg=None):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
print("saving model to %s" % flmodel)
torch.save(model.state_dict(), flmodel)
flopt = None
if optimizer is not None:
flopt = get_optimizer_path(fl)
print("saving optimizer to %s" % flopt)
torch.save(optimizer.state_dict(), flopt)
print("saving model stats to %s" % flstats)
stats.save(flstats)
return flstats, flmodel, flopt
def load_model(fl):
flstats = get_stats_path(fl)
flmodel = get_model_path(fl)
flopt = get_optimizer_path(fl)
model_state_dict = torch.load(flmodel)
stats = load_stats(flstats)
if os.path.isfile(flopt):
optimizer = torch.load(flopt)
else:
optimizer = None
return model_state_dict, stats, optimizer
def parse_epoch_from_model_path(model_path) -> int:
return int(
os.path.split(model_path)[-1].replace(".pth", "").replace("model_epoch_", "")
)
def get_checkpoint(exp_dir, epoch):
fl = os.path.join(exp_dir, "model_epoch_%08d.pth" % epoch)
return fl
def find_last_checkpoint(
exp_dir, any_path: bool = False, all_checkpoints: bool = False
):
if any_path:
exts = [".pth", "_stats.jgz", "_opt.pth"]
else:
exts = [".pth"]
for ext in exts:
fls = sorted(
glob.glob(
os.path.join(glob.escape(exp_dir), "model_epoch_" + "[0-9]" * 8 + ext)
)
)
if len(fls) > 0:
break
# pyre-fixme[61]: `fls` is undefined, or not always defined.
if len(fls) == 0:
fl = None
else:
if all_checkpoints:
# pyre-fixme[61]: `fls` is undefined, or not always defined.
fl = [f[0 : -len(ext)] + ".pth" for f in fls]
else:
fl = fls[-1][0 : -len(ext)] + ".pth"
return fl
def purge_epoch(exp_dir, epoch) -> None:
model_path = get_checkpoint(exp_dir, epoch)
for file_path in [
model_path,
get_optimizer_path(model_path),
get_stats_path(model_path),
]:
if os.path.isfile(file_path):
print("deleting %s" % file_path)
os.remove(file_path)

View File

@ -0,0 +1,168 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple, cast
import torch
import torch.nn.functional as Fu
from pytorch3d.renderer import (
AlphaCompositor,
NDCMultinomialRaysampler,
PointsRasterizationSettings,
PointsRasterizer,
ray_bundle_to_ray_points,
)
from pytorch3d.renderer.cameras import CamerasBase
from pytorch3d.structures import Pointclouds
def get_rgbd_point_cloud(
camera: CamerasBase,
image_rgb: torch.Tensor,
depth_map: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask_thr: float = 0.5,
mask_points: bool = True,
) -> Pointclouds:
"""
Given a batch of images, depths, masks and cameras, generate a colored
point cloud by unprojecting depth maps to the and coloring with the source
pixel colors.
"""
imh, imw = image_rgb.shape[2:]
# convert the depth maps to point clouds using the grid ray sampler
pts_3d = ray_bundle_to_ray_points(
NDCMultinomialRaysampler(
image_width=imw,
image_height=imh,
n_pts_per_ray=1,
min_depth=1.0,
max_depth=1.0,
)(camera)._replace(lengths=depth_map[:, 0, ..., None])
)
pts_mask = depth_map > 0.0
if mask is not None:
pts_mask *= mask > mask_thr
pts_mask = pts_mask.reshape(-1)
pts_3d = pts_3d.reshape(-1, 3)[pts_mask]
pts_colors = torch.nn.functional.interpolate(
image_rgb,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got
# `List[typing.Any]`.
size=[imh, imw],
mode="bilinear",
align_corners=False,
)
pts_colors = pts_colors.permute(0, 2, 3, 1).reshape(-1, 3)[pts_mask]
return Pointclouds(points=pts_3d[None], features=pts_colors[None])
def render_point_cloud_pytorch3d(
camera,
point_cloud,
render_size: Tuple[int, int],
point_radius: float = 0.03,
topk: int = 10,
eps: float = 1e-2,
bg_color=None,
**kwargs
):
# feature dimension
featdim = point_cloud.features_packed().shape[-1]
# move to the camera coordinates; using identity cameras in the renderer
point_cloud = _transform_points(camera, point_cloud, eps, **kwargs)
camera_trivial = camera.clone()
camera_trivial.R[:] = torch.eye(3)
camera_trivial.T *= 0.0
rasterizer = PointsRasterizer(
cameras=camera_trivial,
raster_settings=PointsRasterizationSettings(
image_size=render_size,
radius=point_radius,
points_per_pixel=topk,
bin_size=64 if int(max(render_size)) > 1024 else None,
),
)
fragments = rasterizer(point_cloud, **kwargs)
# Construct weights based on the distance of a point to the true point.
# However, this could be done differently: e.g. predicted as opposed
# to a function of the weights.
r = rasterizer.raster_settings.radius
# set up the blending weights
dists2 = fragments.dists
weights = 1 - dists2 / (r * r)
ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float()
weights = weights * ok
fragments_prm = fragments.idx.long().permute(0, 3, 1, 2)
weights_prm = weights.permute(0, 3, 1, 2)
images = AlphaCompositor()(
fragments_prm,
weights_prm,
point_cloud.features_packed().permute(1, 0),
background_color=bg_color if bg_color is not None else [0.0] * featdim,
**kwargs,
)
# get the depths ...
# weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]]
# cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j])
cumprod = torch.cumprod(1 - weights, dim=-1)
cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1)
depths = (weights * cumprod * fragments.zbuf).sum(dim=-1)
# add the rendering mask
render_mask = -torch.prod(1.0 - weights, dim=-1) + 1.0
# cat depths and render mask
rendered_blob = torch.cat((images, depths[:, None], render_mask[:, None]), dim=1)
# reshape back
rendered_blob = Fu.interpolate(
rendered_blob,
# pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got `Tuple[int,
# ...]`.
size=tuple(render_size),
mode="bilinear",
)
data_rendered, depth_rendered, render_mask = rendered_blob.split(
[rendered_blob.shape[1] - 2, 1, 1],
dim=1,
)
return data_rendered, render_mask, depth_rendered
def _signed_clamp(x, eps):
sign = x.sign() + (x == 0.0).type_as(x)
x_clamp = sign * torch.clamp(x.abs(), eps)
return x_clamp
def _transform_points(cameras, point_clouds, eps, **kwargs):
pts_world = point_clouds.points_padded()
pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points(
pts_world, eps=eps
)
# it is crucial to actually clamp the points as well ...
pts_view = torch.cat(
(pts_view[..., :-1], _signed_clamp(pts_view[..., -1:], eps)), dim=-1
)
point_clouds = point_clouds.update_padded(pts_view)
return point_clouds

View File

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Tuple
import torch
from pytorch3d.renderer import PerspectiveCameras
from pytorch3d.structures import Pointclouds
from .point_cloud_utils import render_point_cloud_pytorch3d
def rasterize_mc_samples(
xys: torch.Tensor,
feats: torch.Tensor,
image_size_hw: Tuple[int, int],
radius: float = 0.03,
topk: int = 5,
masks: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Rasterizes Monte-Carlo sampled features back onto the image.
Specifically, the code uses the PyTorch3D point rasterizer to render
a z-flat point cloud composed of the xy MC locations and their features.
Args:
xys: B x N x 2 2D point locations in PyTorch3D NDC convention
feats: B x N x dim tensor containing per-point rendered features.
image_size_hw: Tuple[image_height, image_width] containing
the size of rasterized image.
radius: Rasterization point radius.
topk: The maximum z-buffer size for the PyTorch3D point cloud rasterizer.
masks: B x N x 1 tensor containing the alpha mask of the
rendered features.
"""
if masks is None:
masks = torch.ones_like(xys[..., :1])
feats = torch.cat((feats, masks), dim=-1)
pointclouds = Pointclouds(
points=torch.cat([xys, torch.ones_like(xys[..., :1])], dim=-1),
features=feats,
)
data_rendered, render_mask, _ = render_point_cloud_pytorch3d(
PerspectiveCameras(device=feats.device),
pointclouds,
render_size=image_size_hw,
point_radius=radius,
topk=topk,
)
data_rendered, masks_pt = data_rendered.split(
[data_rendered.shape[1] - 1, 1], dim=1
)
render_mask = masks_pt * render_mask
return data_rendered, render_mask

View File

@ -0,0 +1,491 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import gzip
import json
import time
import warnings
from collections.abc import Iterable
from itertools import cycle
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors as mcolors
from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.history = []
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1, epoch=0):
# make sure the history is of the same len as epoch
while len(self.history) <= epoch:
self.history.append([])
self.history[epoch].append(val / n)
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def get_epoch_averages(self, epoch=-1):
if len(self.history) == 0: # no stats here
return None
elif epoch == -1:
return [
(float(np.array(x).mean()) if len(x) > 0 else float("NaN"))
for x in self.history
]
else:
return float(np.array(self.history[epoch]).mean())
def get_all_values(self):
all_vals = [np.array(x) for x in self.history]
all_vals = np.concatenate(all_vals)
return all_vals
def get_epoch(self):
return len(self.history)
@staticmethod
def from_json_str(json_str):
self = AverageMeter()
self.__dict__.update(json.loads(json_str))
return self
class Stats(object):
# TODO: update this with context manager
"""
stats logging object useful for gathering statistics of training a deep net in pytorch
Example:
# init stats structure that logs statistics 'objective' and 'top1e'
stats = Stats( ('objective','top1e') )
network = init_net() # init a pytorch module (=nueral network)
dataloader = init_dataloader() # init a dataloader
for epoch in range(10):
# start of epoch -> call new_epoch
stats.new_epoch()
# iterate over batches
for batch in dataloader:
output = network(batch) # run and save into a dict of output variables "output"
# stats.update() automatically parses the 'objective' and 'top1e' from
# the "output" dict and stores this into the db
stats.update(output)
stats.print() # prints the averages over given epoch
# stores the training plots into '/tmp/epoch_stats.pdf'
# and plots into a visdom server running at localhost (if running)
stats.plot_stats(plot_file='/tmp/epoch_stats.pdf')
"""
def __init__(
self,
log_vars,
verbose=False,
epoch=-1,
visdom_env="main",
do_plot=True,
plot_file=None,
visdom_server="http://localhost",
visdom_port=8097,
):
self.verbose = verbose
self.log_vars = log_vars
self.visdom_env = visdom_env
self.visdom_server = visdom_server
self.visdom_port = visdom_port
self.plot_file = plot_file
self.do_plot = do_plot
self.hard_reset(epoch=epoch)
@staticmethod
def from_json_str(json_str):
self = Stats([])
# load the global state
self.__dict__.update(json.loads(json_str))
# recover the AverageMeters
for stat_set in self.stats:
self.stats[stat_set] = {
log_var: AverageMeter.from_json_str(log_vals_json_str)
for log_var, log_vals_json_str in self.stats[stat_set].items()
}
return self
@staticmethod
def load(flpath, postfix=".jgz"):
flpath = _get_postfixed_filename(flpath, postfix)
with gzip.open(flpath, "r") as fin:
data = json.loads(fin.read().decode("utf-8"))
return Stats.from_json_str(data)
def save(self, flpath, postfix=".jgz"):
flpath = _get_postfixed_filename(flpath, postfix)
# store into a gzipped-json
with gzip.open(flpath, "w") as fout:
fout.write(json.dumps(self, cls=StatsJSONEncoder).encode("utf-8"))
# some sugar to be used with "with stats:" at the beginning of the epoch
def __enter__(self):
if self.do_plot and self.epoch >= 0:
self.plot_stats(self.visdom_env)
self.new_epoch()
def __exit__(self, type, value, traceback):
iserr = type is not None and issubclass(type, Exception)
iserr = iserr or (type is KeyboardInterrupt)
if iserr:
print("error inside 'with' block")
return
if self.do_plot:
self.plot_stats(self.visdom_env)
def reset(self): # to be called after each epoch
stat_sets = list(self.stats.keys())
if self.verbose:
print("stats: epoch %d - reset" % self.epoch)
self.it = {k: -1 for k in stat_sets}
for stat_set in stat_sets:
for stat in self.stats[stat_set]:
self.stats[stat_set][stat].reset()
def hard_reset(self, epoch=-1): # to be called during object __init__
self.epoch = epoch
if self.verbose:
print("stats: epoch %d - hard reset" % self.epoch)
self.stats = {}
# reset
self.reset()
def new_epoch(self):
if self.verbose:
print("stats: new epoch %d" % (self.epoch + 1))
self.epoch += 1
self.reset() # zero the stats + increase epoch counter
def gather_value(self, val):
if isinstance(val, (float, int)):
val = float(val)
else:
val = val.data.cpu().numpy()
val = float(val.sum())
return val
def add_log_vars(self, added_log_vars, verbose=True):
for add_log_var in added_log_vars:
if add_log_var not in self.stats:
if verbose:
print(f"Adding {add_log_var}")
self.log_vars.append(add_log_var)
# self.synchronize_logged_vars(self.log_vars, verbose=verbose)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
if self.epoch == -1: # uninitialized
print(
"warning: epoch==-1 means uninitialized stats structure -> new_epoch() called"
)
self.new_epoch()
if stat_set not in self.stats:
self.stats[stat_set] = {}
self.it[stat_set] = -1
if not freeze_iter:
self.it[stat_set] += 1
epoch = self.epoch
it = self.it[stat_set]
for stat in self.log_vars:
if stat not in self.stats[stat_set]:
self.stats[stat_set][stat] = AverageMeter()
if stat == "sec/it": # compute speed
if time_start is None:
elapsed = 0.0
else:
elapsed = time.time() - time_start
time_per_it = float(elapsed) / float(it + 1)
val = time_per_it
# self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
else:
if stat in preds:
try:
val = self.gather_value(preds[stat])
except KeyError:
raise ValueError(
"could not extract prediction %s\
from the prediction dictionary"
% stat
)
else:
val = None
if val is not None:
self.stats[stat_set][stat].update(val, epoch=epoch, n=1)
def get_epoch_averages(self, epoch=None):
stat_sets = list(self.stats.keys())
if epoch is None:
epoch = self.epoch
if epoch == -1:
epoch = list(range(self.epoch))
outvals = {}
for stat_set in stat_sets:
outvals[stat_set] = {
"epoch": epoch,
"it": self.it[stat_set],
"epoch_max": self.epoch,
}
for stat in self.stats[stat_set].keys():
if self.stats[stat_set][stat].count == 0:
continue
if isinstance(epoch, Iterable):
avgs = self.stats[stat_set][stat].get_epoch_averages()
avgs = [avgs[e] for e in epoch]
else:
avgs = self.stats[stat_set][stat].get_epoch_averages(epoch=epoch)
outvals[stat_set][stat] = avgs
return outvals
def print(
self,
max_it=None,
stat_set="train",
vars_print=None,
get_str=False,
skip_nan=False,
stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"),
):
epoch = self.epoch
stats = self.stats
str_out = ""
it = self.it[stat_set]
stat_str = ""
stats_print = sorted(stats[stat_set].keys())
for stat in stats_print:
if stats[stat_set][stat].count == 0:
continue
if skip_nan and not np.isfinite(stats[stat_set][stat].avg):
continue
stat_str += " {0:.12}: {1:1.3f} |".format(
stat_format(stat), stats[stat_set][stat].avg
)
head_str = "[%s] | epoch %3d | it %5d" % (stat_set, epoch, it)
if max_it:
head_str += "/ %d" % max_it
str_out = "%s | %s" % (head_str, stat_str)
if get_str:
return str_out
else:
print(str_out)
def plot_stats(
self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None
):
# use the cached visdom env if none supplied
if visdom_env is None:
visdom_env = self.visdom_env
if visdom_server is None:
visdom_server = self.visdom_server
if visdom_port is None:
visdom_port = self.visdom_port
if plot_file is None:
plot_file = self.plot_file
stat_sets = list(self.stats.keys())
print(
"printing charts to visdom env '%s' (%s:%d)"
% (visdom_env, visdom_server, visdom_port)
)
novisdom = False
viz = get_visdom_connection(server=visdom_server, port=visdom_port)
if not viz.check_connection():
print("no visdom server! -> skipping visdom plots")
novisdom = True
lines = []
# plot metrics
if not novisdom:
viz.close(env=visdom_env, win=None)
for stat in self.log_vars:
vals = []
stat_sets_now = []
for stat_set in stat_sets:
val = self.stats[stat_set][stat].get_epoch_averages()
if val is None:
continue
else:
val = np.array(val).reshape(-1)
stat_sets_now.append(stat_set)
vals.append(val)
if len(vals) == 0:
continue
lines.append((stat_sets_now, stat, vals))
if not novisdom:
for tmodes, stat, vals in lines:
title = "%s" % stat
opts = {"title": title, "legend": list(tmodes)}
for i, (tmode, val) in enumerate(zip(tmodes, vals)):
update = "append" if i > 0 else None
valid = np.where(np.isfinite(val))[0]
if len(valid) == 0:
continue
x = np.arange(len(val))
viz.line(
Y=val[valid],
X=x[valid],
env=visdom_env,
opts=opts,
win=f"stat_plot_{title}",
name=tmode,
update=update,
)
if plot_file:
print("exporting stats to %s" % plot_file)
ncol = 3
nrow = int(np.ceil(float(len(lines)) / ncol))
matplotlib.rcParams.update({"font.size": 5})
color = cycle(plt.cm.tab10(np.linspace(0, 1, 10)))
fig = plt.figure(1)
plt.clf()
for idx, (tmodes, stat, vals) in enumerate(lines):
c = next(color)
plt.subplot(nrow, ncol, idx + 1)
plt.gca()
for vali, vals_ in enumerate(vals):
c_ = c * (1.0 - float(vali) * 0.3)
valid = np.where(np.isfinite(vals_))[0]
if len(valid) == 0:
continue
x = np.arange(len(vals_))
plt.plot(x[valid], vals_[valid], c=c_, linewidth=1)
plt.ylabel(stat)
plt.xlabel("epoch")
plt.gca().yaxis.label.set_color(c[0:3] * 0.75)
plt.legend(tmodes)
gcolor = np.array(mcolors.to_rgba("lightgray"))
plt.grid(
b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4
)
plt.grid(
b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2
)
plt.minorticks_on()
plt.tight_layout()
plt.show()
try:
fig.savefig(plot_file)
except PermissionError:
warnings.warn("Cant dump stats due to insufficient permissions!")
def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True):
stat_sets = list(self.stats.keys())
# remove the additional log_vars
for stat_set in stat_sets:
for stat in self.stats[stat_set].keys():
if stat not in log_vars:
print("additional stat %s:%s -> removing" % (stat_set, stat))
self.stats[stat_set] = {
stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars
}
self.log_vars = log_vars # !!!
for stat_set in stat_sets:
reference_stat = list(self.stats[stat_set].keys())[0]
for stat in log_vars:
if stat not in self.stats[stat_set]:
if verbose:
print(
"missing stat %s:%s -> filling with default values (%1.2f)"
% (stat_set, stat, default_val)
)
elif len(self.stats[stat_set][stat].history) != self.epoch + 1:
h = self.stats[stat_set][stat].history
if len(h) == 0: # just never updated stat ... skip
continue
else:
if verbose:
print(
"incomplete stat %s:%s -> reseting with default values (%1.2f)"
% (stat_set, stat, default_val)
)
else:
continue
self.stats[stat_set][stat] = AverageMeter()
self.stats[stat_set][stat].reset()
lastep = self.epoch + 1
for ep in range(lastep):
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
epoch_self = self.stats[stat_set][reference_stat].get_epoch()
epoch_generated = self.stats[stat_set][stat].get_epoch()
assert (
epoch_self == epoch_generated
), "bad epoch of synchronized log_var! %d vs %d" % (
epoch_self,
epoch_generated,
)
class StatsJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, (AverageMeter, Stats)):
enc = self.encode(o.__dict__)
return enc
else:
raise TypeError(
f"Object of type {o.__class__.__name__} " f"is not JSON serializable"
)
def _get_postfixed_filename(fl, postfix):
return fl if fl.endswith(postfix) else fl + postfix

View File

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import collections
import dataclasses
import time
from contextlib import contextmanager
from typing import Any, Callable, Dict
import torch
@contextmanager
def evaluating(net: torch.nn.Module):
"""Temporarily switch to evaluation mode."""
istrain = net.training
try:
net.eval()
yield net
finally:
if istrain:
net.train()
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.cuda()
except AttributeError:
pass
return t
def try_to_cpu(t: Any) -> Any:
"""
Try to move the input variable `t` to a cpu device.
Args:
t: Input.
Returns:
t_cpu: `t` moved to a cpu device, if supported.
"""
try:
t = t.cpu()
except AttributeError:
pass
return t
def dict_to_cuda(batch: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Move all values in a dictionary to cuda if supported.
Args:
batch: Input dict.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
return {k: try_to_cuda(v) for k, v in batch.items()}
def dict_to_cpu(batch):
"""
Move all values in a dictionary to cpu if supported.
Args:
batch: Input dict.
Returns:
batch_cpu: `batch` moved to a cpu device, if supported.
"""
return {k: try_to_cpu(v) for k, v in batch.items()}
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
def dataclass_to_cpu_(obj):
"""
Move all contents of a dataclass to cpu inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cpu device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cpu(getattr(obj, f.name)))
return obj
# TODO: test it
def cat_dataclass(batch, tensor_collator: Callable):
"""
Concatenate all fields of a list of dataclasses `batch` to a single
dataclass object using `tensor_collator`.
Args:
batch: Input list of dataclasses.
Returns:
concatenated_batch: All elements of `batch` concatenated to a single
dataclass object.
tensor_collator: The function used to concatenate tensor fields.
"""
elem = batch[0]
collated = {}
for f in dataclasses.fields(elem):
elem_f = getattr(elem, f.name)
if elem_f is None:
collated[f.name] = None
elif torch.is_tensor(elem_f):
collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch])
elif dataclasses.is_dataclass(elem_f):
collated[f.name] = cat_dataclass(
[getattr(e, f.name) for e in batch], tensor_collator
)
elif isinstance(elem_f, collections.abc.Mapping):
collated[f.name] = {
k: tensor_collator([getattr(e, f.name)[k] for e in batch])
if elem_f[k] is not None
else None
for k in elem_f
}
else:
raise ValueError("Unsupported field type for concatenation")
return type(elem)(**collated)
class Timer:
"""
A simple class for timing execution.
Example:
```
with Timer():
print("This print statement is timed.")
```
"""
def __init__(self, name="timer", quiet=False):
self.name = name
self.quiet = quiet
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
self.end = time.time()
self.interval = self.end - self.start
if not self.quiet:
print("%20s: %1.6f sec" % (self.name, self.interval))

View File

@ -0,0 +1,149 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import shutil
import tempfile
import warnings
from typing import Optional, Tuple, Union
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
matplotlib.use("Agg")
class VideoWriter:
"""
A class for exporting videos.
"""
def __init__(
self,
cache_dir: Optional[str] = None,
ffmpeg_bin: str = "ffmpeg",
out_path: str = "/tmp/video.mp4",
fps: int = 20,
output_format: str = "visdom",
rmdir_allowed: bool = False,
**kwargs,
):
"""
Args:
cache_dir: A directory for storing the video frames. If `None`,
a temporary directory will be used.
ffmpeg_bin: The path to an `ffmpeg` executable.
out_path: The path to the output video.
fps: The speed of the generated video in frames-per-second.
output_format: Format of the output video. Currently only `"visdom"`
is supported.
rmdir_allowed: If `True` delete and create `cache_dir` in case
it is not empty.
"""
self.rmdir_allowed = rmdir_allowed
self.output_format = output_format
self.fps = fps
self.out_path = out_path
self.cache_dir = cache_dir
self.ffmpeg_bin = ffmpeg_bin
self.frames = []
self.regexp = "frame_%08d.png"
self.frame_num = 0
if self.cache_dir is not None:
self.tmp_dir = None
if os.path.isdir(self.cache_dir):
if rmdir_allowed:
shutil.rmtree(self.cache_dir)
else:
warnings.warn(
f"Warning: cache directory not empty ({self.cache_dir})."
)
os.makedirs(self.cache_dir, exist_ok=True)
else:
self.tmp_dir = tempfile.TemporaryDirectory()
self.cache_dir = self.tmp_dir.name
def write_frame(
self,
frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str],
resize: Optional[Union[float, Tuple[int, int]]] = None,
):
"""
Write a frame to the video.
Args:
frame: An object containing the frame image.
resize: Either a floating defining the image rescaling factor
or a 2-tuple defining the size of the output image.
"""
outfile = os.path.join(self.cache_dir, self.regexp % self.frame_num)
if isinstance(frame, matplotlib.figure.Figure):
plt.savefig(outfile)
im = Image.open(outfile)
elif isinstance(frame, np.ndarray):
if frame.dtype in (np.float64, np.float32, float):
frame = (np.transpose(frame, (1, 2, 0)) * 255.0).astype(np.uint8)
im = Image.fromarray(frame)
elif isinstance(frame, Image.Image):
im = frame
elif isinstance(frame, str):
im = Image.open(frame).convert("RGB")
else:
raise ValueError("Cant convert type %s" % str(type(frame)))
if im is not None:
if resize is not None:
if isinstance(resize, float):
resize = [int(resize * s) for s in im.size]
else:
resize = im.size
# make sure size is divisible by 2
resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)])
im = im.resize(resize, Image.ANTIALIAS)
im.save(outfile)
self.frames.append(outfile)
self.frame_num += 1
def get_video(self, quiet: bool = True):
"""
Generate the video from the written frames.
Args:
quiet: If `True`, suppresses logging messages.
Returns:
video_path: The path to the generated video.
"""
regexp = os.path.join(self.cache_dir, self.regexp)
if self.output_format == "visdom": # works for ppt too
ffmcmd_ = (
"%s -r %d -i %s -vcodec h264 -f mp4 \
-y -crf 18 -b 2000k -pix_fmt yuv420p '%s'"
% (self.ffmpeg_bin, self.fps, regexp, self.out_path)
)
else:
raise ValueError("no such output type %s" % str(self.output_format))
if quiet:
ffmcmd_ += " > /dev/null 2>&1"
else:
print(ffmcmd_)
os.system(ffmcmd_)
return self.out_path
def __del__(self):
if self.tmp_dir is not None:
self.tmp_dir.cleanup()

View File

@ -0,0 +1,172 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List
import torch
from visdom import Visdom
def get_visdom_env(cfg):
"""
Parse out visdom environment name from the input config.
Args:
cfg: The global config file.
Returns:
visdom_env: The name of the visdom environment.
"""
if len(cfg.visdom_env) == 0:
visdom_env = cfg.exp_dir.split("/")[-1]
else:
visdom_env = cfg.visdom_env
return visdom_env
# TODO: a proper singleton
_viz_singleton = None
def get_visdom_connection(
server: str = "http://localhost",
port: int = 8097,
) -> Visdom:
"""
Obtain a connection to a visdom server.
Args:
server: Server address.
port: Server port.
Returns:
connection: The connection object.
"""
global _viz_singleton
if _viz_singleton is None:
_viz_singleton = Visdom(server=server, port=port)
return _viz_singleton
def visualize_basics(
viz: Visdom,
preds: Dict[str, Any],
visdom_env_imgs: str,
title: str = "",
visualize_preds_keys: List[str] = [
"image_rgb",
"images_render",
"fg_probability",
"masks_render",
"depths_render",
"depth_map",
],
store_history: bool = False,
) -> None:
"""
Visualize basic outputs of a `GenericModel` to visdom.
Args:
viz: The visdom object.
preds: A dictionary containing `GenericModel` outputs.
visdom_env_imgs: Target visdom environment name.
title: The title of produced visdom window.
visualize_preds_keys: The list of keys of `preds` for visualization.
store_history: Store the history buffer in visdom windows.
"""
imout = {}
for k in visualize_preds_keys:
if k not in preds or preds[k] is None:
print(f"cant show {k}")
continue
v = preds[k].cpu().detach().clone()
if k.startswith("depth"):
# divide by 95th percentile
normfac = (
v.view(v.shape[0], -1)
.topk(k=int(0.05 * (v.numel() // v.shape[0])), dim=-1)
.values[:, -1]
)
v = v / normfac[:, None, None, None].clamp(1e-4)
if v.shape[1] == 1:
v = v.repeat(1, 3, 1, 1)
v = torch.nn.functional.interpolate(
v,
# pyre-fixme[6]: Expected `Optional[typing.List[float]]` for 2nd param
# but got `float`.
scale_factor=(
600.0
if (
"_eval" in visdom_env_imgs
and k in ("images_render", "depths_render")
)
else 200.0
)
/ v.shape[2],
mode="bilinear",
)
imout[k] = v
# TODO: handle errors on the outside
try:
imout = {"all": torch.cat(list(imout.values()), dim=2)}
except:
print("cant cat!")
for k, v in imout.items():
viz.images(
v.clamp(0.0, 1.0),
win=k,
env=visdom_env_imgs,
opts={"title": title + "_" + k, "store_history": store_history},
)
def make_depth_image(
depths: torch.Tensor,
masks: torch.Tensor,
max_quantile: float = 0.98,
min_quantile: float = 0.02,
min_out_depth: float = 0.1,
max_out_depth: float = 0.9,
) -> torch.Tensor:
"""
Convert a batch of depth maps to a grayscale image.
Args:
depths: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps.
masks: A tensor of shape `(B, 1, H, W)` containing a batch of foreground masks.
max_quantile: The quantile of the input depth values which will
be mapped to `max_out_depth`.
min_quantile: The quantile of the input depth values which will
be mapped to `min_out_depth`.
min_out_depth: The minimal value in each depth map will be assigned this color.
max_out_depth: The maximal value in each depth map will be assigned this color.
Returns:
depth_image: A tensor of shape `(B, 1, H, W)` a batch of grayscale
depth images.
"""
normfacs = []
for d, m in zip(depths, masks):
ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5)
if ok.sum() <= 1:
print("empty depth!")
normfacs.append(torch.zeros(2).type_as(depths))
continue
dok = d.view(-1)[ok].view(-1)
_maxk = max(int(round((1 - max_quantile) * (dok.numel()))), 1)
_mink = max(int(round(min_quantile * (dok.numel()))), 1)
normfac_max = dok.topk(k=_maxk, dim=-1).values[-1]
normfac_min = dok.topk(k=_mink, dim=-1, largest=False).values[-1]
normfacs.append(torch.stack([normfac_min, normfac_max]))
normfacs = torch.stack(normfacs)
_min, _max = (normfacs[:, 0].view(-1, 1, 1, 1), normfacs[:, 1].view(-1, 1, 1, 1))
depths = (depths - _min) / (_max - _min).clamp(1e-4)
depths = (
(depths * (max_out_depth - min_out_depth) + min_out_depth) * masks.float()
).clamp(0.0, 1.0)
return depths

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import tempfile
import unittest
from pathlib import Path
from typing import Generator, Tuple
from zipfile import ZipFile
from iopath.common.file_io import PathManager
@contextlib.contextmanager
def get_skateboard_data(
avoid_manifold: bool = False, silence_logs: bool = False
) -> Generator[Tuple[str, PathManager], None, None]:
"""
Context manager for accessing Co3D dataset by tests, at least for
the first 5 skateboards. Internally, we want this to exercise the
normal way to access the data directly manifold, but on an RE
worker this is impossible so we use a workaround.
Args:
avoid_manifold: Use the method used by RE workers even locally.
silence_logs: Whether to reduce log output from iopath library.
Yields:
dataset_root: (str) path to dataset root.
path_manager: path_manager to access it with.
"""
path_manager = PathManager()
if silence_logs:
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
if not os.environ.get("FB_TEST", False):
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
raise unittest.SkipTest("Unknown environment. Data not available.")
yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", path_manager
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
from libfb.py.parutil import get_file_path
par_path = "skateboard_first_5"
source = get_file_path(par_path)
assert Path(source).is_file()
with tempfile.TemporaryDirectory() as dest:
with ZipFile(source) as f:
f.extractall(dest)
yield os.path.join(dest, "extracted"), path_manager
else:
from iopath.fb.manifold import ManifoldPathHandler
path_manager.register_handler(ManifoldPathHandler())
yield "manifold://co3d/tree/extracted", path_manager
def provide_lpips_vgg():
"""
Ensure the weights files are available for lpips.LPIPS(net="vgg")
to be called. Specifically, torchvision's vgg16
"""
# In OSS, torchvision looks for vgg16 weights in
# https://download.pytorch.org/models/vgg16-397923af.pth
# Inside fbcode, this is replaced by asking iopath for
# manifold://torchvision/tree/models/vgg16-397923af.pth
# (the code for this replacement is in
# fbcode/pytorch/vision/fb/_internally_replaced_utils.py )
#
# iopath does this by looking for the file at the cache location
# and if it is not there getting it from manifold.
# (the code for this is in
# fbcode/fair_infra/data/iopath/iopath/fb/manifold.py )
#
# On the remote execution worker, manifold is inaccessible.
# We solve this by making the cached file available before iopath
# looks.
#
# By default the cache location is
# ~/.torch/iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# But we can't write to the home directory on the RE worker.
# We define FVCORE_CACHE to change the cache location to
# iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# (Without it, manifold caches in unstable temporary locations on RE.)
#
# The file we want has been copied from
# tree/models/vgg16-397923af.pth in the torchvision bucket
# to
# tree/testing/vgg16-397923af.pth in the co3d bucket
# and the TARGETS file copies it somewhere in the PAR which we
# recover with get_file_path.
# (It can't copy straight to a nested location, see
# https://fb.workplace.com/groups/askbuck/posts/2644615728920359/)
# Here we symlink it to the new cache location.
if os.environ.get("INSIDE_RE_WORKER") is not None:
from libfb.py.parutil import get_file_path
os.environ["FVCORE_CACHE"] = "iopath_cache"
par_path = "vgg_weights_for_lpips"
source = Path(get_file_path(par_path))
assert source.is_file()
dest = Path("iopath_cache/manifold_cache/tree/models")
if not dest.exists():
dest.mkdir(parents=True)
(dest / "vgg16-397923af.pth").symlink_to(source)

View File

@ -0,0 +1,122 @@
mask_images: true
mask_depths: true
render_image_width: 400
render_image_height: 400
mask_threshold: 0.5
output_rasterized_mc: false
bg_color:
- 0.0
- 0.0
- 0.0
view_pool: false
num_passes: 1
chunk_size_grid: 4096
render_features_dimensions: 3
tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
renderer_class_type: LSTMRenderer
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
implicit_function_class_type: IdrFeatureField
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
sequence_autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_args:
image_width: 400
image_height: 400
scene_center:
- 0.0
- 0.0
- 0.0
scene_extent: 0.0
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.1
max_depth: 8.0
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0
init_depth_noise_std: 0.0005
hidden_size: 16
n_feature_channels: 256
verbose: false
image_feature_extractor_args:
name: resnet34
pretrained: true
stages:
- 1
- 2
- 3
- 4
normalize_image: true
image_rescale: 0.16
first_max_pool: true
proj_dim: 32
l2_norm: true
add_masks: true
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0

View File

@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
frame_timestamp: float = 0.0
class MockDataset:
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
"""
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
self.seq_to_idx = {
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
}
# frame numbers within sequence: [0, ..., 4, n, ..., n+4]
# where n - 4 == max_frame_gap
frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap))
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
def get_frame_numbers_and_timestamps(self, idxs):
out = []
for idx in idxs:
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):
self.dataset_overfit = MockDataset(1)
def test_overfit(self):
num_batches = 3
batch_size = 10
sampler = SceneBatchSampler(
self.dataset_overfit,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=[10], # will try to sample batch_size anyway
)
self.assertEqual(len(sampler), num_batches)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch)
self.assertEqual(len(batch), batch_size) # true for our examples
self.assertTrue(all(idx // 10 == 0 for idx in batch))
with self.assertRaises(StopIteration):
batch = next(it)
def test_multiseq(self):
for ips_options in [[10], [2], [3], [2, 3, 4]]:
for sample_consecutive_frames in [True, False]:
for consecutive_frames_max_gap in [0, 1, 3]:
self._test_multiseq_flavour(
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
)
def test_multiseq_gaps(self):
num_batches = 16
batch_size = 10
dataset_multiseq = MockDataset(5, max_frame_gap=3)
for ips_options in [[10], [2], [3], [2, 3, 4]]:
debug_info = f" Images per sequence: {ips_options}."
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=True,
consecutive_frames_max_gap=1,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
if max(ips_options) > 5:
# true for our examples
self.assertEqual(len(batch), 5, msg=debug_info)
else:
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
self._check_frames_are_consecutive(
batch, dataset_multiseq.frame_annots, debug_info
)
def _test_multiseq_flavour(
self,
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
num_batches=16,
batch_size=10,
):
debug_info = (
f" Images per sequence: {ips_options}, "
f"sample_consecutive_frames: {sample_consecutive_frames}, "
f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, "
)
# in this test, either consecutive_frames_max_gap == max_frame_gap,
# or consecutive_frames_max_gap == 0, so segments consist of full sequences
frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3
dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap)
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
typical_counts = set()
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
# find distribution over sequences
counts = _count_by_quotient(batch, 10)
freqs = _count_by_quotient(counts.values(), 1)
self.assertLessEqual(
len(freqs),
2,
msg="We should have maximum of 2 different "
"frequences of sequences in the batch." + debug_info,
)
if len(freqs) == 2:
most_seq_count = max(*freqs.keys())
last_seq = min(*freqs.keys())
self.assertEqual(
freqs[last_seq],
1,
msg="Only one odd sequence allowed." + debug_info,
)
else:
self.assertEqual(len(freqs), 1)
most_seq_count = next(iter(freqs))
self.assertIn(most_seq_count, ips_options)
typical_counts.add(most_seq_count)
if sample_consecutive_frames:
self._check_frames_are_consecutive(
batch,
dataset_multiseq.frame_annots,
debug_info,
max_gap=consecutive_frames_max_gap,
)
self.assertTrue(
all(i in typical_counts for i in ips_options),
"Some of the frequency options did not occur among "
f"the {num_batches} batches (could be just bad luck)." + debug_info,
)
with self.assertRaises(StopIteration):
batch = next(it)
def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1):
# make sure that sampled frames are consecutive
for i in range(len(batch) - 1):
curr_idx, next_idx = batch[i : i + 2]
if curr_idx // 10 == next_idx // 10: # same sequence
if max_gap > 0:
curr_idx, next_idx = [
annots[idx]["frame_annotation"].frame_number
for idx in (curr_idx, next_idx)
]
gap = max_gap
else:
gap = 1 # we'll check that raw dataset indices are consecutive
self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info)
def _count_by_quotient(indices, divisor):
counter = defaultdict(int)
for i in indices:
counter[i // divisor] += 1
return counter

View File

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from math import pi
import torch
from pytorch3d.implicitron.tools.circle_fitting import (
_signed_area,
fit_circle_in_2d,
fit_circle_in_3d,
)
from pytorch3d.transforms import random_rotation
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestCircleFitting(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def _assertParallel(self, a, b, **kwargs):
"""
Given a and b of shape (..., 3) each containing 3D vectors,
assert that correspnding vectors are parallel. Changed sign is ok.
"""
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
def test_simple_3d(self):
device = torch.device("cuda:0")
for _ in range(7):
radius = 10 * torch.rand(1, device=device)[0]
center = 10 * torch.rand(3, device=device)
rot = random_rotation(device=device)
offset = torch.rand(3, device=device)
up = torch.rand(3, device=device)
self._simple_3d_test(radius, center, rot, offset, up)
def _simple_3d_test(self, radius, center, rot, offset, up):
# angles are increasing so the points move in a well defined direction.
angles = torch.cumsum(torch.rand(17, device=rot.device), dim=0)
many = torch.stack(
[torch.cos(angles), torch.sin(angles), torch.zeros_like(angles)], dim=1
)
source_points = (many * radius) @ rot + center[None]
# case with no generation
result = fit_circle_in_3d(source_points)
self.assertClose(result.radius, radius)
self.assertClose(result.center, center)
self._assertParallel(result.normal, rot[2], atol=1e-5)
self.assertEqual(result.generated_points.shape, (0, 3))
# Generate 5 points around the circle
n_new_points = 5
result2 = fit_circle_in_3d(source_points, n_points=n_new_points)
self.assertClose(result2.radius, radius)
self.assertClose(result2.center, center)
self.assertClose(result2.normal, result.normal)
self.assertEqual(result2.generated_points.shape, (5, 3))
observed_points = result2.generated_points
self.assertClose(observed_points[0], observed_points[4], atol=1e-4)
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
observed_normal = torch.cross(
observed_points[0] - observed_points[2],
observed_points[1] - observed_points[3],
dim=-1,
)
self._assertParallel(observed_normal, result.normal, atol=1e-4)
diameters = observed_points[:2] - observed_points[2:4]
self.assertClose(
torch.norm(diameters, dim=1), diameters.new_full((2,), 2 * radius)
)
# Regenerate the input points
result3 = fit_circle_in_3d(source_points, angles=angles - angles[0])
self.assertClose(result3.radius, radius)
self.assertClose(result3.center, center)
self.assertClose(result3.normal, result.normal)
self.assertClose(result3.generated_points, source_points, atol=1e-5)
# Test with offset
result4 = fit_circle_in_3d(
source_points, angles=angles - angles[0], offset=offset, up=up
)
self.assertClose(result4.radius, radius)
self.assertClose(result4.center, center)
self.assertClose(result4.normal, result.normal)
observed_offsets = result4.generated_points - source_points
# observed_offset is constant
self.assertClose(
observed_offsets.min(0).values, observed_offsets.max(0).values, atol=1e-5
)
# observed_offset has the right length
self.assertClose(observed_offsets[0].norm(), offset.norm())
self.assertClose(result.normal.norm(), torch.ones(()))
# component of observed_offset along normal
component = torch.dot(observed_offsets[0], result.normal)
self.assertClose(component.abs(), offset[2].abs(), atol=1e-5)
agree_normal = torch.dot(result.normal, up) > 0
agree_signs = component * offset[2] > 0
self.assertEqual(agree_normal, agree_signs)
def test_simple_2d(self):
radius = 7.0
center = torch.tensor([9, 2.5])
angles = torch.cumsum(torch.rand(17), dim=0)
many = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
source_points = (many * radius) + center[None]
result = fit_circle_in_2d(source_points)
self.assertClose(result.radius, torch.tensor(radius))
self.assertClose(result.center, center)
self.assertEqual(result.generated_points.shape, (0, 2))
# Generate 5 points around the circle
n_new_points = 5
result2 = fit_circle_in_2d(source_points, n_points=n_new_points)
self.assertClose(result2.radius, torch.tensor(radius))
self.assertClose(result2.center, center)
self.assertEqual(result2.generated_points.shape, (5, 2))
observed_points = result2.generated_points
self.assertClose(observed_points[0], observed_points[4])
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
diameters = observed_points[:2] - observed_points[2:4]
self.assertClose(torch.norm(diameters, dim=1), torch.full((2,), 2 * radius))
# Regenerate the input points
result3 = fit_circle_in_2d(source_points, angles=angles - angles[0])
self.assertClose(result3.radius, torch.tensor(radius))
self.assertClose(result3.center, center)
self.assertClose(result3.generated_points, source_points, atol=1e-5)
def test_minimum_inputs(self):
fit_circle_in_3d(torch.rand(3, 3), n_points=10)
with self.assertRaisesRegex(
ValueError, "2 points are not enough to determine a circle"
):
fit_circle_in_3d(torch.rand(2, 3))
def test_signed_area(self):
n_points = 1001
angles = torch.linspace(0, 2 * pi, n_points)
radius = 0.85
center = torch.rand(2)
circle = center + radius * torch.stack(
[torch.cos(angles), torch.sin(angles)], dim=1
)
circle_area = torch.tensor(pi * radius * radius)
self.assertClose(_signed_area(circle), circle_area)
# clockwise is negative
self.assertClose(_signed_area(circle.flip(0)), -circle_area)
# Semicircles
self.assertClose(_signed_area(circle[: (n_points + 1) // 2]), circle_area / 2)
self.assertClose(_signed_area(circle[n_points // 2 :]), circle_area / 2)
# A straight line bounds no area
self.assertClose(_signed_area(torch.rand(2, 2)), torch.tensor(0.0))
# Letter 'L' written anticlockwise.
L_shape = [[0, 1], [0, 0], [1, 0]]
# Triangle area is 0.5 * b * h.
self.assertClose(_signed_area(torch.tensor(L_shape)), torch.tensor(0.5))

View File

@ -0,0 +1,610 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import List, Optional, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
_is_actually_dataclass,
_Registry,
expand_args_fields,
get_default_args,
get_default_args_field,
registry,
remove_unused_components,
run_auto_creation,
)
@dataclass
class Animal(ReplaceableBase):
pass
class Fruit(ReplaceableBase):
pass
@registry.register
class Banana(Fruit):
pips: int
spots: int
bananame: str
@registry.register
class Pear(Fruit):
n_pips: int = 13
class Pineapple(Fruit):
pass
@registry.register
class Orange(Fruit):
pass
@registry.register
class Kiwi(Fruit):
pass
@registry.register
class LargePear(Pear):
pass
class MainTest(Configurable):
the_fruit: Fruit
n_ids: int
n_reps: int = 8
the_second_fruit: Fruit
def create_the_second_fruit(self):
expand_args_fields(Pineapple)
self.the_second_fruit = Pineapple()
def __post_init__(self):
run_auto_creation(self)
class TestConfig(unittest.TestCase):
def test_is_actually_dataclass(self):
@dataclass
class A:
pass
self.assertTrue(_is_actually_dataclass(A))
self.assertTrue(is_dataclass(A))
class B(A):
a: int
self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B))
def test_simple_replacement(self):
struct = get_default_args(MainTest)
struct.n_ids = 9780
struct.the_fruit_Pear_args.n_pips = 3
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Pear"
main = MainTest(**struct)
self.assertIsInstance(main.the_fruit, Pear)
self.assertEqual(main.n_reps, 8)
self.assertEqual(main.n_ids, 9780)
self.assertEqual(main.the_fruit.n_pips, 3)
self.assertIsInstance(main.the_second_fruit, Pineapple)
struct2 = get_default_args(MainTest)
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
self.assertEqual(
MainTest._creation_functions,
("create_the_fruit", "create_the_second_fruit"),
)
def test_detect_bases(self):
# testing the _base_class_from_class function
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
self.assertIsNone(_Registry._base_class_from_class(MainTest))
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
class PricklyPear(Pear):
pass
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
def test_registry_entries(self):
self.assertIs(registry.get(Fruit, "Banana"), Banana)
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
registry.get(Animal, "Banana")
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
registry.get(Fruit, "PricklyPear")
self.assertIs(registry.get(Pear, "Pear"), Pear)
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
registry.get(Pear, "Banana")
all_fruit = set(registry.get_all(Fruit))
self.assertIn(Banana, all_fruit)
self.assertIn(Pear, all_fruit)
self.assertIn(LargePear, all_fruit)
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
@registry.register
class Apple(Fruit):
pass
@registry.register
class CrabApple(Apple):
pass
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
@registry.register
class NotAFruit:
pass
def test_recursion(self):
class Shape(ReplaceableBase):
pass
@registry.register
class Triangle(Shape):
a: float = 5.0
@registry.register
class Square(Shape):
a: float = 3.0
@registry.register
class LargeShape(Shape):
inner: Shape
def __post_init__(self):
run_auto_creation(self)
class ShapeContainer(Configurable):
shape: Shape
container = ShapeContainer(**get_default_args(ShapeContainer))
# This is because ShapeContainer is missing __post_init__
with self.assertRaises(AttributeError):
container.shape
class ShapeContainer2(Configurable):
x: Shape
x_class_type: str = "LargeShape"
def __post_init__(self):
self.x_LargeShape_args.inner_class_type = "Triangle"
run_auto_creation(self)
container2_args = get_default_args(ShapeContainer2)
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
# We do not perform expansion that would result in an infinite recursion,
# so this member is not present.
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
container2_args.x_LargeShape_args.inner_Square_args.a += 100
container2 = ShapeContainer2(**container2_args)
self.assertIsInstance(container2.x, LargeShape)
self.assertIsInstance(container2.x.inner, Triangle)
self.assertEqual(container2.x.inner.a, 15.0)
def test_simpleclass_member(self):
# Members which are not dataclasses are
# tolerated. But it would be nice to be able to
# configure them.
class Foo:
def __init__(self, a=1, b=2):
self.a, self.b = a, b
@dataclass()
class Bar:
aa: int = 9
bb: int = 9
class Container(Configurable):
bar: Bar = Bar()
# TODO make this work?
# foo: Foo = Foo()
fruit: Fruit
fruit_class_type: str = "Orange"
def __post_init__(self):
run_auto_creation(self)
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
# self.assertIsInstance(container.bar, Bar)
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4
container_args2 = get_default_args(Container)
container = Container(**container_args2)
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Orange"
def __post_init__(self):
raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl):
extra_fruit: Fruit
extra_fruit_class_type: str = "Kiwi"
def __post_init__(self):
run_auto_creation(self)
large_args = get_default_args(LargeFruitBowl)
self.assertNotIn("extra_fruit", large_args)
self.assertNotIn("main_fruit", large_args)
large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi)
def test_inheritance2(self):
# This is a case where a class could contain an instance
# of a subclass, which is ignored.
class Parent(ReplaceableBase):
pass
class Main(Configurable):
parent: Parent
# Note - no __post__init__
@registry.register
class Derived(Parent, Main):
pass
args = get_default_args(Main)
# Derived has been ignored in processing Main.
self.assertCountEqual(args.keys(), ["parent_class_type"])
main = Main(**args)
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
run_auto_creation(main)
main.parent_class_type = "Derived"
# Illustrates that a dict works fine instead of a DictConfig.
main.parent_Derived_args = {}
with self.assertRaises(AttributeError):
main.parent
run_auto_creation(main)
self.assertIsInstance(main.parent, Derived)
def test_redefine(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Grape"
def __post_init__(self):
run_auto_creation(self)
@registry.register
@dataclass
class Grape(Fruit):
large: bool = False
def get_color(self):
return "red"
def __post_init__(self):
raise ValueError("This doesn't get called")
bowl_args = get_default_args(FruitBowl)
@registry.register
@dataclass
class Grape(Fruit): # noqa: F811
large: bool = True
def get_color(self):
return "green"
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
self.assertIsInstance(bowl.main_fruit, Grape)
# Redefining the same class won't help with defaults because encoded in args
self.assertEqual(bowl.main_fruit.large, False)
# But the override worked.
self.assertEqual(bowl.main_fruit.get_color(), "green")
# 2. Try redefining without the dataclass modifier
# This relies on the fact that default creation processes the class.
# (otherwise incomprehensible messages)
@registry.register
class Grape(Fruit): # noqa: F811
large: bool = True
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
# 3. Adding a new class doesn't get picked up, because the first
# get_default_args call has frozen FruitBowl. This is intrinsic to
# the way dataclass and expand_args_fields work in-place but
# expand_args_fields is not pure - it depends on the registry.
@registry.register
class Fig(Fruit):
pass
bowl_args2 = get_default_args(FruitBowl)
self.assertIn("main_fruit_Grape_args", bowl_args2)
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
# TODO Is it possible to make this work?
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
# bowl_args2.main_fruit_class_type = "Fig"
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
# Note that it is possible to use Fig if you can set
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
# before run_auto_creation happens. See test_inheritance2
# for an example.
def test_no_replacement(self):
# Test of Configurables without ReplaceableBase
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
class C(Configurable):
b: B
def __post_init__(self):
run_auto_creation(self)
c_args = get_default_args(C)
c = C(**c_args)
self.assertIsInstance(c.b.a, A)
self.assertEqual(c.b.a.n, 9)
def test_doc(self):
# The case in the docstring.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
b_args = get_default_args(B)
self.assertNotIn("a", b_args)
b = B(**b_args)
self.assertEqual(b.a.n, "2")
def test_raw_types(self):
@dataclass
class MyDataclass:
int_field: int = 0
none_field: Optional[int] = None
float_field: float = 9.3
bool_field: bool = True
tuple_field: tuple = (3, True, "j")
class SimpleClass:
def __init__(self, tuple_member_=(3, 4)):
self.tuple_member = tuple_member_
def get_tuple(self):
return self.tuple_member
def f(*, a: int = 3, b: str = "kj"):
self.assertEqual(a, 3)
self.assertEqual(b, "kj")
class C(Configurable):
simple: DictConfig = get_default_args_field(SimpleClass)
# simple2: SimpleClass2 = SimpleClass2()
mydata: DictConfig = get_default_args_field(MyDataclass)
a_tuple: Tuple[float] = (4.0, 3.0)
f_args: DictConfig = get_default_args_field(f)
args = get_default_args(C)
c = C(**args)
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
mydata = MyDataclass(**c.mydata)
simple = SimpleClass(**c.simple)
# OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
f(**c.f_args)
def test_irrelevant_bases(self):
class NotADataclass:
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields.
a: int = 9
b: int
class LeftConfigured(Configurable, NotADataclass):
left: int = 1
class RightConfigured(NotADataclass, Configurable):
right: int = 2
class Outer(Configurable):
left: LeftConfigured
right: RightConfigured
def __post_init__(self):
run_auto_creation(self)
outer = Outer(**get_default_args(Outer))
self.assertEqual(outer.left.left, 1)
self.assertEqual(outer.right.right, 2)
with self.assertRaisesRegex(TypeError, "non-default argument"):
dataclass(NotADataclass)
def test_unprocessed(self):
# behavior of Configurable classes which need processing in __new__,
class Unprocessed(Configurable):
a: int = 9
class UnprocessedReplaceable(ReplaceableBase):
a: int = 1
with self.assertWarnsRegex(UserWarning, "must be processed"):
Unprocessed()
with self.assertWarnsRegex(UserWarning, "must be processed"):
UnprocessedReplaceable()
def test_enum(self):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
# are in use.
class A(Enum):
B1 = "b1"
B2 = "b2"
class C(Configurable):
a: A = A.B1
base = get_default_args(C)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Banana"
remove_unused_components(struct)
expected_keys = [
"n_ids",
"n_reps",
"the_fruit_Pear_args",
"the_fruit_class_type",
"the_second_fruit_Banana_args",
"the_second_fruit_class_type",
]
expected_yaml = textwrap.dedent(
"""\
n_ids: 32
n_reps: 8
the_fruit_class_type: Pear
the_fruit_Pear_args:
n_pips: 13
the_second_fruit_class_type: Banana
the_second_fruit_Banana_args:
pips: ???
spots: ???
bananame: ???
"""
)
self.assertEqual(sorted(struct.keys()), expected_keys)
# Check that struct is what we expect
expected = OmegaConf.create(expected_yaml)
self.assertEqual(struct, expected)
# Check that we get what we expect when writing to yaml.
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
main = MainTest(**struct)
instance_data = OmegaConf.structured(main)
remove_unused_components(instance_data)
self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected)
@dataclass(eq=False)
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_reference_type: List[int] = field(default_factory=lambda: [])
class MockClassWithInit: # noqa: B903
def __init__(
self,
field_no_default: int,
field_primitive_type: int = 42,
field_reference_type: List[int] = [], # noqa: B006
):
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_reference_type = field_reference_type
class TestRawClasses(unittest.TestCase):
def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
inst = cls(field_no_default=0)
dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(inst, name))
self.assertEqual(val, getattr(inst, name))
def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_reference_type"].append(13)
inst = cls(field_no_default=0)
self.assertEqual(inst.field_reference_type, [])

View File

@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
remove_unused_components,
)
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
else:
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
# Tests the use of the config system in implicitron
class TestGenericModel(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_create_gm(self):
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
def test_create_gm_overrides(self):
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm)
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())

View File

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import os
import unittest
import torch
import torchvision
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
from pytorch3d.vis.plotly_vis import plot_scene
from visdom import Visdom
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data
else:
from common_resources import get_skateboard_data
class TestDatasetVisualize(unittest.TestCase):
def setUp(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
raise unittest.SkipTest("Visdom not available")
category = "skateboard"
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.datasets = {
"simple": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nonsquare": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nocrop": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=False,
load_point_clouds=True,
path_manager=path_manager,
),
}
self.datasets.update(
{
k + "_newndc": _change_annotations_to_new_ndc(dataset)
for k, dataset in self.datasets.items()
}
)
self.visdom = Visdom()
if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None
def _render_one_pointcloud(self, point_cloud, cameras, render_size):
(_image_render, _, _) = render_point_cloud_pytorch3d(
cameras,
point_cloud,
render_size=render_size,
point_radius=1e-2,
topk=10,
bg_color=0.0,
)
return _image_render.clamp(0.0, 1.0)
def test_one(self):
"""Test dataset visualization."""
for max_frames in (16, -1):
for load_dataset_point_cloud in (True, False):
for dataset_key in self.datasets:
self._gen_and_render_pointcloud(
max_frames, load_dataset_point_cloud, dataset_key
)
def _gen_and_render_pointcloud(
self, max_frames, load_dataset_point_cloud, dataset_key
):
dataset = self.datasets[dataset_key]
# load the point cloud of the first sequence
sequence_show = list(dataset.seq_annots.keys())[0]
device = torch.device("cuda:0")
point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
dataset,
sequence_name=sequence_show,
mask_points=True,
max_frames=max_frames,
num_workers=10,
load_dataset_point_cloud=load_dataset_point_cloud,
)
# render on gpu
point_cloud = point_cloud.to(device)
cameras = sequence_frame_data.camera.to(device)
# render the point_cloud from the viewpoint of loaded cameras
images_render = torch.cat(
[
self._render_one_pointcloud(
point_cloud,
cameras[frame_i],
(
dataset.image_height,
dataset.image_width,
),
)
for frame_i in range(len(cameras))
]
).cpu()
images_gt_and_render = torch.cat(
[sequence_frame_data.image_rgb, images_render], dim=3
)
imfile = os.path.join(
os.path.split(os.path.abspath(__file__))[0],
"test_dataset_visualize"
+ f"_max_frames={max_frames}"
+ f"_load_pcl={load_dataset_point_cloud}.png",
)
print(f"Exporting image {imfile}.")
torchvision.utils.save_image(images_gt_and_render, imfile, nrow=2)
if self.visdom is not None:
test_name = f"{max_frames}_{load_dataset_point_cloud}_{dataset_key}"
self.visdom.images(
images_gt_and_render,
env="test_dataset_visualize",
win=f"pcl_renders_{test_name}",
opts={"title": f"pcl_renders_{test_name}"},
)
plotlyplot = plot_scene(
{
"scene_batch": {
"cameras": cameras,
"point_cloud": point_cloud,
}
},
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1.0,
)
self.visdom.plotlyplot(
plotlyplot,
env="test_dataset_visualize",
win=f"pcl_{test_name}",
)
def _change_annotations_to_new_ndc(dataset):
dataset = copy.deepcopy(dataset)
for frame in dataset.frame_annots:
vp = frame["frame_annotation"].viewpoint
vp.intrinsics_format = "ndc_isotropic"
# this assume the focal length to be equal on x and y (ok for a test)
max_flength = max(vp.focal_length)
vp.principal_point = (
vp.principal_point[0] * max_flength / vp.focal_length[0],
vp.principal_point[1] * max_flength / vp.focal_length[1],
)
vp.focal_length = (
max_flength,
max_flength,
)
return dataset

View File

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
from pytorch3d.transforms import axis_angle_to_matrix
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestEvalCameras(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_circular(self):
n_train_cameras = 10
n_test_cameras = 100
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
amplitude = 0.01
R_jiggled = torch.bmm(
R, axis_angle_to_matrix(torch.rand(n_train_cameras, 3) * amplitude)
)
cameras_train = PerspectiveCameras(R=R_jiggled, T=T)
cameras_test = generate_eval_video_cameras(
cameras_train, trajectory_type="circular_lsq_fit", trajectory_scale=1.0
)
positions_test = cameras_test.get_camera_center()
center = positions_test.mean(0)
self.assertClose(center, torch.zeros(3), atol=0.1)
self.assertClose(
(positions_test - center).norm(dim=[1]),
torch.ones(n_test_cameras),
atol=0.1,
)

View File

@ -0,0 +1,290 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import dataclasses
import math
import os
import unittest
import lpips
import torch
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg
else:
from common_resources import get_skateboard_data, provide_lpips_vgg
class TestEvaluation(unittest.TestCase):
def setUp(self):
# initialize evaluation dataset/dataloader
torch.manual_seed(42)
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
category = "skateboard"
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.dataset = ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
path_manager=path_manager,
)
self.bg_color = 0.0
# init the lpips model for eval
provide_lpips_vgg()
self.lpips_model = lpips.LPIPS(net="vgg")
def test_eval_depth(self):
"""
Check that eval_depth correctly masks errors and that, for get_best_scale=True,
the error with scaled prediction equals the error without scaling the
predicted depth. Finally, test that the error values are as expected
for prediction and gt differing by a constant offset.
"""
gt = (torch.randn(10, 1, 300, 400, device="cuda") * 5.0).clamp(0.0)
mask = (torch.rand_like(gt) > 0.5).type_as(gt)
for diff in 10 ** torch.linspace(-5, 0, 6):
for crop in (0, 5):
pred = gt + (torch.rand_like(gt) - 0.5) * 2 * diff
# scaled prediction test
mse_depth, abs_depth = eval_depth(
pred,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
mse_depth_scale, abs_depth_scale = eval_depth(
pred * 10.0,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
self.assertAlmostEqual(
float(mse_depth.sum()), float(mse_depth_scale.sum()), delta=1e-4
)
self.assertAlmostEqual(
float(abs_depth.sum()), float(abs_depth_scale.sum()), delta=1e-4
)
# error masking test
pred_masked_err = gt + (torch.rand_like(gt) + diff) * (1 - mask)
mse_depth_masked, abs_depth_masked = eval_depth(
pred_masked_err,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
self.assertAlmostEqual(
float(mse_depth_masked.sum()), float(0.0), delta=1e-4
)
self.assertAlmostEqual(
float(abs_depth_masked.sum()), float(0.0), delta=1e-4
)
mse_depth_unmasked, abs_depth_unmasked = eval_depth(
pred_masked_err,
gt,
crop=crop,
mask=1 - mask,
get_best_scale=True,
)
self.assertGreater(
float(mse_depth_unmasked.sum()),
float(diff ** 2),
)
self.assertGreater(
float(abs_depth_unmasked.sum()),
float(diff),
)
# tests with constant error
pred_fix_diff = gt + diff * mask
for _mask_gt in (mask, None):
mse_depth_fix_diff, abs_depth_fix_diff = eval_depth(
pred_fix_diff,
gt,
crop=crop,
mask=_mask_gt,
get_best_scale=False,
)
if _mask_gt is not None:
expected_err_abs = diff
expected_err_mse = diff ** 2
else:
err_mask = (gt > 0.0).float() * mask
if crop > 0:
err_mask = err_mask[:, :, crop:-crop, crop:-crop]
gt_cropped = gt[:, :, crop:-crop, crop:-crop]
else:
gt_cropped = gt
gt_mass = (gt_cropped > 0.0).float().sum(dim=(1, 2, 3))
expected_err_abs = (
diff * err_mask.sum(dim=(1, 2, 3)) / (gt_mass)
)
expected_err_mse = diff * expected_err_abs
self.assertTrue(
torch.allclose(
abs_depth_fix_diff,
expected_err_abs * torch.ones_like(abs_depth_fix_diff),
atol=1e-4,
)
)
self.assertTrue(
torch.allclose(
mse_depth_fix_diff,
expected_err_mse * torch.ones_like(mse_depth_fix_diff),
atol=1e-4,
)
)
def test_psnr(self):
"""
Compare against opencv and check that the psnr is above
the minimum possible value.
"""
import cv2
im1 = torch.rand(100, 3, 256, 256).cuda()
im1_uint8 = (im1 * 255).to(torch.uint8)
im1_rounded = im1_uint8.float() / 255
for max_diff in 10 ** torch.linspace(-5, 0, 6):
im2 = im1 + (torch.rand_like(im1) - 0.5) * 2 * max_diff
im2 = im2.clamp(0.0, 1.0)
im2_uint8 = (im2 * 255).to(torch.uint8)
im2_rounded = im2_uint8.float() / 255
# check that our psnr matches the output of opencv
psnr = calc_psnr(im1_rounded, im2_rounded)
# some versions of cv2 can only take uint8 input
psnr_cv2 = cv2.PSNR(
im1_uint8.cpu().numpy(),
im2_uint8.cpu().numpy(),
)
self.assertAlmostEqual(float(psnr), float(psnr_cv2), delta=1e-4)
# check that all PSNRs are bigger than the minimum possible PSNR
max_mse = max_diff ** 2
min_psnr = 10 * math.log10(1.0 / max_mse)
for _im1, _im2 in zip(im1, im2):
_psnr = calc_psnr(_im1, _im2)
self.assertGreaterEqual(float(_psnr) + 1e-6, min_psnr)
def _one_sequence_test(
self,
seq_dataset,
n_batches=2,
min_batch_size=5,
max_batch_size=10,
):
# form a list of random batches
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
loader = torch.utils.data.DataLoader(
seq_dataset,
# batch_size=1,
shuffle=False,
batch_sampler=batch_indices,
collate_fn=FrameData.collate,
)
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
model.cuda()
self.lpips_model.cuda()
for frame_data in loader:
self.assertIsNone(frame_data.frame_type)
self.assertIsNotNone(frame_data.image_rgb)
# override the frame_type
frame_data.frame_type = [
"train_unseen",
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
]
# move frame_data to gpu
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
eval_result = eval_batch(
frame_data,
nvs_prediction,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
nvs_prediction_bad.depth_render += (
torch.randn_like(nvs_prediction.depth_render) * 100.0
)
nvs_prediction_bad.image_render += (
torch.randn_like(nvs_prediction.image_render) * 100.0
)
nvs_prediction_bad.mask_render = (
torch.randn_like(nvs_prediction.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
nvs_prediction_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better.keys():
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
)
_assert(m_better, m_worse)
def test_full_eval(self, n_sequences=5):
"""Test evaluation."""
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset)

View File

@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
class TestGenericModel(unittest.TestCase):
def test_gm(self):
# Simple test of a forward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
# TODO: make these default to None?
defaulted_args = {
"fg_probability": None,
"depth_map": None,
"mask_crop": None,
"sequence_name": None,
}
with self.assertWarnsRegex(UserWarning, "No main objective found"):
model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
**defaulted_args,
image_rgb=None,
)
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
model.eval()
with torch.no_grad():
# TODO: perhaps this warning should be skipped in eval mode?
with self.assertWarnsRegex(UserWarning, "No main objective found"):
eval_preds = model(
camera=cameras[0],
**defaulted_args,
image_rgb=None,
)
self.assertEqual(
eval_preds["images_render"].shape,
(1, 3, model.render_image_height, model.render_image_width),
)

View File

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
from pytorch3d.renderer import RayBundle
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
def test_simple(self):
length = 15
n_pts_per_ray = 10
for add_input_samples in [False, True]:
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
expected = expected.expand(3, 25, n_pts_per_ray)
if add_input_samples:
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0]
else:
full_expected = expected
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(bundle, weights)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0.5)
self.assertLess(lengths_random.max().item(), length - 1.5)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)

View File

@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
SRNPixelGenerator,
)
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import RayBundle
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3
_N_RAYS: int = 100
_N_POINTS_ON_RAY: int = 10
class TestSRN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
get_default_args(SRNHyperNetImplicitFunction)
get_default_args(SRNImplicitFunction)
def test_pixel_generator(self):
SRNPixelGenerator()
def _get_bundle(self, *, device) -> RayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = RayBundle(
lengths=lengths, origins=origins, directions=directions, xys=None
)
return bundle
def test_srn_implicit_function(self):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function_optim(self):
# Test optimization loop, requiring that the cache is properly
# cleared in new_args_bound
latent_dim_hypernet = 39
hyper_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
implicit_function.to(device)
implicit_function2.to(device)
wrapper = ImplicitFunctionWrapper(implicit_function)
optimizer = torch.optim.Adam(implicit_function.parameters())
for _step in range(3):
optimizer.zero_grad()
wrapper.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper(bundle)
wrapper.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer.step()
wrapper2 = ImplicitFunctionWrapper(implicit_function)
optimizer2 = torch.optim.Adam(implicit_function2.parameters())
implicit_function2.load_state_dict(implicit_function.state_dict())
optimizer2.load_state_dict(optimizer.state_dict())
for _step in range(3):
optimizer2.zero_grad()
wrapper2.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper2(bundle)
wrapper2.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer2.step()

View File

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import unittest
from typing import Dict, List, NamedTuple, Tuple
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.types import FrameAnnotation
class _NT(NamedTuple):
annot: FrameAnnotation
class TestDatasetTypes(unittest.TestCase):
def setUp(self):
self.entry = FrameAnnotation(
frame_number=23,
sequence_name="1",
frame_timestamp=1.2,
image=types.ImageAnnotation(path="/tmp/1.jpg", size=(224, 224)),
mask=types.MaskAnnotation(path="/tmp/1.png", mass=42.0),
viewpoint=types.ViewpointAnnotation(
R=(
(1, 0, 0),
(1, 0, 0),
(1, 0, 0),
),
T=(0, 0, 0),
principal_point=(100, 100),
focal_length=(200, 200),
),
)
def test_asdict_rec(self):
first = [dataclasses.asdict(self.entry)]
second = types._asdict_rec([self.entry])
self.assertEqual(first, second)
def test_parsing(self):
"""Test that we handle collections enclosing dataclasses."""
dct = dataclasses.asdict(self.entry)
parsed = types._dataclass_from_dict(dct, FrameAnnotation)
self.assertEqual(parsed, self.entry)
# namedtuple
parsed = types._dataclass_from_dict(_NT(dct), _NT)
self.assertEqual(parsed.annot, self.entry)
# tuple
parsed = types._dataclass_from_dict((dct,), Tuple[FrameAnnotation])
self.assertEqual(parsed, (self.entry,))
# list
parsed = types._dataclass_from_dict(
[
dct,
],
List[FrameAnnotation],
)
self.assertEqual(
parsed,
[
self.entry,
],
)
# dict
parsed = types._dataclass_from_dict({"k": dct}, Dict[str, FrameAnnotation])
self.assertEqual(parsed, {"k": self.entry})
def test_parsing_vectorized(self):
dct = dataclasses.asdict(self.entry)
self._compare_with_scalar(dct, FrameAnnotation)
self._compare_with_scalar(_NT(dct), _NT)
self._compare_with_scalar((dct,), Tuple[FrameAnnotation])
self._compare_with_scalar([dct], List[FrameAnnotation])
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
def _compare_with_scalar(self, obj, typeannot, repeat=3):
input = [obj] * 3
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
self.assertEqual(len(input), repeat)
gt = types._dataclass_from_dict(obj, typeannot)
self.assertTrue(all(res == gt for res in vect_output))

View File

@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import pytorch3d as pt3d
import torch
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
from pytorch3d.implicitron.tools.config import expand_args_fields
class TestViewsampling(unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
expand_args_fields(ViewSampler)
def _init_view_sampler_problem(self, random_masks):
"""
Generates a view-sampling problem:
- 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2'
- 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively.
- first 50 points in each batch correctly project to the source views,
while the remaining 50 do not land in any projection plane.
- each source view is labeled with image feature tensors of shape 7x100x50,
where all elements of the n-th tensor are set to `n+1`.
- the elements of the source view masks are either set to random binary number
(if `random_masks==True`), or all set to 1 (`random_masks==False`).
- the source view cameras are uniformly distributed on a unit circle
in the x-z plane and look at (0,0,0).
"""
seq_id_camera = ["seq1", "seq1", "seq2", "seq2"]
seq_id_pts = ["seq1", "seq2", "seq2"]
pts_batch = 3
n_pts = 100
n_views = 4
fdim = 7
H = 100
W = 50
# points that land into the projection planes of all cameras
pts_inside = (
torch.nn.functional.normalize(
torch.randn(pts_batch, n_pts // 2, 3, device="cuda"),
dim=-1,
)
* 0.1
)
# move the outside points far above the scene
pts_outside = pts_inside.clone()
pts_outside[:, :, 1] += 1e8
pts = torch.cat([pts_inside, pts_outside], dim=1)
R, T = pt3d.renderer.look_at_view_transform(
dist=1.0,
elev=0.0,
azim=torch.linspace(0, 360, n_views + 1)[:n_views],
degrees=True,
device=pts.device,
)
focal_length = R.new_ones(n_views, 2)
principal_point = R.new_zeros(n_views, 2)
camera = pt3d.renderer.PerspectiveCameras(
R=R,
T=T,
focal_length=focal_length,
principal_point=principal_point,
device=pts.device,
)
feats_map = torch.arange(n_views, device=pts.device, dtype=pts.dtype) + 1
feats = {"feats": feats_map[:, None, None, None].repeat(1, fdim, H, W)}
masks = (
torch.rand(n_views, 1, H, W, device=pts.device, dtype=pts.dtype) > 0.5
).type_as(R)
if not random_masks:
masks[:] = 1.0
return pts, camera, feats, masks, seq_id_camera, seq_id_pts
def test_compare_with_naive(self):
"""
Compares the outputs of the efficient ViewSampler module with a
naive implementation.
"""
(
pts,
camera,
feats,
masks,
seq_id_camera,
seq_id_pts,
) = self._init_view_sampler_problem(True)
for masked_sampling in (True, False):
feats_sampled_n, masks_sampled_n = _view_sample_naive(
pts,
seq_id_pts,
camera,
seq_id_camera,
feats,
masks,
masked_sampling,
)
# make sure we generate the constructor for ViewSampler
expand_args_fields(ViewSampler)
view_sampler = ViewSampler(masked_sampling=masked_sampling)
feats_sampled, masks_sampled = view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
for k in feats_sampled.keys():
self.assertTrue(torch.allclose(feats_sampled[k], feats_sampled_n[k]))
self.assertTrue(torch.allclose(masks_sampled, masks_sampled_n))
def test_viewsampling(self):
"""
Generates a viewsampling problem with predictable outcome, and compares
the ViewSampler's output to the expected result.
"""
(
pts,
camera,
feats,
masks,
seq_id_camera,
seq_id_pts,
) = self._init_view_sampler_problem(False)
expand_args_fields(ViewSampler)
for masked_sampling in (True, False):
view_sampler = ViewSampler(masked_sampling=masked_sampling)
feats_sampled, masks_sampled = view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
n_views = camera.R.shape[0]
n_pts = pts.shape[1]
feat_dim = feats["feats"].shape[1]
pts_batch = pts.shape[0]
n_pts_away = n_pts // 2
for pts_i in range(pts_batch):
for view_i in range(n_views):
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
# points / cameras come from different sequences
gt_masks = pts.new_zeros(n_pts, 1)
gt_feats = pts.new_zeros(n_pts, feat_dim)
else:
gt_masks = pts.new_ones(n_pts, 1)
gt_feats = pts.new_ones(n_pts, feat_dim) * (view_i + 1)
gt_feats[n_pts_away:] = 0.0
if masked_sampling:
gt_masks[n_pts_away:] = 0.0
for k in feats_sampled:
self.assertTrue(
torch.allclose(
feats_sampled[k][pts_i, view_i],
gt_feats,
)
)
self.assertTrue(
torch.allclose(
masks_sampled[pts_i, view_i],
gt_masks,
)
)
def _view_sample_naive(
pts,
seq_id_pts,
camera,
seq_id_camera,
feats,
masks,
masked_sampling,
):
"""
A naive implementation of the forward pass of ViewSampler.
Refer to ViewSampler's docstring for description of the arguments.
"""
pts_batch = pts.shape[0]
n_views = camera.R.shape[0]
n_pts = pts.shape[1]
feats_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
masks_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
for pts_i in range(pts_batch):
for view_i in range(n_views):
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
# points/cameras come from different sequences
feats_sampled_ = {
k: f.new_zeros(n_pts, f.shape[1]) for k, f in feats.items()
}
masks_sampled_ = masks.new_zeros(n_pts, 1)
else:
# same sequence of pts and cameras -> sample
feats_sampled_, masks_sampled_ = _sample_one_view_naive(
camera[view_i],
pts[pts_i],
{k: f[view_i] for k, f in feats.items()},
masks[view_i],
masked_sampling,
sampling_mode="bilinear",
)
feats_sampled[pts_i][view_i] = feats_sampled_
masks_sampled[pts_i][view_i] = masks_sampled_
masks_sampled_cat = torch.stack([torch.stack(m) for m in masks_sampled])
feats_sampled_cat = {}
for k in feats_sampled[0][0].keys():
feats_sampled_cat[k] = torch.stack(
[torch.stack([f_[k] for f_ in f]) for f in feats_sampled]
)
return feats_sampled_cat, masks_sampled_cat
def _sample_one_view_naive(
camera,
pts,
feats,
masks,
masked_sampling,
sampling_mode="bilinear",
):
"""
Sample a single source view.
"""
proj_ndc = camera.transform_points(pts[None])[None, ..., :-1] # 1 x 1 x n_pts x 2
feats_sampled = {
k: pt3d.renderer.ndc_grid_sample(f[None], proj_ndc, mode=sampling_mode).permute(
0, 3, 1, 2
)[0, :, :, 0]
for k, f in feats.items()
} # n_pts x dim
if not masked_sampling:
n_pts = pts.shape[0]
masks_sampled = proj_ndc.new_ones(n_pts, 1)
else:
masks_sampled = pt3d.renderer.ndc_grid_sample(
masks[None],
proj_ndc,
mode=sampling_mode,
align_corners=False,
)[0, 0, 0, :][:, None]
return feats_sampled, masks_sampled