From cdd2142dd531b4f983468bd9b158f4085ee57dd8 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein <669761+bottler@users.noreply.github.com> Date: Mon, 21 Mar 2022 20:20:10 +0000 Subject: [PATCH] implicitron v0 (#1133) Co-authored-by: Jeremy Francis Reizenstein --- projects/implicitron_trainer/README.md | 276 +++++ .../configs/repro_base.yaml | 83 ++ .../configs/repro_feat_extractor_normed.yaml | 16 + .../repro_feat_extractor_transformer.yaml | 16 + .../repro_feat_extractor_unnormed.yaml | 16 + .../configs/repro_multiseq_base.yaml | 31 + .../configs/repro_multiseq_idr_ad.yaml | 64 ++ .../configs/repro_multiseq_nerf_ad.yaml | 9 + .../configs/repro_multiseq_nerf_wce.yaml | 10 + .../configs/repro_multiseq_nerformer.yaml | 16 + .../repro_multiseq_nerformer_angle_w.yaml | 16 + .../repro_multiseq_srn_ad_hypernet.yaml | 32 + ...repro_multiseq_srn_ad_hypernet_noharm.yaml | 10 + .../configs/repro_multiseq_srn_wce.yaml | 30 + .../repro_multiseq_srn_wce_noharm.yaml | 10 + .../configs/repro_singleseq_base.yaml | 41 + .../configs/repro_singleseq_idr.yaml | 57 + .../configs/repro_singleseq_nerf.yaml | 4 + .../configs/repro_singleseq_nerf_wce.yaml | 9 + .../configs/repro_singleseq_nerformer.yaml | 16 + .../configs/repro_singleseq_srn.yaml | 28 + .../configs/repro_singleseq_srn_noharm.yaml | 10 + .../configs/repro_singleseq_srn_wce.yaml | 29 + .../repro_singleseq_srn_wce_noharm.yaml | 10 + .../configs/repro_singleseq_wce_base.yaml | 18 + projects/implicitron_trainer/experiment.py | 714 +++++++++++++ .../visualize_reconstruction.py | 382 +++++++ .../implicitron/dataset/dataloader_zoo.py | 97 ++ pytorch3d/implicitron/dataset/dataset_zoo.py | 260 +++++ .../dataset/implicitron_dataset.py | 988 ++++++++++++++++++ .../dataset/scene_batch_sampler.py | 203 ++++ pytorch3d/implicitron/dataset/types.py | 331 ++++++ pytorch3d/implicitron/dataset/utils.py | 44 + pytorch3d/implicitron/dataset/visualize.py | 95 ++ pytorch3d/implicitron/eval_demo.py | 216 ++++ .../evaluation/evaluate_new_view_synthesis.py | 649 ++++++++++++ pytorch3d/implicitron/models/autodecoder.py | 172 +++ pytorch3d/implicitron/models/base.py | 883 ++++++++++++++++ .../models/implicit_function/__init__.py | 5 + .../models/implicit_function/base.py | 50 + .../implicit_function/idr_feature_field.py | 142 +++ .../neural_radiance_field.py | 542 ++++++++++ .../scene_representation_networks.py | 411 ++++++++ .../models/implicit_function/utils.py | 90 ++ pytorch3d/implicitron/models/metrics.py | 230 ++++ pytorch3d/implicitron/models/model_dbir.py | 139 +++ pytorch3d/implicitron/models/renderer/base.py | 118 +++ .../models/renderer/lstm_renderer.py | 179 ++++ .../models/renderer/multipass_ea.py | 171 +++ .../models/renderer/ray_point_refiner.py | 87 ++ .../models/renderer/ray_sampler.py | 190 ++++ .../models/renderer/ray_tracing.py | 573 ++++++++++ .../implicitron/models/renderer/raymarcher.py | 143 +++ .../implicitron/models/renderer/rgb_net.py | 101 ++ .../models/renderer/sdf_renderer.py | 253 +++++ .../models/resnet_feature_extractor.py | 218 ++++ .../view_pooling/feature_aggregation.py | 666 ++++++++++++ .../models/view_pooling/view_sampling.py | 291 ++++++ .../implicitron/third_party/hyperlayers.py | 254 +++++ .../third_party/pytorch_prototyping.py | 772 ++++++++++++++ pytorch3d/implicitron/tools/__init__.py | 0 pytorch3d/implicitron/tools/camera_utils.py | 142 +++ pytorch3d/implicitron/tools/circle_fitting.py | 231 ++++ pytorch3d/implicitron/tools/config.py | 714 +++++++++++++ pytorch3d/implicitron/tools/depth_cleanup.py | 113 ++ .../tools/eval_video_trajectory.py | 226 ++++ pytorch3d/implicitron/tools/image_utils.py | 53 + pytorch3d/implicitron/tools/metric_utils.py | 231 ++++ pytorch3d/implicitron/tools/model_io.py | 163 +++ .../implicitron/tools/point_cloud_utils.py | 168 +++ pytorch3d/implicitron/tools/rasterize_mc.py | 63 ++ pytorch3d/implicitron/tools/stats.py | 491 +++++++++ pytorch3d/implicitron/tools/utils.py | 183 ++++ pytorch3d/implicitron/tools/video_writer.py | 149 +++ pytorch3d/implicitron/tools/vis_utils.py | 172 +++ tests/implicitron/__init__.py | 5 + tests/implicitron/common_resources.py | 114 ++ tests/implicitron/data/overrides.yaml | 122 +++ tests/implicitron/test_batch_sampler.py | 215 ++++ tests/implicitron/test_circle_fitting.py | 177 ++++ tests/implicitron/test_config.py | 610 +++++++++++ tests/implicitron/test_config_use.py | 81 ++ tests/implicitron/test_dataset_visualize.py | 191 ++++ tests/implicitron/test_eval_cameras.py | 48 + tests/implicitron/test_evaluation.py | 290 +++++ tests/implicitron/test_forward_pass.py | 67 ++ tests/implicitron/test_ray_point_refiner.py | 63 ++ tests/implicitron/test_srn.py | 114 ++ tests/implicitron/test_types.py | 93 ++ tests/implicitron/test_viewsampling.py | 270 +++++ 90 files changed, 17075 insertions(+) create mode 100644 projects/implicitron_trainer/README.md create mode 100644 projects/implicitron_trainer/configs/repro_base.yaml create mode 100644 projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml create mode 100644 projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml create mode 100644 projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_base.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml create mode 100644 projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_base.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_idr.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_srn.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml create mode 100644 projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml create mode 100755 projects/implicitron_trainer/experiment.py create mode 100644 projects/implicitron_trainer/visualize_reconstruction.py create mode 100644 pytorch3d/implicitron/dataset/dataloader_zoo.py create mode 100644 pytorch3d/implicitron/dataset/dataset_zoo.py create mode 100644 pytorch3d/implicitron/dataset/implicitron_dataset.py create mode 100644 pytorch3d/implicitron/dataset/scene_batch_sampler.py create mode 100644 pytorch3d/implicitron/dataset/types.py create mode 100644 pytorch3d/implicitron/dataset/utils.py create mode 100644 pytorch3d/implicitron/dataset/visualize.py create mode 100644 pytorch3d/implicitron/eval_demo.py create mode 100644 pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py create mode 100644 pytorch3d/implicitron/models/autodecoder.py create mode 100644 pytorch3d/implicitron/models/base.py create mode 100644 pytorch3d/implicitron/models/implicit_function/__init__.py create mode 100644 pytorch3d/implicitron/models/implicit_function/base.py create mode 100644 pytorch3d/implicitron/models/implicit_function/idr_feature_field.py create mode 100644 pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py create mode 100644 pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py create mode 100644 pytorch3d/implicitron/models/implicit_function/utils.py create mode 100644 pytorch3d/implicitron/models/metrics.py create mode 100644 pytorch3d/implicitron/models/model_dbir.py create mode 100644 pytorch3d/implicitron/models/renderer/base.py create mode 100644 pytorch3d/implicitron/models/renderer/lstm_renderer.py create mode 100644 pytorch3d/implicitron/models/renderer/multipass_ea.py create mode 100644 pytorch3d/implicitron/models/renderer/ray_point_refiner.py create mode 100644 pytorch3d/implicitron/models/renderer/ray_sampler.py create mode 100644 pytorch3d/implicitron/models/renderer/ray_tracing.py create mode 100644 pytorch3d/implicitron/models/renderer/raymarcher.py create mode 100644 pytorch3d/implicitron/models/renderer/rgb_net.py create mode 100644 pytorch3d/implicitron/models/renderer/sdf_renderer.py create mode 100644 pytorch3d/implicitron/models/resnet_feature_extractor.py create mode 100644 pytorch3d/implicitron/models/view_pooling/feature_aggregation.py create mode 100644 pytorch3d/implicitron/models/view_pooling/view_sampling.py create mode 100644 pytorch3d/implicitron/third_party/hyperlayers.py create mode 100644 pytorch3d/implicitron/third_party/pytorch_prototyping.py create mode 100644 pytorch3d/implicitron/tools/__init__.py create mode 100644 pytorch3d/implicitron/tools/camera_utils.py create mode 100644 pytorch3d/implicitron/tools/circle_fitting.py create mode 100644 pytorch3d/implicitron/tools/config.py create mode 100644 pytorch3d/implicitron/tools/depth_cleanup.py create mode 100644 pytorch3d/implicitron/tools/eval_video_trajectory.py create mode 100644 pytorch3d/implicitron/tools/image_utils.py create mode 100644 pytorch3d/implicitron/tools/metric_utils.py create mode 100644 pytorch3d/implicitron/tools/model_io.py create mode 100644 pytorch3d/implicitron/tools/point_cloud_utils.py create mode 100644 pytorch3d/implicitron/tools/rasterize_mc.py create mode 100644 pytorch3d/implicitron/tools/stats.py create mode 100644 pytorch3d/implicitron/tools/utils.py create mode 100644 pytorch3d/implicitron/tools/video_writer.py create mode 100644 pytorch3d/implicitron/tools/vis_utils.py create mode 100644 tests/implicitron/__init__.py create mode 100644 tests/implicitron/common_resources.py create mode 100644 tests/implicitron/data/overrides.yaml create mode 100644 tests/implicitron/test_batch_sampler.py create mode 100644 tests/implicitron/test_circle_fitting.py create mode 100644 tests/implicitron/test_config.py create mode 100644 tests/implicitron/test_config_use.py create mode 100644 tests/implicitron/test_dataset_visualize.py create mode 100644 tests/implicitron/test_eval_cameras.py create mode 100644 tests/implicitron/test_evaluation.py create mode 100644 tests/implicitron/test_forward_pass.py create mode 100644 tests/implicitron/test_ray_point_refiner.py create mode 100644 tests/implicitron/test_srn.py create mode 100644 tests/implicitron/test_types.py create mode 100644 tests/implicitron/test_viewsampling.py diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md new file mode 100644 index 00000000..be12fc58 --- /dev/null +++ b/projects/implicitron_trainer/README.md @@ -0,0 +1,276 @@ +# Introduction + +Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling the neural-network based representations. + +# License + +Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE). +It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos. +See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses. + + +# Installation + +There are three ways to set up Implicitron, depending on the flexibility level required. +If you only want to train or evaluate models as they are implemented changing only the parameters, you can just install the package. +Implicitron also provides a flexible API that supports user-defined plug-ins; +if you want to re-implement some of the components without changing the high-level pipeline, you need to create a custom launcher script. +The most flexible option, though, is cloning PyTorch3D repo and building it from sources, which allows changing the code in arbitrary ways. +Below, we descibe all three options in more details. + + +## [Option 1] Running an executable from the package + +This option allows you to use the code as is without changing the implementations. +Only configuration can be changed (see [Configuration system](#configuration-system)). + +For this setup, install the dependencies and PyTorch3D from conda following [the guide](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md#1-install-with-cuda-support-from-anaconda-cloud-on-linux-only). Then, install implicitron-specific dependencies: + +```shell +pip install "hydra-core>=1.1" visdom lpips matplotlib +``` + +Runner executable is available as `pytorch3d_implicitron_runner` shell command. +See [Running](#running) section below for examples of training and evaluation commands. + +## [Option 2] Supporting custom implementations + +To plug in custom implementations, for example, of renderer or implicit-function protocols, you need to create your own runner script and import the plug-in implementations there. +First, install PyTorch3D and Implicitron dependencies as described in the previous section. +Then, implement the custom script; copying `pytorch3d/projects/implicitron_trainer/experiment.py` is a good place to start. +See [Custom plugins](#custom-plugins) for more information on how to import implementations and enable them in the configs. + + +## [Option 3] Cloning PyTorch3D repo + +This is the most flexible way to set up Implicitron as it allows changing the code directly. +It allows modifying the high-level rendering pipeline or implementing yet-unsupported loss functions. +Please follow the instructions to [install PyTorch3D from a local clone](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md#2-install-from-a-local-clone). +Then, install Implicitron-specific dependencies: + +```shell +pip install "hydra-core>=1.1" visdom lpips matplotlib +``` + +You are still encouraged to implement custom plugins as above where possible as it makes reusing the code easier. +The executable is located in `pytorch3d/projects/implicitron_trainer`. + + +# Running + +This section assumes that you use the executable provided by the installed package. +If you have a custom `experiment.py` script (as in the Option 2 above), replace the executable with the path to your script. + +## Training + +To run training, pass a yaml config file, followed by a list of overridden arguments. +For example, to train NeRF on the first skateboard sequence from CO3D dataset, you can run: +```shell +pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root= dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 test_when_finished=True exp_dir= +``` + +Here, `--config-path` points to the config path relative to `pytorch3d_implicitron_runner` location; +`--config-name` picks the config (in this case, `repro_singleseq_nerf.yaml`); +`test_when_finished` will launch evaluation script once training is finished. +Replace `` with the location where the dataset in Implicitron format is stored +and `` with a directory where checkpoints will be dumped during training. +Other configuration parameters can be overridden in the same way. +See [Configuration system](#configuration-system) section for more information on this. + + +## Evaluation + +To run evaluation on the latest checkpoint after (or during) training, simply add `eval_only=True` to your training command. + +E.g. for executing the evaluation on the NeRF skateboard sequence, you can run: +```shell +pytorch3d_implicitron_runner --config-path ./configs/ --config-name repro_singleseq_nerf dataset_args.dataset_root= dataset_args.category='skateboard' dataset_args.test_restrict_sequence_id=0 exp_dir= eval_only=True +``` +Evaluation prints the metrics to `stdout` and dumps them to a json file in `exp_dir`. + +## Visualisation + +The script produces a video of renders by a trained model assuming a pre-defined camera trajectory. +In order for it to work, `ffmpeg` needs to be installed: + +```shell +conda install ffmpeg +``` + +Here is an example of calling the script: +```shell +projects/implicitron_trainer/visualize_reconstruction.py exp_dir= visdom_show_preds=True n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]" +``` + +The argument `n_eval_cameras` sets the number of renderring viewpoints sampled on a trajectory, which defaults to a circular fly-around; +`render_size` sets the size of a render passed to the model, which can be resized to `video_size` before writing. + +Rendered videos of images, masks, and depth maps will be saved to `/vis`. + + +# Configuration system + +We use hydra and OmegaConf to parse the configs. +The config schema and default values are defined by the dataclasses implementing the modules. +More specifically, if a class derives from `Configurable`, its fields can be set in config yaml files or overridden in CLI. +For example, `GenericModel` has a field `render_image_width` with the default value 400. +If it is specified in the yaml config file or in CLI command, the new value will be used. + +Configurables can form hierarchies. +For example, `GenericModel` has a field `raysampler: RaySampler`, which is also Configurable. +In the config, inner parameters can be propagated using `_args` postfix, e.g. to change `raysampler.n_pts_per_ray_training` (the number of sampled points per ray), the node `raysampler_args.n_pts_per_ray_training` should be specified. + +The root of the hierarchy is defined by `ExperimentConfig` dataclass. +It has top-level fields like `eval_only` which was used above for running evaluation by adding a CLI override. +Additionally, it has non-leaf nodes like `generic_model_args`, which dispatches the config parameters to `GenericModel`. Thus, changing the model parameters may be achieved in two ways: either by editing the config file, e.g. +```yaml +generic_model_args: + render_image_width: 800 + raysampler_args: + n_pts_per_ray_training: 128 +``` + +or, equivalently, by adding the following to `pytorch3d_implicitron_runner` arguments: + +```shell +generic_model_args.render_image_width=800 generic_model_args.raysampler_args.n_pts_per_ray_training=128 +``` + +See the documentation in `pytorch3d/implicitron/tools/config.py` for more details. + +## Replaceable implementations + +Sometimes changing the model parameters does not provide enough flexibility, and you want to provide a new implementation for a building block. +The configuration system also supports it! +Abstract classes like `BaseRenderer` derive from `ReplaceableBase` instead of `Configurable`. +This means that other Configurables can refer to them using the base type, while the specific implementation is chosen in the config using `_class_type`-postfixed node. +In that case, `_args` node name has to include the implementation type. +More specifically, to change renderer settings, the config will look like this: +```yaml +generic_model_args: + renderer_class_type: LSTMRenderer + renderer_LSTMRenderer_args: + num_raymarch_steps: 10 + hidden_size: 16 +``` + +See the documentation in `pytorch3d/implicitron/tools/config.py` for more details on the configuration system. + +## Custom plugins + +If you have an idea for another implementation of a replaceable component, it can be plugged in without changing the core code. +For that, you need to set up Implicitron through option 2 or 3 above. +Let's say you want to implement a renderer that accumulates opacities similar to an X-ray machine. +First, create a module `x_ray_renderer.py` with a class deriving from `BaseRenderer`: + +```python +from pytorch3d.implicitron.tools.config import registry + +@registry.register +class XRayRenderer(BaseRenderer, torch.nn.Module): + n_pts_per_ray: int = 64 + + # if there are other base classes, make sure to call `super().__init__()` explicitly + def __post_init__(self): + super().__init__() + # custom initialization + + def forward( + self, + ray_bundle, + implicit_functions=[], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs, + ) -> RendererOutput: + ... +``` + +Please note `@registry.register` decorator that registers the plug-in as an implementation of `Renderer`. +IMPORTANT: In order for it to run, the class (or its enclosing module) has to be imported in your launch script. Additionally, this has to be done before parsing the root configuration class `ExperimentConfig`. +Simply add `import .x_ray_renderer` in the beginning of `experiment.py`. + +After that, you should be able to change the config with: +```yaml +generic_model_args: + renderer_class_type: XRayRenderer + renderer_XRayRenderer_args: + n_pts_per_ray: 128 +``` + +to replace the implementation and potentially override the parameters. + +# Code and config structure + +As per above, the config structure is parsed automatically from the module hierarchy. +In particular, model parameters are contained in `generic_model_args` node, and dataset parameters in `dataset_args` node. + +Here is the class structure (single-line edges show aggregation, while double lines show available implementations): +``` +generic_model_args: GenericModel +└-- sequence_autodecoder_args: Autodecoder +└-- raysampler_args: RaySampler +└-- renderer_*_args: BaseRenderer + ╘== MultiPassEmissionAbsorptionRenderer + ╘== LSTMRenderer + ╘== SignedDistanceFunctionRenderer + └-- ray_tracer_args: RayTracing + └-- ray_normal_coloring_network_args: RayNormalColoringNetwork +└-- implicit_function_*_args: ImplicitFunctionBase + ╘== NeuralRadianceFieldImplicitFunction + ╘== SRNImplicitFunction + └-- raymarch_function_args: SRNRaymarchFunction + └-- pixel_generator_args: SRNPixelGenerator + ╘== SRNHyperNetImplicitFunction + └-- hypernet_args: SRNRaymarchHyperNet + └-- pixel_generator_args: SRNPixelGenerator + ╘== IdrFeatureField +└-- image_feature_extractor_args: ResNetFeatureExtractor +└-- view_sampler_args: ViewSampler +└-- feature_aggregator_*_args: FeatureAggregatorBase + ╘== IdentityFeatureAggregator + ╘== AngleWeightedIdentityFeatureAggregator + ╘== AngleWeightedReductionFeatureAggregator + ╘== ReductionFeatureAggregator +solver_args: init_optimizer +dataset_args: dataset_zoo +dataloader_args: dataloader_zoo +``` + +Please look at the annotations of the respective classes or functions for the lists of hyperparameters. + +# Reproducing CO3D experiments + +Common Objects in 3D (CO3D) is a large-scale dataset of videos of rigid objects grouped into 50 common categories. +Implicitron provides implementations and config files to reproduce the results from [the paper](https://arxiv.org/abs/2109.00512). +Please follow [the link](https://github.com/facebookresearch/co3d#automatic-batch-download) for the instructions to download the dataset. +In training and evaluation scripts, use the download location as ``. +It is also possible to define environment variable `CO3D_DATASET_ROOT` instead of specifying it. +To reproduce the experiments from the paper, use the following configs. For single-sequence experiments: + +| Method | config file | +|-----------------|-------------------------------------| +| NeRF | repro_singleseq_nerf.yaml | +| NeRF + WCE | repro_singleseq_nerf_wce.yaml | +| NerFormer | repro_singleseq_nerformer.yaml | +| IDR | repro_singleseq_idr.yaml | +| SRN | repro_singleseq_srn_noharm.yaml | +| SRN + γ | repro_singleseq_srn.yaml | +| SRN + WCE | repro_singleseq_srn_wce_noharm.yaml | +| SRN + WCE + γ | repro_singleseq_srn_wce_noharm.yaml | + +For multi-sequence experiments (without generalisation to new sequences): + +| Method | config file | +|-----------------|--------------------------------------------| +| NeRF + AD | repro_multiseq_nerf_ad.yaml | +| SRN + AD | repro_multiseq_srn_ad_hypernet_noharm.yaml | +| SRN + γ + AD | repro_multiseq_srn_ad_hypernet.yaml | + +For multi-sequence experiments (with generalisation to new sequences): + +| Method | config file | +|-----------------|--------------------------------------| +| NeRF + WCE | repro_multiseq_nerf_wce.yaml | +| NerFormer | repro_multiseq_nerformer.yaml | +| SRN + WCE | repro_multiseq_srn_wce_noharm.yaml | +| SRN + WCE + γ | repro_multiseq_srn_wce.yaml | diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml new file mode 100644 index 00000000..2f0e6e3c --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -0,0 +1,83 @@ +defaults: +- default_config +- _self_ +exp_dir: ./data/exps/base/ +architecture: generic +visualize_interval: 0 +visdom_port: 8097 +dataloader_args: + batch_size: 10 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 +dataset_args: + dataset_root: ${oc.env:CO3D_DATASET_ROOT}" + load_point_clouds: false + mask_depths: false + mask_images: false + n_frames_per_sequence: -1 + test_on_train: true + test_restrict_sequence_id: 0 +generic_model_args: + loss_weights: + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 1.0 + loss_autodecoder_norm: 0.01 + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + output_rasterized_mc: false + chunk_size_grid: 102400 + render_image_height: 400 + render_image_width: 400 + num_passes: 2 + implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_xyz: 256 + n_hidden_neurons_dir: 128 + n_layers_xyz: 8 + append_xyz: + - 5 + latent_dim: 0 + raysampler_args: + n_rays_per_image_sampled_from_mask: 1024 + min_depth: 0.0 + max_depth: 0.0 + scene_extent: 8.0 + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 64 + n_pts_per_ray_fine_evaluation: 64 + append_coarse_samples_to_fine: true + density_noise_std_train: 1.0 + view_sampler_args: + masked_sampling: false + image_feature_extractor_args: + stages: + - 1 + - 2 + - 3 + - 4 + proj_dim: 16 + image_rescale: 0.32 + first_max_pool: false +solver_args: + breed: adam + lr: 0.0005 + lr_policy: multistep + max_epochs: 2000 + momentum: 0.9 + weight_decay: 0.0 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml new file mode 100644 index 00000000..9e00bb12 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_normed.yaml @@ -0,0 +1,16 @@ +generic_model_args: + image_feature_extractor_args: + add_images: true + add_masks: true + first_max_pool: true + image_rescale: 0.375 + l2_norm: true + name: resnet34 + normalize_image: true + pretrained: true + stages: + - 1 + - 2 + - 3 + - 4 + proj_dim: 32 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml new file mode 100644 index 00000000..017be45e --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_transformer.yaml @@ -0,0 +1,16 @@ +generic_model_args: + image_feature_extractor_args: + add_images: true + add_masks: true + first_max_pool: false + image_rescale: 0.375 + l2_norm: true + name: resnet34 + normalize_image: true + pretrained: true + stages: + - 1 + - 2 + - 3 + - 4 + proj_dim: 16 diff --git a/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml new file mode 100644 index 00000000..d1c43458 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_feat_extractor_unnormed.yaml @@ -0,0 +1,16 @@ +generic_model_args: + image_feature_extractor_args: + stages: + - 1 + - 2 + - 3 + first_max_pool: false + proj_dim: -1 + l2_norm: false + image_rescale: 0.375 + name: resnet34 + normalize_image: true + pretrained: true + feature_aggregator_AngleWeightedReductionFeatureAggregator_args: + reduction_functions: + - AVG diff --git a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml new file mode 100644 index 00000000..12abe1ae --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml @@ -0,0 +1,31 @@ +defaults: +- repro_base.yaml +- _self_ +dataloader_args: + batch_size: 10 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 +dataset_args: + assert_single_seq: false + dataset_name: co3d_multisequence + load_point_clouds: false + mask_depths: false + mask_images: false + n_frames_per_sequence: -1 + test_on_train: true + test_restrict_sequence_id: 0 +solver_args: + max_epochs: 3000 + milestones: + - 1000 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml new file mode 100644 index 00000000..0f6c5933 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_idr_ad.yaml @@ -0,0 +1,64 @@ +defaults: +- repro_multiseq_base.yaml +- _self_ +generic_model_args: + loss_weights: + loss_mask_bce: 100.0 + loss_kl: 0.0 + loss_rgb_mse: 1.0 + loss_eikonal: 0.1 + chunk_size_grid: 65536 + num_passes: 1 + output_rasterized_mc: true + sampling_mode_training: mask_sample + view_pool: false + sequence_autodecoder_args: + n_instances: 20000 + init_scale: 1.0 + encoding_dim: 256 + implicit_function_IdrFeatureField_args: + n_harmonic_functions_xyz: 6 + bias: 0.6 + d_in: 3 + d_out: 1 + dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + geometric_init: true + pooled_feature_dim: 0 + skip_in: + - 6 + weight_norm: true + renderer_SignedDistanceFunctionRenderer_args: + ray_tracer_args: + line_search_step: 0.5 + line_step_iters: 3 + n_secant_steps: 8 + n_steps: 100 + object_bounding_sphere: 8.0 + sdf_threshold: 5.0e-05 + ray_normal_coloring_network_args: + d_in: 9 + d_out: 3 + dims: + - 512 + - 512 + - 512 + - 512 + mode: idr + n_harmonic_functions_dir: 4 + pooled_feature_dim: 0 + weight_norm: true + raysampler_args: + n_rays_per_image_sampled_from_mask: 1024 + n_pts_per_ray_training: 0 + n_pts_per_ray_evaluation: 0 + scene_extent: 8.0 + renderer_class_type: SignedDistanceFunctionRenderer + implicit_function_class_type: IdrFeatureField diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml new file mode 100644 index 00000000..f5ce474a --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_ad.yaml @@ -0,0 +1,9 @@ +defaults: +- repro_multiseq_base.yaml +- _self_ +generic_model_args: + chunk_size_grid: 16000 + view_pool: false + sequence_autodecoder_args: + n_instances: 20000 + encoding_dim: 256 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml new file mode 100644 index 00000000..19e5ab0b --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerf_wce.yaml @@ -0,0 +1,10 @@ +defaults: +- repro_multiseq_base.yaml +- repro_feat_extractor_unnormed.yaml +- _self_ +clip_grad: 1.0 +generic_model_args: + chunk_size_grid: 16000 + view_pool: true + raysampler_args: + n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml new file mode 100644 index 00000000..e2a57e96 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer.yaml @@ -0,0 +1,16 @@ +defaults: +- repro_multiseq_base.yaml +- repro_feat_extractor_transformer.yaml +- _self_ +generic_model_args: + chunk_size_grid: 16000 + view_pool: true + raysampler_args: + n_rays_per_image_sampled_from_mask: 800 + n_pts_per_ray_training: 32 + n_pts_per_ray_evaluation: 32 + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 16 + n_pts_per_ray_fine_evaluation: 16 + implicit_function_class_type: NeRFormerImplicitFunction + feature_aggregator_class_type: IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml new file mode 100644 index 00000000..f28b5961 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_nerformer_angle_w.yaml @@ -0,0 +1,16 @@ +defaults: +- repro_multiseq_base.yaml +- repro_feat_extractor_transformer.yaml +- _self_ +generic_model_args: + chunk_size_grid: 16000 + view_pool: true + raysampler_args: + n_rays_per_image_sampled_from_mask: 800 + n_pts_per_ray_training: 32 + n_pts_per_ray_evaluation: 32 + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 16 + n_pts_per_ray_fine_evaluation: 16 + implicit_function_class_type: NeRFormerImplicitFunction + feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml new file mode 100644 index 00000000..a4ff2030 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet.yaml @@ -0,0 +1,32 @@ +defaults: +- repro_multiseq_base.yaml +- _self_ +generic_model_args: + chunk_size_grid: 16000 + view_pool: false + n_train_target_views: -1 + num_passes: 1 + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.001 + depth_neg_penalty: 10000.0 + sequence_autodecoder_args: + encoding_dim: 256 + n_instances: 20000 + raysampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + scene_extent: 0.0 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNHyperNetImplicitFunction +solver_args: + breed: adam + lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml new file mode 100644 index 00000000..42355955 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_ad_hypernet_noharm.yaml @@ -0,0 +1,10 @@ +defaults: +- repro_multiseq_srn_ad_hypernet.yaml +- _self_ +generic_model_args: + num_passes: 1 + implicit_function_SRNHyperNetImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + hypernet_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml new file mode 100644 index 00000000..f59662ea --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce.yaml @@ -0,0 +1,30 @@ +defaults: +- repro_multiseq_base.yaml +- repro_feat_extractor_normed.yaml +- _self_ +generic_model_args: + chunk_size_grid: 32000 + view_pool: true + num_passes: 1 + n_train_target_views: -1 + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.0 + depth_neg_penalty: 10000.0 + raysampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + scene_extent: 0.0 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNImplicitFunction +solver_args: + breed: adam + lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml new file mode 100644 index 00000000..e80d1cb9 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_multiseq_srn_wce_noharm.yaml @@ -0,0 +1,10 @@ +defaults: +- repro_multiseq_srn_wce.yaml +- _self_ +generic_model_args: + num_passes: 1 + implicit_function_SRNImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + raymarch_function_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml new file mode 100644 index 00000000..bbec0f4c --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml @@ -0,0 +1,41 @@ +defaults: +- repro_base +- _self_ +dataloader_args: + batch_size: 1 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 +dataset_args: + dataset_name: co3d_singlesequence + assert_single_seq: true + n_frames_per_sequence: -1 + test_restrict_sequence_id: 0 + test_on_train: false +generic_model_args: + render_image_height: 800 + render_image_width: 800 + log_vars: + - loss_rgb_psnr_fg + - loss_rgb_psnr + - loss_eikonal + - loss_prev_stage_rgb_psnr + - loss_mask_bce + - loss_prev_stage_mask_bce + - loss_rgb_mse + - loss_prev_stage_rgb_mse + - loss_depth_abs + - loss_depth_abs_fg + - loss_kl + - loss_mask_neg_iou + - objective + - epoch + - sec/it +solver_args: + lr: 0.0005 + max_epochs: 400 + milestones: + - 200 + - 300 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml new file mode 100644 index 00000000..28553693 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_idr.yaml @@ -0,0 +1,57 @@ +defaults: +- repro_singleseq_base +- _self_ +generic_model_args: + loss_weights: + loss_mask_bce: 100.0 + loss_kl: 0.0 + loss_rgb_mse: 1.0 + loss_eikonal: 0.1 + chunk_size_grid: 65536 + num_passes: 1 + view_pool: false + implicit_function_IdrFeatureField_args: + n_harmonic_functions_xyz: 6 + bias: 0.6 + d_in: 3 + d_out: 1 + dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + geometric_init: true + pooled_feature_dim: 0 + skip_in: + - 6 + weight_norm: true + renderer_SignedDistanceFunctionRenderer_args: + ray_tracer_args: + line_search_step: 0.5 + line_step_iters: 3 + n_secant_steps: 8 + n_steps: 100 + object_bounding_sphere: 8.0 + sdf_threshold: 5.0e-05 + ray_normal_coloring_network_args: + d_in: 9 + d_out: 3 + dims: + - 512 + - 512 + - 512 + - 512 + mode: idr + n_harmonic_functions_dir: 4 + pooled_feature_dim: 0 + weight_norm: true + raysampler_args: + n_rays_per_image_sampled_from_mask: 1024 + n_pts_per_ray_training: 0 + n_pts_per_ray_evaluation: 0 + renderer_class_type: SignedDistanceFunctionRenderer + implicit_function_class_type: IdrFeatureField diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml new file mode 100644 index 00000000..d6d45585 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf.yaml @@ -0,0 +1,4 @@ +defaults: +- repro_singleseq_base +- _self_ +exp_dir: ./data/nerf_single_apple/ diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml new file mode 100644 index 00000000..93f3ff5c --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerf_wce.yaml @@ -0,0 +1,9 @@ +defaults: +- repro_singleseq_wce_base.yaml +- repro_feat_extractor_unnormed.yaml +- _self_ +generic_model_args: + chunk_size_grid: 16000 + view_pool: true + raysampler_args: + n_rays_per_image_sampled_from_mask: 850 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml new file mode 100644 index 00000000..215a6477 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_nerformer.yaml @@ -0,0 +1,16 @@ +defaults: +- repro_singleseq_wce_base.yaml +- repro_feat_extractor_transformer.yaml +- _self_ +generic_model_args: + chunk_size_grid: 16000 + view_pool: true + implicit_function_class_type: NeRFormerImplicitFunction + raysampler_args: + n_rays_per_image_sampled_from_mask: 800 + n_pts_per_ray_training: 32 + n_pts_per_ray_evaluation: 32 + renderer_MultiPassEmissionAbsorptionRenderer_args: + n_pts_per_ray_fine_training: 16 + n_pts_per_ray_fine_evaluation: 16 + feature_aggregator_class_type: IdentityFeatureAggregator diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml new file mode 100644 index 00000000..98575daf --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn.yaml @@ -0,0 +1,28 @@ +defaults: +- repro_singleseq_base.yaml +- _self_ +generic_model_args: + num_passes: 1 + chunk_size_grid: 32000 + view_pool: false + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.0 + depth_neg_penalty: 10000.0 + raysampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + scene_extent: 0.0 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNImplicitFunction +solver_args: + breed: adam + lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml new file mode 100644 index 00000000..dd81241c --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_noharm.yaml @@ -0,0 +1,10 @@ +defaults: +- repro_singleseq_srn.yaml +- _self_ +generic_model_args: + num_passes: 1 + implicit_function_SRNImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + raymarch_function_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml new file mode 100644 index 00000000..57e10183 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce.yaml @@ -0,0 +1,29 @@ +defaults: +- repro_singleseq_wce_base +- repro_feat_extractor_normed.yaml +- _self_ +generic_model_args: + num_passes: 1 + chunk_size_grid: 32000 + view_pool: true + loss_weights: + loss_rgb_mse: 200.0 + loss_prev_stage_rgb_mse: 0.0 + loss_mask_bce: 1.0 + loss_prev_stage_mask_bce: 0.0 + loss_autodecoder_norm: 0.0 + depth_neg_penalty: 10000.0 + raysampler_args: + n_rays_per_image_sampled_from_mask: 2048 + min_depth: 0.05 + max_depth: 0.05 + scene_extent: 0.0 + n_pts_per_ray_training: 1 + n_pts_per_ray_evaluation: 1 + stratified_point_sampling_training: false + stratified_point_sampling_evaluation: false + renderer_class_type: LSTMRenderer + implicit_function_class_type: SRNImplicitFunction +solver_args: + breed: adam + lr: 5.0e-05 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml new file mode 100644 index 00000000..2a0c3fd9 --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_srn_wce_noharm.yaml @@ -0,0 +1,10 @@ +defaults: +- repro_singleseq_srn_wce.yaml +- _self_ +generic_model_args: + num_passes: 1 + implicit_function_SRNImplicitFunction_args: + pixel_generator_args: + n_harmonic_functions: 0 + raymarch_function_args: + n_harmonic_functions: 0 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml new file mode 100644 index 00000000..f8ae682a --- /dev/null +++ b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml @@ -0,0 +1,18 @@ +defaults: +- repro_singleseq_base +- _self_ +dataloader_args: + batch_size: 10 + dataset_len: 1000 + dataset_len_val: 1 + num_workers: 8 + images_per_seq_options: + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py new file mode 100755 index 00000000..64090153 --- /dev/null +++ b/projects/implicitron_trainer/experiment.py @@ -0,0 +1,714 @@ +#!/usr/bin/env python +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""" +This file is the entry point for launching experiments with Implicitron. + +Main functions +--------------- +- `run_training` is the wrapper for the train, val, test loops + and checkpointing +- `trainvalidate` is the inner loop which runs the model forward/backward + pass, visualizations and metric printing + +Launch Training +--------------- +Experiment config .yaml files are located in the +`projects/implicitron_trainer/configs` folder. To launch +an experiment, specify the name of the file. Specific config values can +also be overridden from the command line, for example: + +``` +./experiment.py --config-name base_config.yaml override.param.one=42 override.param.two=84 +``` + +To run an experiment on a specific GPU, specify the `gpu_idx` key +in the config file / CLI. To run on a different device, specify the +device in `run_training`. + +Outputs +-------- +The outputs of the experiment are saved and logged in multiple ways: + - Checkpoints: + Model, optimizer and stats are stored in the directory + named by the `exp_dir` key from the config file / CLI parameters. + - Stats + Stats are logged and plotted to the file "train_stats.pdf" in the + same directory. The stats are also saved as part of the checkpoint file. + - Visualizations + Prredictions are plotted to a visdom server running at the + port specified by the `visdom_server` and `visdom_port` keys in the + config file. + +""" + +import copy +import json +import logging +import os +import random +import time +import warnings +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, Tuple + +import hydra +import lpips +import numpy as np +import torch +import tqdm +from omegaconf import DictConfig, OmegaConf +from packaging import version +from pytorch3d.implicitron.dataset import utils as ds_utils +from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo +from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo +from pytorch3d.implicitron.dataset.implicitron_dataset import ( + ImplicitronDataset, + FrameData, +) +from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate +from pytorch3d.implicitron.models.base import EvaluationMode, GenericModel +from pytorch3d.implicitron.tools import model_io, vis_utils +from pytorch3d.implicitron.tools.config import ( + get_default_args_field, + remove_unused_components, +) +from pytorch3d.implicitron.tools.stats import Stats +from pytorch3d.renderer.cameras import CamerasBase + +logger = logging.getLogger(__name__) + +if version.parse(hydra.__version__) < version.Version("1.1"): + raise ValueError( + f"Hydra version {hydra.__version__} is too old." + " (Implicitron requires version 1.1 or later.)" + ) + +try: + # only makes sense in FAIR cluster + import pytorch3d.implicitron.fair_cluster.slurm # noqa: F401 +except ModuleNotFoundError: + pass + + +def init_model( + cfg: DictConfig, + force_load: bool = False, + clear_stats: bool = False, + load_model_only: bool = False, +) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]: + """ + Returns an instance of `GenericModel`. + + If `cfg.resume` is set or `force_load` is true, + attempts to load the last checkpoint from `cfg.exp_dir`. Failure to do so + will return the model with initial weights, unless `force_load` is passed, + in which case a FileNotFoundError is raised. + + Args: + force_load: If true, force load model from checkpoint even if + cfg.resume is false. + clear_stats: If true, clear the stats object loaded from checkpoint + load_model_only: If true, load only the model weights from checkpoint + and do not load the state of the optimizer and stats. + + Returns: + model: The model with optionally loaded weights from checkpoint + stats: The stats structure (optionally loaded from checkpoint) + optimizer_state: The optimizer state dict containing + `state` and `param_groups` keys (optionally loaded from checkpoint) + + Raise: + FileNotFoundError if `force_load` is passed but checkpoint is not found. + """ + + # Initialize the model + if cfg.architecture == "generic": + model = GenericModel(**cfg.generic_model_args) + else: + raise ValueError(f"No such arch {cfg.architecture}.") + + # Determine the network outputs that should be logged + if hasattr(model, "log_vars"): + log_vars = copy.deepcopy(list(model.log_vars)) + else: + log_vars = ["objective"] + + visdom_env_charts = vis_utils.get_visdom_env(cfg) + "_charts" + + # Init the stats struct + stats = Stats( + log_vars, + visdom_env=visdom_env_charts, + verbose=False, + visdom_server=cfg.visdom_server, + visdom_port=cfg.visdom_port, + ) + + # Retrieve the last checkpoint + if cfg.resume_epoch > 0: + model_path = model_io.get_checkpoint(cfg.exp_dir, cfg.resume_epoch) + else: + model_path = model_io.find_last_checkpoint(cfg.exp_dir) + + optimizer_state = None + if model_path is not None: + logger.info("found previous model %s" % model_path) + if force_load or cfg.resume: + logger.info(" -> resuming") + if load_model_only: + model_state_dict = torch.load(model_io.get_model_path(model_path)) + stats_load, optimizer_state = None, None + else: + model_state_dict, stats_load, optimizer_state = model_io.load_model( + model_path + ) + + # Determine if stats should be reset + if not clear_stats: + if stats_load is None: + logger.info("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n") + last_epoch = model_io.parse_epoch_from_model_path(model_path) + logger.info(f"Estimated resume epoch = {last_epoch}") + + # Reset the stats struct + for _ in range(last_epoch + 1): + stats.new_epoch() + assert last_epoch == stats.epoch + else: + stats = stats_load + + # Update stats properties incase it was reset on load + stats.visdom_env = visdom_env_charts + stats.visdom_server = cfg.visdom_server + stats.visdom_port = cfg.visdom_port + stats.plot_file = os.path.join(cfg.exp_dir, "train_stats.pdf") + stats.synchronize_logged_vars(log_vars) + else: + logger.info(" -> clearing stats") + + try: + # TODO: fix on creation of the buffers + # after the hack above, this will not pass in most cases + # ... but this is fine for now + model.load_state_dict(model_state_dict, strict=True) + except RuntimeError as e: + logger.error(e) + logger.info("Cant load state dict in strict mode! -> trying non-strict") + model.load_state_dict(model_state_dict, strict=False) + model.log_vars = log_vars + else: + logger.info(" -> but not resuming -> starting from scratch") + elif force_load: + raise FileNotFoundError(f"Cannot find a checkpoint in {cfg.exp_dir}!") + + return model, stats, optimizer_state + + +def init_optimizer( + model: GenericModel, + optimizer_state: Optional[Dict[str, Any]], + last_epoch: int, + breed: bool = "adam", + weight_decay: float = 0.0, + lr_policy: str = "multistep", + lr: float = 0.0005, + gamma: float = 0.1, + momentum: float = 0.9, + betas: Tuple[float] = (0.9, 0.999), + milestones: tuple = (), + max_epochs: int = 1000, +): + """ + Initialize the optimizer (optionally from checkpoint state) + and the learning rate scheduler. + + Args: + model: The model with optionally loaded weights + optimizer_state: The state dict for the optimizer. If None + it has not been loaded from checkpoint + last_epoch: If the model was loaded from checkpoint this will be the + number of the last epoch that was saved + breed: The type of optimizer to use e.g. adam + weight_decay: The optimizer weight_decay (L2 penalty on model weights) + lr_policy: The policy to use for learning rate. Currently, only "multistep: + is supported. + lr: The value for the initial learning rate + gamma: Multiplicative factor of learning rate decay + momentum: Momentum factor for SGD optimizer + betas: Coefficients used for computing running averages of gradient and its square + in the Adam optimizer + milestones: List of increasing epoch indices at which the learning rate is + modified + max_epochs: The maximum number of epochs to run the optimizer for + + Returns: + optimizer: Optimizer module, optionally loaded from checkpoint + scheduler: Learning rate scheduler module + + Raise: + ValueError if `breed` or `lr_policy` are not supported. + """ + + # Get the parameters to optimize + if hasattr(model, "_get_param_groups"): # use the model function + p_groups = model._get_param_groups(lr, wd=weight_decay) + else: + allprm = [prm for prm in model.parameters() if prm.requires_grad] + p_groups = [{"params": allprm, "lr": lr}] + + # Intialize the optimizer + if breed == "sgd": + optimizer = torch.optim.SGD( + p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay + ) + elif breed == "adagrad": + optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay) + elif breed == "adam": + optimizer = torch.optim.Adam( + p_groups, lr=lr, betas=betas, weight_decay=weight_decay + ) + else: + raise ValueError("no such solver type %s" % breed) + logger.info(" -> solver type = %s" % breed) + + # Load state from checkpoint + if optimizer_state is not None: + logger.info(" -> setting loaded optimizer state") + optimizer.load_state_dict(optimizer_state) + + # Initialize the learning rate scheduler + if lr_policy == "multistep": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + optimizer, + milestones=milestones, + gamma=gamma, + ) + else: + raise ValueError("no such lr policy %s" % lr_policy) + + # When loading from checkpoint, this will make sure that the + # lr is correctly set even after returning + for _ in range(last_epoch): + scheduler.step() + + # Add the max epochs here + scheduler.max_epochs = max_epochs + + optimizer.zero_grad() + return optimizer, scheduler + + +def trainvalidate( + model, + stats, + epoch, + loader, + optimizer, + validation, + bp_var: str = "objective", + metric_print_interval: int = 5, + visualize_interval: int = 100, + visdom_env_root: str = "trainvalidate", + clip_grad: float = 0.0, + device: str = "cuda:0", + **kwargs, +) -> None: + """ + This is the main loop for training and evaluation including: + model forward pass, loss computation, backward pass and visualization. + + Args: + model: The model module optionally loaded from checkpoint + stats: The stats struct, also optionally loaded from checkpoint + epoch: The index of the current epoch + loader: The dataloader to use for the loop + optimizer: The optimizer module optionally loaded from checkpoint + validation: If true, run the loop with the model in eval mode + and skip the backward pass + bp_var: The name of the key in the model output `preds` dict which + should be used as the loss for the backward pass. + metric_print_interval: The batch interval at which the stats should be + logged. + visualize_interval: The batch interval at which the visualizations + should be plotted + visdom_env_root: The name of the visdom environment to use for plotting + clip_grad: Optionally clip the gradient norms. + If set to a value <=0.0, no clipping + device: The device on which to run the model. + + Returns: + None + """ + + if validation: + model.eval() + trainmode = "val" + else: + model.train() + trainmode = "train" + + t_start = time.time() + + # get the visdom env name + visdom_env_imgs = visdom_env_root + "_images_" + trainmode + viz = vis_utils.get_visdom_connection( + server=stats.visdom_server, + port=stats.visdom_port, + ) + + # Iterate through the batches + n_batches = len(loader) + for it, batch in enumerate(loader): + last_iter = it == n_batches - 1 + + # move to gpu where possible (in place) + net_input = batch.to(device) + + # run the forward pass + if not validation: + optimizer.zero_grad() + preds = model(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}) + else: + with torch.no_grad(): + preds = model( + **{**net_input, "evaluation_mode": EvaluationMode.EVALUATION} + ) + + # make sure we dont overwrite something + assert all(k not in preds for k in net_input.keys()) + # merge everything into one big dict + preds.update(net_input) + + # update the stats logger + stats.update(preds, time_start=t_start, stat_set=trainmode) + assert stats.it[trainmode] == it, "inconsistent stat iteration number!" + + # print textual status update + if it % metric_print_interval == 0 or last_iter: + stats.print(stat_set=trainmode, max_it=n_batches) + + # visualize results + if visualize_interval > 0 and it % visualize_interval == 0: + prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" + + model.visualize( + viz, + visdom_env_imgs, + preds, + prefix, + ) + + # optimizer step + if not validation: + loss = preds[bp_var] + assert torch.isfinite(loss).all(), "Non-finite loss!" + # backprop + loss.backward() + if clip_grad > 0.0: + # Optionally clip the gradient norms. + total_norm = torch.nn.utils.clip_grad_norm( + model.parameters(), clip_grad + ) + if total_norm > clip_grad: + logger.info( + f"Clipping gradient: {total_norm}" + + f" with coef {clip_grad / total_norm}." + ) + + optimizer.step() + + +def run_training(cfg: DictConfig, device: str = "cpu"): + """ + Entry point to run the training and validation loops + based on the specified config file. + """ + + # set the debug mode + if cfg.detect_anomaly: + logger.info("Anomaly detection!") + torch.autograd.set_detect_anomaly(cfg.detect_anomaly) + + # create the output folder + os.makedirs(cfg.exp_dir, exist_ok=True) + _seed_all_random_engines(cfg.seed) + remove_unused_components(cfg) + + # dump the exp config to the exp dir + try: + cfg_filename = os.path.join(cfg.exp_dir, "expconfig.yaml") + OmegaConf.save(config=cfg, f=cfg_filename) + except PermissionError: + warnings.warn("Cant dump config due to insufficient permissions!") + + # setup datasets + datasets = dataset_zoo(**cfg.dataset_args) + cfg.dataloader_args["dataset_name"] = cfg.dataset_args["dataset_name"] + dataloaders = dataloader_zoo(datasets, **cfg.dataloader_args) + + # init the model + model, stats, optimizer_state = init_model(cfg) + start_epoch = stats.epoch + 1 + + # move model to gpu + model.to(device) + + # only run evaluation on the test dataloader + if cfg.eval_only: + _eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device) + return + + # init the optimizer + optimizer, scheduler = init_optimizer( + model, + optimizer_state=optimizer_state, + last_epoch=start_epoch, + **cfg.solver_args, + ) + + # check the scheduler and stats have been initialized correctly + assert scheduler.last_epoch == stats.epoch + 1 + assert scheduler.last_epoch == start_epoch + + past_scheduler_lrs = [] + # loop through epochs + for epoch in range(start_epoch, cfg.solver_args.max_epochs): + # automatic new_epoch and plotting of stats at every epoch start + with stats: + + # Make sure to re-seed random generators to ensure reproducibility + # even after restart. + _seed_all_random_engines(cfg.seed + epoch) + + cur_lr = float(scheduler.get_last_lr()[-1]) + logger.info(f"scheduler lr = {cur_lr:1.2e}") + past_scheduler_lrs.append(cur_lr) + + # train loop + trainvalidate( + model, + stats, + epoch, + dataloaders["train"], + optimizer, + False, + visdom_env_root=vis_utils.get_visdom_env(cfg), + device=device, + **cfg, + ) + + # val loop (optional) + if "val" in dataloaders and epoch % cfg.validation_interval == 0: + trainvalidate( + model, + stats, + epoch, + dataloaders["val"], + optimizer, + True, + visdom_env_root=vis_utils.get_visdom_env(cfg), + device=device, + **cfg, + ) + + # eval loop (optional) + if ( + "test" in dataloaders + and cfg.test_interval > 0 + and epoch % cfg.test_interval == 0 + ): + run_eval(cfg, model, stats, dataloaders["test"], device=device) + + assert stats.epoch == epoch, "inconsistent stats!" + + # delete previous models if required + # save model + if cfg.store_checkpoints: + if cfg.store_checkpoints_purge > 0: + for prev_epoch in range(epoch - cfg.store_checkpoints_purge): + model_io.purge_epoch(cfg.exp_dir, prev_epoch) + outfile = model_io.get_checkpoint(cfg.exp_dir, epoch) + model_io.safe_save_model(model, stats, outfile, optimizer=optimizer) + + scheduler.step() + + new_lr = float(scheduler.get_last_lr()[-1]) + if new_lr != cur_lr: + logger.info(f"LR change! {cur_lr} -> {new_lr}") + + if cfg.test_when_finished: + _eval_and_dump(cfg, datasets, dataloaders, model, stats, device=device) + + +def _eval_and_dump(cfg, datasets, dataloaders, model, stats, device): + """ + Run the evaluation loop with the test data loader and + save the predictions to the `exp_dir`. + """ + + if "test" not in dataloaders: + raise ValueError('Dataloaders have to contain the "test" entry for eval!') + + eval_task = cfg.dataset_args["dataset_name"].split("_")[-1] + all_source_cameras = ( + _get_all_source_cameras(datasets["train"]) + if eval_task == "singlesequence" + else None + ) + results = run_eval( + cfg, model, all_source_cameras, dataloaders["test"], eval_task, device=device + ) + + # add the evaluation epoch to the results + for r in results: + r["eval_epoch"] = int(stats.epoch) + + logger.info("Evaluation results") + evaluate.pretty_print_nvs_metrics(results) + + with open(os.path.join(cfg.exp_dir, "results_test.json"), "w") as f: + json.dump(results, f) + + +def _get_eval_frame_data(frame_data): + """ + Masks the unknown image data to make sure we cannot use it at model evaluation time. + """ + frame_data_for_eval = copy.deepcopy(frame_data) + is_known = ds_utils.is_known_frame(frame_data.frame_type).type_as( + frame_data.image_rgb + )[:, None, None, None] + for k in ("image_rgb", "depth_map", "fg_probability", "mask_crop"): + value_masked = getattr(frame_data_for_eval, k).clone() * is_known + setattr(frame_data_for_eval, k, value_masked) + return frame_data_for_eval + + +def run_eval(cfg, model, all_source_cameras, loader, task, device): + """ + Run the evaluation loop on the test dataloader + """ + lpips_model = lpips.LPIPS(net="vgg") + lpips_model = lpips_model.to(device) + + model.eval() + + per_batch_eval_results = [] + logger.info("Evaluating model ...") + for frame_data in tqdm.tqdm(loader): + frame_data = frame_data.to(device) + + # mask out the unknown images so that the model does not see them + frame_data_for_eval = _get_eval_frame_data(frame_data) + + with torch.no_grad(): + preds = model( + **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION} + ) + nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) + per_batch_eval_results.append( + evaluate.eval_batch( + frame_data, + nvs_prediction, + bg_color="black", + lpips_model=lpips_model, + source_cameras=all_source_cameras, + ) + ) + + _, category_result = evaluate.summarize_nvs_eval_results( + per_batch_eval_results, task + ) + + return category_result["results"] + + +def _get_all_source_cameras( + dataset: ImplicitronDataset, + num_workers: int = 8, +) -> CamerasBase: + """ + Load and return all the source cameras in the training dataset + """ + + all_frame_data = next( + iter( + torch.utils.data.DataLoader( + dataset, + shuffle=False, + batch_size=len(dataset), + num_workers=num_workers, + collate_fn=FrameData.collate, + ) + ) + ) + + is_source = ds_utils.is_known_frame(all_frame_data.frame_type) + source_cameras = all_frame_data.camera[torch.where(is_source)[0]] + return source_cameras + + +def _seed_all_random_engines(seed: int): + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + +@dataclass(eq=False) +class ExperimentConfig: + generic_model_args: DictConfig = get_default_args_field(GenericModel) + solver_args: DictConfig = get_default_args_field(init_optimizer) + dataset_args: DictConfig = get_default_args_field(dataset_zoo) + dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) + architecture: str = "generic" + detect_anomaly: bool = False + eval_only: bool = False + exp_dir: str = "./data/default_experiment/" + exp_idx: int = 0 + gpu_idx: int = 0 + metric_print_interval: int = 5 + resume: bool = True + resume_epoch: int = -1 + seed: int = 0 + store_checkpoints: bool = True + store_checkpoints_purge: int = 1 + test_interval: int = -1 + test_when_finished: bool = False + validation_interval: int = 1 + visdom_env: str = "" + visdom_port: int = 8097 + visdom_server: str = "http://127.0.0.1" + visualize_interval: int = 1000 + clip_grad: float = 0.0 + + hydra: dict = field( + default_factory=lambda: { + "run": {"dir": "."}, # Make hydra not change the working dir. + "output_subdir": None, # disable storing the .hydra logs + } + ) + + +cs = hydra.core.config_store.ConfigStore.instance() +cs.store(name="default_config", node=ExperimentConfig) + + +@hydra.main(config_path="./configs/", config_name="default_config") +def experiment(cfg: DictConfig) -> None: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) + # Set the device + device = "cpu" + if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count(): + device = f"cuda:{cfg.gpu_idx}" + logger.info(f"Running experiment on device: {device}") + run_training(cfg, device) + + +if __name__ == "__main__": + experiment() diff --git a/projects/implicitron_trainer/visualize_reconstruction.py b/projects/implicitron_trainer/visualize_reconstruction.py new file mode 100644 index 00000000..28fa9727 --- /dev/null +++ b/projects/implicitron_trainer/visualize_reconstruction.py @@ -0,0 +1,382 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Script to visualize a previously trained model. Example call: + + projects/implicitron_trainer/visualize_reconstruction.py + exp_dir='./exps/checkpoint_dir' visdom_show_preds=True visdom_port=8097 + n_eval_cameras=40 render_size="[64,64]" video_size="[256,256]" +""" + +import math +import os +import random +import sys +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as Fu +from experiment import init_model +from omegaconf import OmegaConf +from pytorch3d.implicitron.dataset.dataset_zoo import dataset_zoo +from pytorch3d.implicitron.dataset.implicitron_dataset import ( + FrameData, + ImplicitronDataset, +) +from pytorch3d.implicitron.dataset.utils import is_train_frame +from pytorch3d.implicitron.models.base import EvaluationMode +from pytorch3d.implicitron.tools.configurable import get_default_args +from pytorch3d.implicitron.tools.eval_video_trajectory import ( + generate_eval_video_cameras, +) +from pytorch3d.implicitron.tools.video_writer import VideoWriter +from pytorch3d.implicitron.tools.vis_utils import ( + get_visdom_connection, + make_depth_image, +) +from tqdm import tqdm + + +def render_sequence( + dataset: ImplicitronDataset, + sequence_name: str, + model: torch.nn.Module, + video_path, + n_eval_cameras=40, + fps=20, + max_angle=2 * math.pi, + trajectory_type="circular_lsq_fit", + trajectory_scale=1.1, + scene_center=(0.0, 0.0, 0.0), + up=(0.0, -1.0, 0.0), + traj_offset=0.0, + n_source_views=9, + viz_env="debug", + visdom_show_preds=False, + visdom_server="http://127.0.0.1", + visdom_port=8097, + num_workers=10, + seed=None, + video_resize=None, +): + if seed is None: + seed = hash(sequence_name) + print(f"Loading all data of sequence '{sequence_name}'.") + seq_idx = dataset.seq_to_idx[sequence_name] + train_data = _load_whole_dataset(dataset, seq_idx, num_workers=num_workers) + assert all(train_data.sequence_name[0] == sn for sn in train_data.sequence_name) + sequence_set_name = "train" if is_train_frame(train_data.frame_type)[0] else "test" + print(f"Sequence set = {sequence_set_name}.") + train_cameras = train_data.camera + time = torch.linspace(0, max_angle, n_eval_cameras + 1)[:n_eval_cameras] + test_cameras = generate_eval_video_cameras( + train_cameras, + time=time, + n_eval_cams=n_eval_cameras, + trajectory_type=trajectory_type, + trajectory_scale=trajectory_scale, + scene_center=scene_center, + up=up, + focal_length=None, + principal_point=torch.zeros(n_eval_cameras, 2), + traj_offset_canonical=[0.0, 0.0, traj_offset], + ) + + # sample the source views reproducibly + with torch.random.fork_rng(): + torch.manual_seed(seed) + source_views_i = torch.randperm(len(seq_idx))[:n_source_views] + # add the first dummy view that will get replaced with the target camera + source_views_i = Fu.pad(source_views_i, [1, 0]) + source_views = [seq_idx[i] for i in source_views_i.tolist()] + batch = _load_whole_dataset(dataset, source_views, num_workers=num_workers) + assert all(batch.sequence_name[0] == sn for sn in batch.sequence_name) + + preds_total = [] + for n in tqdm(range(n_eval_cameras), total=n_eval_cameras): + # set the first batch camera to the target camera + for k in ("R", "T", "focal_length", "principal_point"): + getattr(batch.camera, k)[0] = getattr(test_cameras[n], k) + + # Move to cuda + net_input = batch.cuda() + with torch.no_grad(): + preds = model(**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}) + + # make sure we dont overwrite something + assert all(k not in preds for k in net_input.keys()) + preds.update(net_input) # merge everything into one big dict + + # Render the predictions to images + rendered_pred = images_from_preds(preds) + preds_total.append(rendered_pred) + + # show the preds every 5% of the export iterations + if visdom_show_preds and ( + n % max(n_eval_cameras // 20, 1) == 0 or n == n_eval_cameras - 1 + ): + viz = get_visdom_connection(server=visdom_server, port=visdom_port) + show_predictions( + preds_total, + sequence_name=batch.sequence_name[0], + viz=viz, + viz_env=viz_env, + ) + + print(f"Exporting videos for sequence {sequence_name} ...") + generate_prediction_videos( + preds_total, + sequence_name=batch.sequence_name[0], + viz=viz, + viz_env=viz_env, + fps=fps, + video_path=video_path, + resize=video_resize, + ) + + +def _load_whole_dataset(dataset, idx, num_workers=10): + load_all_dataloader = torch.utils.data.DataLoader( + torch.utils.data.Subset(dataset, idx), + batch_size=len(idx), + num_workers=num_workers, + shuffle=False, + collate_fn=FrameData.collate, + ) + return next(iter(load_all_dataloader)) + + +def images_from_preds(preds): + imout = {} + for k in ( + "image_rgb", + "images_render", + "fg_probability", + "masks_render", + "depths_render", + "depth_map", + "_all_source_images", + ): + if k == "_all_source_images" and "image_rgb" in preds: + src_ims = preds["image_rgb"][1:].cpu().detach().clone() + v = _stack_images(src_ims, None)[None] + else: + if k not in preds or preds[k] is None: + print(f"cant show {k}") + continue + v = preds[k].cpu().detach().clone() + if k.startswith("depth"): + mask_resize = Fu.interpolate( + preds["masks_render"], + size=preds[k].shape[2:], + mode="nearest", + ) + v = make_depth_image(preds[k], mask_resize) + if v.shape[1] == 1: + v = v.repeat(1, 3, 1, 1) + imout[k] = v.detach().cpu() + + return imout + + +def _stack_images(ims, size): + ba = ims.shape[0] + H = int(np.ceil(np.sqrt(ba))) + W = H + n_add = H * W - ba + if n_add > 0: + ims = torch.cat((ims, torch.zeros_like(ims[:1]).repeat(n_add, 1, 1, 1))) + + ims = ims.view(H, W, *ims.shape[1:]) + cated = torch.cat([torch.cat(list(row), dim=2) for row in ims], dim=1) + if size is not None: + cated = Fu.interpolate(cated[None], size=size, mode="bilinear")[0] + return cated.clamp(0.0, 1.0) + + +def show_predictions( + preds, + sequence_name, + viz, + viz_env="visualizer", + predicted_keys=( + "images_render", + "masks_render", + "depths_render", + "_all_source_images", + ), + n_samples=10, + one_image_width=200, +): + """Given a list of predictions visualize them into a single image using visdom.""" + assert isinstance(preds, list) + + pred_all = [] + # Randomly choose a subset of the rendered images, sort by ordr in the sequence + n_samples = min(n_samples, len(preds)) + pred_idx = sorted(random.sample(list(range(len(preds))), n_samples)) + for predi in pred_idx: + # Make the concatentation for the same camera vertically + pred_all.append( + torch.cat( + [ + torch.nn.functional.interpolate( + preds[predi][k].cpu(), + scale_factor=one_image_width / preds[predi][k].shape[3], + mode="bilinear", + ).clamp(0.0, 1.0) + for k in predicted_keys + ], + dim=2, + ) + ) + # Concatenate the images horizontally + pred_all_cat = torch.cat(pred_all, dim=3)[0] + viz.image( + pred_all_cat, + win="show_predictions", + env=viz_env, + opts={"title": f"pred_{sequence_name}"}, + ) + + +def generate_prediction_videos( + preds, + sequence_name, + viz, + viz_env="visualizer", + predicted_keys=( + "images_render", + "masks_render", + "depths_render", + "_all_source_images", + ), + fps=20, + video_path="/tmp/video", + resize=None, +): + """Given a list of predictions create and visualize rotating videos of the + objects using visdom. + """ + assert isinstance(preds, list) + + # make sure the target video directory exists + os.makedirs(os.path.dirname(video_path), exist_ok=True) + + # init a video writer for each predicted key + vws = {} + for k in predicted_keys: + vws[k] = VideoWriter(out_path=f"{video_path}_{sequence_name}_{k}.mp4", fps=fps) + + for rendered_pred in tqdm(preds): + for k in predicted_keys: + vws[k].write_frame( + rendered_pred[k][0].detach().cpu().numpy(), + resize=resize, + ) + + for k in predicted_keys: + vws[k].get_video(quiet=True) + print(f"Generated {vws[k].out_path}.") + viz.video( + videofile=vws[k].out_path, + env=viz_env, + win=k, # we reuse the same window otherwise visdom dies + opts={"title": sequence_name + " " + k}, + ) + + +def export_scenes( + exp_dir: str = "", + restrict_sequence_name: Optional[str] = None, + output_directory: Optional[str] = None, + render_size: Tuple[int, int] = (512, 512), + video_size: Optional[Tuple[int, int]] = None, + split: str = "train", # train | test + n_source_views: int = 9, + n_eval_cameras: int = 40, + visdom_server="http://127.0.0.1", + visdom_port=8097, + visdom_show_preds: bool = False, + visdom_env: Optional[str] = None, + gpu_idx: int = 0, +): + # In case an output directory is specified use it. If no output_directory + # is specified create a vis folder inside the experiment directory + if output_directory is None: + output_directory = os.path.join(exp_dir, "vis") + else: + output_directory = output_directory + if not os.path.exists(output_directory): + os.makedirs(output_directory) + + # Set the random seeds + torch.manual_seed(0) + np.random.seed(0) + + # Get the config from the experiment_directory, + # and overwrite relevant fields + config = _get_config_from_experiment_directory(exp_dir) + config.gpu_idx = gpu_idx + config.exp_dir = exp_dir + # important so that the CO3D dataset gets loaded in full + config.dataset_args.test_on_train = False + # Set the rendering image size + config.generic_model_args.render_image_width = render_size[0] + config.generic_model_args.render_image_height = render_size[1] + if restrict_sequence_name is not None: + config.dataset_args.restrict_sequence_name = restrict_sequence_name + + # Set up the CUDA env for the visualization + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = str(config.gpu_idx) + + # Load the previously trained model + model, _, _ = init_model(config, force_load=True, load_model_only=True) + model.cuda() + model.eval() + + # Setup the dataset + dataset = dataset_zoo(**config.dataset_args)[split] + + # iterate over the sequences in the dataset + for sequence_name in dataset.seq_to_idx.keys(): + with torch.no_grad(): + render_sequence( + dataset, + sequence_name, + model, + video_path="{}/video".format(output_directory), + n_source_views=n_source_views, + visdom_show_preds=visdom_show_preds, + n_eval_cameras=n_eval_cameras, + visdom_server=visdom_server, + visdom_port=visdom_port, + viz_env=f"visualizer_{config.visdom_env}" + if visdom_env is None + else visdom_env, + video_resize=video_size, + ) + + +def _get_config_from_experiment_directory(experiment_directory): + cfg_file = os.path.join(experiment_directory, "expconfig.yaml") + config = OmegaConf.load(cfg_file) + return config + + +def main(argv): + # automatically parses arguments of export_scenes + cfg = OmegaConf.create(get_default_args(export_scenes)) + cfg.update(OmegaConf.from_cli()) + with torch.no_grad(): + export_scenes(**cfg) + + +if __name__ == "__main__": + main(sys.argv) diff --git a/pytorch3d/implicitron/dataset/dataloader_zoo.py b/pytorch3d/implicitron/dataset/dataloader_zoo.py new file mode 100644 index 00000000..a49a815b --- /dev/null +++ b/pytorch3d/implicitron/dataset/dataloader_zoo.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, Sequence + +import torch + +from .implicitron_dataset import FrameData, ImplicitronDatasetBase +from .scene_batch_sampler import SceneBatchSampler + + +def dataloader_zoo( + datasets: Dict[str, ImplicitronDatasetBase], + dataset_name: str = "co3d_singlesequence", + batch_size: int = 1, + num_workers: int = 0, + dataset_len: int = 1000, + dataset_len_val: int = 1, + images_per_seq_options: Sequence[int] = (2,), + sample_consecutive_frames: bool = False, + consecutive_frames_max_gap: int = 0, + consecutive_frames_max_gap_seconds: float = 0.1, +) -> Dict[str, torch.utils.data.DataLoader]: + """ + Returns a set of dataloaders for a given set of datasets. + + Args: + datasets: A dictionary containing the + `"dataset_subset_name": torch_dataset_object` key, value pairs. + dataset_name: The name of the returned dataset. + batch_size: The size of the batch of the dataloader. + num_workers: Number data-loading threads. + dataset_len: The number of batches in a training epoch. + dataset_len_val: The number of batches in a validation epoch. + images_per_seq_options: Possible numbers of images sampled per sequence. + sample_consecutive_frames: if True, will sample a contiguous interval of frames + in the sequence. It first sorts the frames by timestimps when available, + otherwise by frame numbers, finds the connected segments within the sequence + of sufficient length, then samples a random pivot element among them and + ideally uses it as a middle of the temporal window, shifting the borders + where necessary. This strategy mitigates the bias against shorter segments + and their boundaries. + consecutive_frames_max_gap: if a number > 0, then used to define the maximum + difference in frame_number of neighbouring frames when forming connected + segments; if both this and consecutive_frames_max_gap_seconds are 0s, + the whole sequence is considered a segment regardless of frame numbers. + consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the + maximum difference in frame_timestamp of neighbouring frames when forming + connected segments; if both this and consecutive_frames_max_gap are 0s, + the whole sequence is considered a segment regardless of frame timestamps. + + Returns: + dataloaders: A dictionary containing the + `"dataset_subset_name": torch_dataloader_object` key, value pairs. + """ + + if dataset_name not in ["co3d_singlesequence", "co3d_multisequence"]: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + dataloaders = {} + + if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]: + for dataset_set, dataset in datasets.items(): + num_samples = { + "train": dataset_len, + "val": dataset_len_val, + "test": None, + }[dataset_set] + + if dataset_set == "test": + batch_sampler = dataset.get_eval_batches() + else: + assert num_samples is not None + num_samples = len(dataset) if num_samples <= 0 else num_samples + batch_sampler = SceneBatchSampler( + dataset, + batch_size, + num_batches=num_samples, + images_per_seq_options=images_per_seq_options, + sample_consecutive_frames=sample_consecutive_frames, + consecutive_frames_max_gap=consecutive_frames_max_gap, + ) + + dataloaders[dataset_set] = torch.utils.data.DataLoader( + dataset, + num_workers=num_workers, + batch_sampler=batch_sampler, + collate_fn=FrameData.collate, + ) + + else: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + return dataloaders diff --git a/pytorch3d/implicitron/dataset/dataset_zoo.py b/pytorch3d/implicitron/dataset/dataset_zoo.py new file mode 100644 index 00000000..cf96cf6a --- /dev/null +++ b/pytorch3d/implicitron/dataset/dataset_zoo.py @@ -0,0 +1,260 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +import json +import os +from typing import Any, Dict, List, Optional, Sequence + +from iopath.common.file_io import PathManager + +from .implicitron_dataset import ImplicitronDataset, ImplicitronDatasetBase +from .utils import ( + DATASET_TYPE_KNOWN, + DATASET_TYPE_TEST, + DATASET_TYPE_TRAIN, + DATASET_TYPE_UNKNOWN, +) + + +# TODO from dataset.dataset_configs import DATASET_CONFIGS +DATASET_CONFIGS: Dict[str, Dict[str, Any]] = { + "default": { + "box_crop": True, + "box_crop_context": 0.3, + "image_width": 800, + "image_height": 800, + "remove_empty_masks": True, + } +} + +# fmt: off +CO3D_CATEGORIES: List[str] = list(reversed([ + "baseballbat", "banana", "bicycle", "microwave", "tv", + "cellphone", "toilet", "hairdryer", "couch", "kite", "pizza", + "umbrella", "wineglass", "laptop", + "hotdog", "stopsign", "frisbee", "baseballglove", + "cup", "parkingmeter", "backpack", "toyplane", "toybus", + "handbag", "chair", "keyboard", "car", "motorcycle", + "carrot", "bottle", "sandwich", "remote", "bowl", "skateboard", + "toaster", "mouse", "toytrain", "book", "toytruck", + "orange", "broccoli", "plant", "teddybear", + "suitcase", "bench", "ball", "cake", + "vase", "hydrant", "apple", "donut", +])) +# fmt: on + +_CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "") + + +def dataset_zoo( + dataset_name: str = "co3d_singlesequence", + dataset_root: str = _CO3D_DATASET_ROOT, + category: str = "DEFAULT", + limit_to: int = -1, + limit_sequences_to: int = -1, + n_frames_per_sequence: int = -1, + test_on_train: bool = False, + load_point_clouds: bool = False, + mask_images: bool = False, + mask_depths: bool = False, + restrict_sequence_name: Sequence[str] = (), + test_restrict_sequence_id: int = -1, + assert_single_seq: bool = False, + only_test_set: bool = False, + aux_dataset_kwargs: dict = DATASET_CONFIGS["default"], + path_manager: Optional[PathManager] = None, +) -> Dict[str, ImplicitronDatasetBase]: + """ + Generates the training / validation and testing dataset objects. + + Args: + dataset_name: The name of the returned dataset. + dataset_root: The root folder of the dataset. + category: The object category of the dataset. + limit_to: Limit the dataset to the first #limit_to frames. + limit_sequences_to: Limit the dataset to the first + #limit_sequences_to sequences. + n_frames_per_sequence: Randomly sample #n_frames_per_sequence frames + in each sequence. + test_on_train: Construct validation and test datasets from + the training subset. + load_point_clouds: Enable returning scene point clouds from the dataset. + mask_images: Mask the loaded images with segmentation masks. + mask_depths: Mask the loaded depths with segmentation masks. + restrict_sequence_name: Restrict the dataset sequences to the ones + present in the given list of names. + test_restrict_sequence_id: The ID of the loaded sequence. + Active for dataset_name='co3d_singlesequence'. + assert_single_seq: Assert that only frames from a single sequence + are present in all generated datasets. + only_test_set: Load only the test set. + aux_dataset_kwargs: Specifies additional arguments to the + ImplicitronDataset constructor call. + + Returns: + datasets: A dictionary containing the + `"dataset_subset_name": torch_dataset_object` key, value pairs. + """ + + datasets = {} + + # TODO: + # - implement loading multiple categories + + if dataset_name in ["co3d_singlesequence", "co3d_multisequence"]: + # This maps the common names of the dataset subsets ("train"/"val"/"test") + # to the names of the subsets in the CO3D dataset. + set_names_mapping = _get_co3d_set_names_mapping( + dataset_name, + test_on_train, + only_test_set, + ) + + # load the evaluation batches + task = dataset_name.split("_")[-1] + batch_indices_path = os.path.join( + dataset_root, + category, + f"eval_batches_{task}.json", + ) + if not os.path.isfile(batch_indices_path): + # The batch indices file does not exist. + # Most probably the user has not specified the root folder. + raise ValueError("Please specify a correct dataset_root folder.") + + with open(batch_indices_path, "r") as f: + eval_batch_index = json.load(f) + + if task == "singlesequence": + assert ( + test_restrict_sequence_id is not None and test_restrict_sequence_id >= 0 + ), ( + "Please specify an integer id 'test_restrict_sequence_id'" + + " of the sequence considered for 'singlesequence'" + + " training and evaluation." + ) + assert len(restrict_sequence_name) == 0, ( + "For the 'singlesequence' task, the restrict_sequence_name has" + " to be unset while test_restrict_sequence_id has to be set to an" + " integer defining the order of the evaluation sequence." + ) + # a sort-stable set() equivalent: + eval_batches_sequence_names = list( + {b[0][0]: None for b in eval_batch_index}.keys() + ) + eval_sequence_name = eval_batches_sequence_names[test_restrict_sequence_id] + eval_batch_index = [ + b for b in eval_batch_index if b[0][0] == eval_sequence_name + ] + # overwrite the restrict_sequence_name + restrict_sequence_name = [eval_sequence_name] + + for dataset, subsets in set_names_mapping.items(): + frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") + assert os.path.isfile(frame_file) + + sequence_file = os.path.join( + dataset_root, category, "sequence_annotations.jgz" + ) + assert os.path.isfile(sequence_file) + + subset_lists_file = os.path.join(dataset_root, category, "set_lists.json") + assert os.path.isfile(subset_lists_file) + + # TODO: maybe directly in param list + params = { + **copy.deepcopy(aux_dataset_kwargs), + "frame_annotations_file": frame_file, + "sequence_annotations_file": sequence_file, + "subset_lists_file": subset_lists_file, + "dataset_root": dataset_root, + "limit_to": limit_to, + "limit_sequences_to": limit_sequences_to, + "n_frames_per_sequence": n_frames_per_sequence + if dataset == "train" + else -1, + "subsets": subsets, + "load_point_clouds": load_point_clouds, + "mask_images": mask_images, + "mask_depths": mask_depths, + "pick_sequence": restrict_sequence_name, + "path_manager": path_manager, + } + + datasets[dataset] = ImplicitronDataset(**params) + if dataset == "test": + if len(restrict_sequence_name) > 0: + eval_batch_index = [ + b for b in eval_batch_index if b[0][0] in restrict_sequence_name + ] + + datasets[dataset].eval_batches = datasets[ + dataset + ].seq_frame_index_to_dataset_index(eval_batch_index) + + if assert_single_seq: + # check theres only one sequence in all datasets + assert ( + len( + { + e["frame_annotation"].sequence_name + for dset in datasets.values() + for e in dset.frame_annots + } + ) + <= 1 + ), "Multiple sequences loaded but expected one" + + else: + raise ValueError(f"Unsupported dataset: {dataset_name}") + + if test_on_train: + datasets["val"] = datasets["train"] + datasets["test"] = datasets["train"] + + return datasets + + +def _get_co3d_set_names_mapping( + dataset_name: str, + test_on_train: bool, + only_test: bool, +) -> Dict[str, List[str]]: + """ + Returns the mapping of the common dataset subset names ("train"/"val"/"test") + to the names of the corresponding subsets in the CO3D dataset + ("test_known"/"test_unseen"/"train_known"/"train_unseen"). + """ + single_seq = dataset_name == "co3d_singlesequence" + + if only_test: + set_names_mapping = {} + else: + set_names_mapping = { + "train": [ + (DATASET_TYPE_TEST if single_seq else DATASET_TYPE_TRAIN) + + "_" + + DATASET_TYPE_KNOWN + ] + } + if not test_on_train: + prefixes = [DATASET_TYPE_TEST] + if not single_seq: + prefixes.append(DATASET_TYPE_TRAIN) + set_names_mapping.update( + { + dset: [ + p + "_" + t + for p in prefixes + for t in [DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN] + ] + for dset in ["val", "test"] + } + ) + + return set_names_mapping diff --git a/pytorch3d/implicitron/dataset/implicitron_dataset.py b/pytorch3d/implicitron/dataset/implicitron_dataset.py new file mode 100644 index 00000000..c397ff6f --- /dev/null +++ b/pytorch3d/implicitron/dataset/implicitron_dataset.py @@ -0,0 +1,988 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import gzip +import hashlib +import json +import os +import random +import warnings +from collections import defaultdict +from dataclasses import dataclass, field, fields +from itertools import islice +from pathlib import Path +from typing import ( + ClassVar, + Dict, + List, + Optional, + Sequence, + Tuple, + Type, + TypedDict, + Union, +) + +import numpy as np +import torch +from iopath.common.file_io import PathManager +from PIL import Image +from pytorch3d.io import IO +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.structures.pointclouds import Pointclouds, join_pointclouds_as_batch + +from . import types + + +@dataclass +class FrameData: + """ + A type of the elements returned by indexing the dataset object. + It can represent both individual frames and batches of thereof; + in this documentation, the sizes of tensors refer to single frames; + add the first batch dimension for the collation result. + + Args: + frame_number: The number of the frame within its sequence. + 0-based continuous integers. + frame_timestamp: The time elapsed since the start of a sequence in sec. + sequence_name: The unique name of the frame's sequence. + sequence_category: The object category of the sequence. + image_size_hw: The size of the image in pixels; (height, width) tuple. + image_path: The qualified path to the loaded image (with dataset_root). + image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image + of the frame; elements are floats in [0, 1]. + mask_crop: A binary mask of shape `(1, H, W)` denoting the valid image + regions. Regions can be invalid (mask_crop[i,j]=0) in case they + are a result of zero-padding of the image after cropping around + the object bounding box; elements are floats in {0.0, 1.0}. + depth_path: The qualified path to the frame's depth map. + depth_map: A float Tensor of shape `(1, H, W)` holding the depth map + of the frame; values correspond to distances from the camera; + use `depth_mask` and `mask_crop` to filter for valid pixels. + depth_mask: A binary mask of shape `(1, H, W)` denoting pixels of the + depth map that are valid for evaluation, they have been checked for + consistency across views; elements are floats in {0.0, 1.0}. + mask_path: A qualified path to the foreground probability mask. + fg_probability: A Tensor of `(1, H, W)` denoting the probability of the + pixels belonging to the captured object; elements are floats + in [0, 1]. + bbox_xywh: The bounding box capturing the object in the + format (x0, y0, width, height). + camera: A PyTorch3D camera object corresponding the frame's viewpoint, + corrected for cropping if it happened. + camera_quality_score: The score proportional to the confidence of the + frame's camera estimation (the higher the more accurate). + point_cloud_quality_score: The score proportional to the accuracy of the + frame's sequence point cloud (the higher the more accurate). + sequence_point_cloud_path: The path to the sequence's point cloud. + sequence_point_cloud: A PyTorch3D Pointclouds object holding the + point cloud corresponding to the frame's sequence. When the object + represents a batch of frames, point clouds may be deduplicated; + see `sequence_point_cloud_idx`. + sequence_point_cloud_idx: Integer indices mapping frame indices to the + corresponding point clouds in `sequence_point_cloud`; to get the + corresponding point cloud to `image_rgb[i]`, use + `sequence_point_cloud[sequence_point_cloud_idx[i]]`. + frame_type: The type of the loaded frame specified in + `subset_lists_file`, if provided. + meta: A dict for storing additional frame information. + """ + + frame_number: Optional[torch.LongTensor] + frame_timestamp: Optional[torch.Tensor] + sequence_name: Union[str, List[str]] + sequence_category: Union[str, List[str]] + image_size_hw: Optional[torch.Tensor] = None + image_path: Union[str, List[str], None] = None + image_rgb: Optional[torch.Tensor] = None + # masks out padding added due to cropping the square bit + mask_crop: Optional[torch.Tensor] = None + depth_path: Union[str, List[str], None] = None + depth_map: Optional[torch.Tensor] = None + depth_mask: Optional[torch.Tensor] = None + mask_path: Union[str, List[str], None] = None + fg_probability: Optional[torch.Tensor] = None + bbox_xywh: Optional[torch.Tensor] = None + camera: Optional[PerspectiveCameras] = None + camera_quality_score: Optional[torch.Tensor] = None + point_cloud_quality_score: Optional[torch.Tensor] = None + sequence_point_cloud_path: Union[str, List[str], None] = None + sequence_point_cloud: Optional[Pointclouds] = None + sequence_point_cloud_idx: Optional[torch.Tensor] = None + frame_type: Union[str, List[str], None] = None # seen | unseen + meta: dict = field(default_factory=lambda: {}) + + def to(self, *args, **kwargs): + new_params = {} + for f in fields(self): + value = getattr(self, f.name) + if isinstance(value, (torch.Tensor, Pointclouds, CamerasBase)): + new_params[f.name] = value.to(*args, **kwargs) + else: + new_params[f.name] = value + return type(self)(**new_params) + + def cpu(self): + return self.to(device=torch.device("cpu")) + + def cuda(self): + return self.to(device=torch.device("cuda")) + + # the following functions make sure **frame_data can be passed to functions + def keys(self): + for f in fields(self): + yield f.name + + def __getitem__(self, key): + return getattr(self, key) + + @classmethod + def collate(cls, batch): + """ + Given a list objects `batch` of class `cls`, collates them into a batched + representation suitable for processing with deep networks. + """ + + elem = batch[0] + + if isinstance(elem, cls): + pointcloud_ids = [id(el.sequence_point_cloud) for el in batch] + id_to_idx = defaultdict(list) + for i, pc_id in enumerate(pointcloud_ids): + id_to_idx[pc_id].append(i) + + sequence_point_cloud = [] + sequence_point_cloud_idx = -np.ones((len(batch),)) + for i, ind in enumerate(id_to_idx.values()): + sequence_point_cloud_idx[ind] = i + sequence_point_cloud.append(batch[ind[0]].sequence_point_cloud) + assert (sequence_point_cloud_idx >= 0).all() + + override_fields = { + "sequence_point_cloud": sequence_point_cloud, + "sequence_point_cloud_idx": sequence_point_cloud_idx.tolist(), + } + # note that the pre-collate value of sequence_point_cloud_idx is unused + + collated = {} + for f in fields(elem): + list_values = override_fields.get( + f.name, [getattr(d, f.name) for d in batch] + ) + collated[f.name] = ( + cls.collate(list_values) + if all(list_value is not None for list_value in list_values) + else None + ) + return cls(**collated) + + elif isinstance(elem, Pointclouds): + return join_pointclouds_as_batch(batch) + + elif isinstance(elem, CamerasBase): + # TODO: don't store K; enforce working in NDC space + return join_cameras_as_batch(batch) + else: + return torch.utils.data._utils.collate.default_collate(batch) + + +@dataclass(eq=False) +class ImplicitronDatasetBase(torch.utils.data.Dataset[FrameData]): + """ + Base class to describe a dataset to be used with Implicitron. + + The dataset is made up of frames, and the frames are grouped into sequences. + Each sequence has a name (a string). + (A sequence could be a video, or a set of images of one scene.) + + This means they have a __getitem__ which returns an instance of a FrameData, + which will describe one frame in one sequence. + + Members: + seq_to_idx: For each sequence, the indices of its frames. + """ + + seq_to_idx: Dict[str, List[int]] = field(init=False) + + def __len__(self) -> int: + raise NotImplementedError + + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int] + ) -> List[Tuple[int, float]]: + """ + If the sequences in the dataset are videos rather than + unordered views, then the dataset should override this method to + return the index and timestamp in their videos of the frames whose + indices are given in `idxs`. In addition, + the values in seq_to_idx should be in ascending order. + If timestamps are absent, they should be replaced with a constant. + + This is used for letting SceneBatchSampler identify consecutive + frames. + + Args: + idx: frame index in self + + Returns: + tuple of + - frame index in video + - timestamp of frame in video + """ + raise ValueError("This dataset does not contain videos.") + + def get_eval_batches(self) -> Optional[List[List[int]]]: + return None + + +class FrameAnnotsEntry(TypedDict): + subset: Optional[str] + frame_annotation: types.FrameAnnotation + + +@dataclass(eq=False) +class ImplicitronDataset(ImplicitronDatasetBase): + """ + A class for the Common Objects in 3D (CO3D) dataset. + + Args: + frame_annotations_file: A zipped json file containing metadata of the + frames in the dataset, serialized List[types.FrameAnnotation]. + sequence_annotations_file: A zipped json file containing metadata of the + sequences in the dataset, serialized List[types.SequenceAnnotation]. + subset_lists_file: A json file containing the lists of frames corresponding + corresponding to different subsets (e.g. train/val/test) of the dataset; + format: {subset: (sequence_name, frame_id, file_path)}. + subsets: Restrict frames/sequences only to the given list of subsets + as defined in subset_lists_file (see above). + limit_to: Limit the dataset to the first #limit_to frames (after other + filters have been applied). + limit_sequences_to: Limit the dataset to the first + #limit_sequences_to sequences (after other sequence filters have been + applied but before frame-based filters). + pick_sequence: A list of sequence names to restrict the dataset to. + exclude_sequence: A list of the names of the sequences to exclude. + limit_category_to: Restrict the dataset to the given list of categories. + dataset_root: The root folder of the dataset; all the paths in jsons are + specified relative to this root (but not json paths themselves). + load_images: Enable loading the frame RGB data. + load_depths: Enable loading the frame depth maps. + load_depth_masks: Enable loading the frame depth map masks denoting the + depth values used for evaluation (the points consistent across views). + load_masks: Enable loading frame foreground masks. + load_point_clouds: Enable loading sequence-level point clouds. + max_points: Cap on the number of loaded points in the point cloud; + if reached, they are randomly sampled without replacement. + mask_images: Whether to mask the images with the loaded foreground masks; + 0 value is used for background. + mask_depths: Whether to mask the depth maps with the loaded foreground + masks; 0 value is used for background. + image_height: The height of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + image_width: The width of the returned images, masks, and depth maps; + aspect ratio is preserved during cropping/resizing. + box_crop: Enable cropping of the image around the bounding box inferred + from the foreground region of the loaded segmentation mask; masks + and depth maps are cropped accordingly; cameras are corrected. + box_crop_mask_thr: The threshold used to separate pixels into foreground + and background based on the foreground_probability mask; if no value + is greater than this threshold, the loader lowers it and repeats. + box_crop_context: The amount of additional padding added to each + dimension of the cropping bounding box, relative to box size. + remove_empty_masks: Removes the frames with no active foreground pixels + in the segmentation mask after thresholding (see box_crop_mask_thr). + n_frames_per_sequence: If > 0, randomly samples #n_frames_per_sequence + frames in each sequences uniformly without replacement if it has + more frames than that; applied before other frame-level filters. + seed: The seed of the random generator sampling #n_frames_per_sequence + random frames per sequence. + sort_frames: Enable frame annotations sorting to group frames from the + same sequences together and order them by timestamps + eval_batches: A list of batches that form the evaluation set; + list of batch-sized lists of indices corresponding to __getitem__ + of this class, thus it can be used directly as a batch sampler. + """ + + frame_annotations_type: ClassVar[ + Type[types.FrameAnnotation] + ] = types.FrameAnnotation + + path_manager: Optional[PathManager] = None + frame_annotations_file: str = "" + sequence_annotations_file: str = "" + subset_lists_file: str = "" + subsets: Optional[List[str]] = None + limit_to: int = 0 + limit_sequences_to: int = 0 + pick_sequence: Sequence[str] = () + exclude_sequence: Sequence[str] = () + limit_category_to: Sequence[int] = () + dataset_root: str = "" + load_images: bool = True + load_depths: bool = True + load_depth_masks: bool = True + load_masks: bool = True + load_point_clouds: bool = False + max_points: int = 0 + mask_images: bool = False + mask_depths: bool = False + image_height: Optional[int] = 256 + image_width: Optional[int] = 256 + box_crop: bool = False + box_crop_mask_thr: float = 0.4 + box_crop_context: float = 1.0 + remove_empty_masks: bool = False + n_frames_per_sequence: int = -1 + seed: int = 0 + sort_frames: bool = False + eval_batches: Optional[List[List[int]]] = None + frame_annots: List[FrameAnnotsEntry] = field(init=False) + seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) + + def __post_init__(self) -> None: + # pyre-fixme[16]: `ImplicitronDataset` has no attribute `subset_to_image_path`. + self.subset_to_image_path = None + self._load_frames() + self._load_sequences() + if self.sort_frames: + self._sort_frames() + self._load_subset_lists() + self._filter_db() # also computes sequence indices + print(str(self)) + + def seq_frame_index_to_dataset_index( + self, + seq_frame_index: Union[ + List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], + ], + ) -> List[List[int]]: + """ + Obtain indices into the dataset object given a list of frames specified as + `seq_frame_index = List[List[Tuple[sequence_name:str, frame_number:int]]]`. + """ + # TODO: check the frame numbers are unique + _dataset_seq_frame_n_index = { + seq: { + self.frame_annots[idx]["frame_annotation"].frame_number: idx + for idx in seq_idx + } + for seq, seq_idx in self.seq_to_idx.items() + } + + def _get_batch_idx(seq_name, frame_no, path=None) -> int: + idx = _dataset_seq_frame_n_index[seq_name][frame_no] + if path is not None: + # Check that the loaded frame path is consistent + # with the one stored in self.frame_annots. + assert os.path.normpath( + self.frame_annots[idx]["frame_annotation"].image.path + ) == os.path.normpath( + path + ), f"Inconsistent batch {seq_name, frame_no, path}." + return idx + + batches_idx = [[_get_batch_idx(*b) for b in batch] for batch in seq_frame_index] + return batches_idx + + def __str__(self) -> str: + return f"ImplicitronDataset #frames={len(self.frame_annots)}" + + def __len__(self) -> int: + return len(self.frame_annots) + + def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: + return entry["subset"] + + def __getitem__(self, index) -> FrameData: + if index >= len(self.frame_annots): + raise IndexError(f"index {index} out of range {len(self.frame_annots)}") + + entry = self.frame_annots[index]["frame_annotation"] + point_cloud = self.seq_annots[entry.sequence_name].point_cloud + frame_data = FrameData( + frame_number=_safe_as_tensor(entry.frame_number, torch.long), + frame_timestamp=_safe_as_tensor(entry.frame_timestamp, torch.float), + sequence_name=entry.sequence_name, + sequence_category=self.seq_annots[entry.sequence_name].category, + camera_quality_score=_safe_as_tensor( + self.seq_annots[entry.sequence_name].viewpoint_quality_score, + torch.float, + ), + point_cloud_quality_score=_safe_as_tensor( + point_cloud.quality_score, torch.float + ) + if point_cloud is not None + else None, + ) + + # The rest of the fields are optional + frame_data.frame_type = self._get_frame_type(self.frame_annots[index]) + + ( + frame_data.fg_probability, + frame_data.mask_path, + frame_data.bbox_xywh, + clamp_bbox_xyxy, + ) = self._load_crop_fg_probability(entry) + + scale = 1.0 + if self.load_images and entry.image is not None: + # original image size + frame_data.image_size_hw = _safe_as_tensor(entry.image.size, torch.long) + + ( + frame_data.image_rgb, + frame_data.image_path, + frame_data.mask_crop, + scale, + ) = self._load_crop_images( + entry, frame_data.fg_probability, clamp_bbox_xyxy + ) + + if self.load_depths and entry.depth is not None: + ( + frame_data.depth_map, + frame_data.depth_path, + frame_data.depth_mask, + ) = self._load_mask_depth(entry, clamp_bbox_xyxy, frame_data.fg_probability) + + if entry.viewpoint is not None: + frame_data.camera = self._get_pytorch3d_camera( + entry, + scale, + clamp_bbox_xyxy, + ) + + if self.load_point_clouds and point_cloud is not None: + frame_data.sequence_point_cloud_path = pcl_path = os.path.join( + self.dataset_root, point_cloud.path + ) + frame_data.sequence_point_cloud = _load_pointcloud( + self._local_path(pcl_path), max_points=self.max_points + ) + + return frame_data + + def _load_crop_fg_probability( + self, entry: types.FrameAnnotation + ) -> Tuple[ + Optional[torch.Tensor], + Optional[str], + Optional[torch.Tensor], + Optional[torch.Tensor], + ]: + fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy = ( + None, + None, + None, + None, + ) + if (self.load_masks or self.box_crop) and entry.mask is not None: + full_path = os.path.join(self.dataset_root, entry.mask.path) + mask = _load_mask(self._local_path(full_path)) + + if mask.shape[-2:] != entry.image.size: + raise ValueError( + f"bad mask size: {mask.shape[-2:]} vs {entry.image.size}!" + ) + + bbox_xywh = torch.tensor(_get_bbox_from_mask(mask, self.box_crop_mask_thr)) + + if self.box_crop: + clamp_bbox_xyxy = _get_clamp_bbox(bbox_xywh, self.box_crop_context) + mask = _crop_around_box(mask, clamp_bbox_xyxy, full_path) + + fg_probability, _, _ = self._resize_image(mask, mode="nearest") + return fg_probability, full_path, bbox_xywh, clamp_bbox_xyxy + + def _load_crop_images( + self, + entry: types.FrameAnnotation, + fg_probability: Optional[torch.Tensor], + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor, float]: + assert self.dataset_root is not None and entry.image is not None + path = os.path.join(self.dataset_root, entry.image.path) + image_rgb = _load_image(self._local_path(path)) + + if image_rgb.shape[-2:] != entry.image.size: + raise ValueError( + f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" + ) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + image_rgb = _crop_around_box(image_rgb, clamp_bbox_xyxy, path) + + image_rgb, scale, mask_crop = self._resize_image(image_rgb) + + if self.mask_images: + assert fg_probability is not None + image_rgb *= fg_probability + + return image_rgb, path, mask_crop, scale + + def _load_mask_depth( + self, + entry: types.FrameAnnotation, + clamp_bbox_xyxy: Optional[torch.Tensor], + fg_probability: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, str, torch.Tensor]: + entry_depth = entry.depth + assert entry_depth is not None + path = os.path.join(self.dataset_root, entry_depth.path) + depth_map = _load_depth(self._local_path(path), entry_depth.scale_adjustment) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + depth_bbox_xyxy = _rescale_bbox( + clamp_bbox_xyxy, entry.image.size, depth_map.shape[-2:] + ) + depth_map = _crop_around_box(depth_map, depth_bbox_xyxy, path) + + depth_map, _, _ = self._resize_image(depth_map, mode="nearest") + + if self.mask_depths: + assert fg_probability is not None + depth_map *= fg_probability + + if self.load_depth_masks: + assert entry_depth.mask_path is not None + mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) + depth_mask = _load_depth_mask(self._local_path(mask_path)) + + if self.box_crop: + assert clamp_bbox_xyxy is not None + depth_mask_bbox_xyxy = _rescale_bbox( + clamp_bbox_xyxy, entry.image.size, depth_mask.shape[-2:] + ) + depth_mask = _crop_around_box( + depth_mask, depth_mask_bbox_xyxy, mask_path + ) + + depth_mask, _, _ = self._resize_image(depth_mask, mode="nearest") + else: + depth_mask = torch.ones_like(depth_map) + + return depth_map, path, depth_mask + + def _get_pytorch3d_camera( + self, + entry: types.FrameAnnotation, + scale: float, + clamp_bbox_xyxy: Optional[torch.Tensor], + ) -> PerspectiveCameras: + entry_viewpoint = entry.viewpoint + assert entry_viewpoint is not None + # principal point and focal length + principal_point = torch.tensor( + entry_viewpoint.principal_point, dtype=torch.float + ) + focal_length = torch.tensor(entry_viewpoint.focal_length, dtype=torch.float) + + half_image_size_wh_orig = ( + torch.tensor(list(reversed(entry.image.size)), dtype=torch.float) / 2.0 + ) + + # first, we convert from the dataset's NDC convention to pixels + format = entry_viewpoint.intrinsics_format + if format.lower() == "ndc_norm_image_bounds": + # this is e.g. currently used in CO3D for storing intrinsics + rescale = half_image_size_wh_orig + elif format.lower() == "ndc_isotropic": + rescale = half_image_size_wh_orig.min() + else: + raise ValueError(f"Unknown intrinsics format: {format}") + + # principal point and focal length in pixels + principal_point_px = half_image_size_wh_orig - principal_point * rescale + focal_length_px = focal_length * rescale + if self.box_crop: + assert clamp_bbox_xyxy is not None + principal_point_px -= clamp_bbox_xyxy[:2] + + # now, convert from pixels to PyTorch3D v0.5+ NDC convention + if self.image_height is None or self.image_width is None: + out_size = list(reversed(entry.image.size)) + else: + out_size = [self.image_width, self.image_height] + + half_image_size_output = torch.tensor(out_size, dtype=torch.float) / 2.0 + half_min_image_size_output = half_image_size_output.min() + + # rescaled principal point and focal length in ndc + principal_point = ( + half_image_size_output - principal_point_px * scale + ) / half_min_image_size_output + focal_length = focal_length_px * scale / half_min_image_size_output + + return PerspectiveCameras( + focal_length=focal_length[None], + principal_point=principal_point[None], + R=torch.tensor(entry_viewpoint.R, dtype=torch.float)[None], + T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], + ) + + def _load_frames(self) -> None: + print(f"Loading Co3D frames from {self.frame_annotations_file}.") + local_file = self._local_path(self.frame_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + frame_annots_list = types.load_dataclass( + zipfile, List[self.frame_annotations_type] + ) + if not frame_annots_list: + raise ValueError("Empty dataset!") + self.frame_annots = [ + FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list + ] + + def _load_sequences(self) -> None: + print(f"Loading Co3D sequences from {self.sequence_annotations_file}.") + local_file = self._local_path(self.sequence_annotations_file) + with gzip.open(local_file, "rt", encoding="utf8") as zipfile: + seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) + if not seq_annots: + raise ValueError("Empty sequences file!") + self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} + + def _load_subset_lists(self) -> None: + print(f"Loading Co3D subset lists from {self.subset_lists_file}.") + if not self.subset_lists_file: + return + + with open(self._local_path(self.subset_lists_file), "r") as f: + subset_to_seq_frame = json.load(f) + + frame_path_to_subset = { + path: subset + for subset, frames in subset_to_seq_frame.items() + for _, _, path in frames + } + + for frame in self.frame_annots: + frame["subset"] = frame_path_to_subset.get( + frame["frame_annotation"].image.path, None + ) + if frame["subset"] is None: + warnings.warn( + "Subset lists are given but don't include " + + frame["frame_annotation"].image.path + ) + + def _sort_frames(self) -> None: + # Sort frames to have them grouped by sequence, ordered by timestamp + self.frame_annots = sorted( + self.frame_annots, + key=lambda f: ( + f["frame_annotation"].sequence_name, + f["frame_annotation"].frame_timestamp or 0, + ), + ) + + def _filter_db(self) -> None: + if self.remove_empty_masks: + print("Removing images with empty masks.") + old_len = len(self.frame_annots) + + msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." + + def positive_mass(frame_annot: types.FrameAnnotation) -> bool: + mask = frame_annot.mask + if mask is None: + return False + if mask.mass is None: + raise ValueError(msg) + return mask.mass > 1 + + self.frame_annots = [ + frame + for frame in self.frame_annots + if positive_mass(frame["frame_annotation"]) + ] + print("... filtered %d -> %d" % (old_len, len(self.frame_annots))) + + # this has to be called after joining with categories!! + subsets = self.subsets + if subsets: + if not self.subset_lists_file: + raise ValueError( + "Subset filter is on but subset_lists_file was not given" + ) + + print(f"Limitting Co3D dataset to the '{subsets}' subsets.") + + # truncate the list of subsets to the valid one + self.frame_annots = [ + entry for entry in self.frame_annots if entry["subset"] in subsets + ] + if len(self.frame_annots) == 0: + raise ValueError(f"There are no frames in the '{subsets}' subsets!") + + self._invalidate_indexes(filter_seq_annots=True) + + if len(self.limit_category_to) > 0: + print(f"Limitting dataset to categories: {self.limit_category_to}") + self.seq_annots = { + name: entry + for name, entry in self.seq_annots.items() + if entry.category in self.limit_category_to + } + + # sequence filters + for prefix in ("pick", "exclude"): + orig_len = len(self.seq_annots) + attr = f"{prefix}_sequence" + arr = getattr(self, attr) + if len(arr) > 0: + print(f"{attr}: {str(arr)}") + self.seq_annots = { + name: entry + for name, entry in self.seq_annots.items() + if (name in arr) == (prefix == "pick") + } + print("... filtered %d -> %d" % (orig_len, len(self.seq_annots))) + + if self.limit_sequences_to > 0: + self.seq_annots = dict( + islice(self.seq_annots.items(), self.limit_sequences_to) + ) + + # retain only frames from retained sequences + self.frame_annots = [ + f + for f in self.frame_annots + if f["frame_annotation"].sequence_name in self.seq_annots + ] + + self._invalidate_indexes() + + if self.n_frames_per_sequence > 0: + print(f"Taking max {self.n_frames_per_sequence} per sequence.") + keep_idx = [] + for seq, seq_indices in self.seq_to_idx.items(): + # infer the seed from the sequence name, this is reproducible + # and makes the selection differ for different sequences + seed = _seq_name_to_seed(seq) + self.seed + seq_idx_shuffled = random.Random(seed).sample( + sorted(seq_indices), len(seq_indices) + ) + keep_idx.extend(seq_idx_shuffled[: self.n_frames_per_sequence]) + + print("... filtered %d -> %d" % (len(self.frame_annots), len(keep_idx))) + self.frame_annots = [self.frame_annots[i] for i in keep_idx] + self._invalidate_indexes(filter_seq_annots=False) + # sequences are not decimated, so self.seq_annots is valid + + if self.limit_to > 0 and self.limit_to < len(self.frame_annots): + print( + "limit_to: filtered %d -> %d" % (len(self.frame_annots), self.limit_to) + ) + self.frame_annots = self.frame_annots[: self.limit_to] + self._invalidate_indexes(filter_seq_annots=True) + + def _invalidate_indexes(self, filter_seq_annots: bool = False) -> None: + # update seq_to_idx and filter seq_meta according to frame_annots change + # if filter_seq_annots, also uldates seq_annots based on the changed seq_to_idx + self._invalidate_seq_to_idx() + + if filter_seq_annots: + self.seq_annots = { + k: v for k, v in self.seq_annots.items() if k in self.seq_to_idx + } + + def _invalidate_seq_to_idx(self) -> None: + seq_to_idx = defaultdict(list) + for idx, entry in enumerate(self.frame_annots): + seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) + self.seq_to_idx = seq_to_idx + + def _resize_image( + self, image, mode="bilinear" + ) -> Tuple[torch.Tensor, float, torch.Tensor]: + image_height, image_width = self.image_height, self.image_width + if image_height is None or image_width is None: + # skip the resizing + imre_ = torch.from_numpy(image) + return imre_, 1.0, torch.ones_like(imre_[:1]) + # takes numpy array, returns pytorch tensor + minscale = min( + image_height / image.shape[-2], + image_width / image.shape[-1], + ) + imre = torch.nn.functional.interpolate( + torch.from_numpy(image)[None], + # pyre-ignore[6] + scale_factor=minscale, + mode=mode, + align_corners=False if mode == "bilinear" else None, + recompute_scale_factor=True, + )[0] + imre_ = torch.zeros(image.shape[0], self.image_height, self.image_width) + imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre + mask = torch.zeros(1, self.image_height, self.image_width) + mask[:, 0 : imre.shape[1] - 1, 0 : imre.shape[2] - 1] = 1.0 + return imre_, minscale, mask + + def _local_path(self, path: str) -> str: + if self.path_manager is None: + return path + return self.path_manager.get_local_path(path) + + def get_frame_numbers_and_timestamps( + self, idxs: Sequence[int] + ) -> List[Tuple[int, float]]: + out: List[Tuple[int, float]] = [] + for idx in idxs: + frame_annotation = self.frame_annots[idx]["frame_annotation"] + out.append( + (frame_annotation.frame_number, frame_annotation.frame_timestamp) + ) + return out + + def get_eval_batches(self) -> Optional[List[List[int]]]: + return self.eval_batches + + +def _seq_name_to_seed(seq_name) -> int: + return int(hashlib.sha1(seq_name.encode("utf-8")).hexdigest(), 16) + + +def _load_image(path) -> np.ndarray: + with Image.open(path) as pil_im: + im = np.array(pil_im.convert("RGB")) + im = im.transpose((2, 0, 1)) + im = im.astype(np.float32) / 255.0 + return im + + +def _load_16big_png_depth(depth_png) -> np.ndarray: + with Image.open(depth_png) as depth_pil: + # the image is stored with 16-bit depth but PIL reads it as I (32 bit). + # we cast it to uint16, then reinterpret as float16, then cast to float32 + depth = ( + np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) + .astype(np.float32) + .reshape((depth_pil.size[1], depth_pil.size[0])) + ) + return depth + + +def _load_1bit_png_mask(file: str) -> np.ndarray: + with Image.open(file) as pil_im: + mask = (np.array(pil_im.convert("L")) > 0.0).astype(np.float32) + return mask + + +def _load_depth_mask(path) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth mask file name "%s"' % path) + m = _load_1bit_png_mask(path) + return m[None] # fake feature channel + + +def _load_depth(path, scale_adjustment) -> np.ndarray: + if not path.lower().endswith(".png"): + raise ValueError('unsupported depth file name "%s"' % path) + + d = _load_16big_png_depth(path) * scale_adjustment + d[~np.isfinite(d)] = 0.0 + return d[None] # fake feature channel + + +def _load_mask(path) -> np.ndarray: + with Image.open(path) as pil_im: + mask = np.array(pil_im) + mask = mask.astype(np.float32) / 255.0 + return mask[None] # fake feature channel + + +def _get_1d_bounds(arr) -> Tuple[int, int]: + nz = np.flatnonzero(arr) + return nz[0], nz[-1] + + +def _get_bbox_from_mask( + mask, thr, decrease_quant: float = 0.05 +) -> Tuple[int, int, int, int]: + # bbox in xywh + masks_for_box = np.zeros_like(mask) + while masks_for_box.sum() <= 1.0: + masks_for_box = (mask > thr).astype(np.float32) + thr -= decrease_quant + if thr <= 0.0: + warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") + + x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) + y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) + + return x0, y0, x1 - x0, y1 - y0 + + +def _get_clamp_bbox( + bbox: torch.Tensor, box_crop_context: float = 0.0, impath: str = "" +) -> torch.Tensor: + # box_crop_context: rate of expansion for bbox + # returns possibly expanded bbox xyxy as float + + # increase box size + if box_crop_context > 0.0: + c = box_crop_context + bbox = bbox.float() + bbox[0] -= bbox[2] * c / 2 + bbox[1] -= bbox[3] * c / 2 + bbox[2] += bbox[2] * c + bbox[3] += bbox[3] * c + + if (bbox[2:] <= 1.0).any(): + raise ValueError( + f"squashed image {impath}!! The bounding box contains no pixels." + ) + + bbox[2:] = torch.clamp(bbox[2:], 2) + bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax] + # +1 because upper bound is not inclusive + + return bbox + + +def _crop_around_box(tensor, bbox, impath: str = ""): + # bbox is xyxy, where the upper bound is corrected with +1 + bbox[[0, 2]] = torch.clamp(bbox[[0, 2]], 0.0, tensor.shape[-1]) + bbox[[1, 3]] = torch.clamp(bbox[[1, 3]], 0.0, tensor.shape[-2]) + bbox = bbox.round().long() + tensor = tensor[..., bbox[1] : bbox[3], bbox[0] : bbox[2]] + assert all(c > 0 for c in tensor.shape), f"squashed image {impath}" + + return tensor + + +def _rescale_bbox(bbox: torch.Tensor, orig_res, new_res) -> torch.Tensor: + assert bbox is not None + assert np.prod(orig_res) > 1e-8 + # average ratio of dimensions + rel_size = (new_res[0] / orig_res[0] + new_res[1] / orig_res[1]) / 2.0 + return bbox * rel_size + + +def _safe_as_tensor(data, dtype): + if data is None: + return None + return torch.tensor(data, dtype=dtype) + + +# NOTE this cache is per-worker; they are implemented as processes. +# each batch is loaded and collated by a single worker; +# since sequences tend to co-occur within batches, this is useful. +@functools.lru_cache(maxsize=256) +def _load_pointcloud(pcl_path: Union[str, Path], max_points: int = 0) -> Pointclouds: + pcl = IO().load_pointcloud(pcl_path) + if max_points > 0: + pcl = pcl.subsample(max_points) + + return pcl diff --git a/pytorch3d/implicitron/dataset/scene_batch_sampler.py b/pytorch3d/implicitron/dataset/scene_batch_sampler.py new file mode 100644 index 00000000..29588079 --- /dev/null +++ b/pytorch3d/implicitron/dataset/scene_batch_sampler.py @@ -0,0 +1,203 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import warnings +from dataclasses import dataclass, field +from typing import Iterator, List, Sequence, Tuple + +import numpy as np +from torch.utils.data.sampler import Sampler + +from .implicitron_dataset import ImplicitronDatasetBase + + +@dataclass(eq=False) # TODO: do we need this if not init from config? +class SceneBatchSampler(Sampler[List[int]]): + """ + A class for sampling training batches with a controlled composition + of sequences. + """ + + dataset: ImplicitronDatasetBase + batch_size: int + num_batches: int + # the sampler first samples a random element k from this list and then + # takes k random frames per sequence + images_per_seq_options: Sequence[int] + + # if True, will sample a contiguous interval of frames in the sequence + # it first finds the connected segments within the sequence of sufficient length, + # then samples a random pivot element among them and ideally uses it as a middle + # of the temporal window, shifting the borders where necessary. + # This strategy mitigates the bias against shorter segments and their boundaries. + sample_consecutive_frames: bool = False + # if a number > 0, then used to define the maximum difference in frame_number + # of neighbouring frames when forming connected segments; otherwise the whole + # sequence is considered a segment regardless of frame numbers + consecutive_frames_max_gap: int = 0 + # same but for timestamps if they are available + consecutive_frames_max_gap_seconds: float = 0.1 + + seq_names: List[str] = field(init=False) + + def __post_init__(self) -> None: + if self.batch_size <= 0: + raise ValueError( + "batch_size should be a positive integral value, " + f"but got batch_size={self.batch_size}" + ) + + if len(self.images_per_seq_options) < 1: + raise ValueError("n_per_seq_posibilities list cannot be empty") + + self.seq_names = list(self.dataset.seq_to_idx.keys()) + + def __len__(self) -> int: + return self.num_batches + + def __iter__(self) -> Iterator[List[int]]: + for batch_idx in range(len(self)): + batch = self._sample_batch(batch_idx) + yield batch + + def _sample_batch(self, batch_idx) -> List[int]: + n_per_seq = np.random.choice(self.images_per_seq_options) + n_seqs = -(-self.batch_size // n_per_seq) # round up + chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False) + + if self.sample_consecutive_frames: + frame_idx = [] + for seq in chosen_seq: + segment_index = self._build_segment_index( + list(self.dataset.seq_to_idx[seq]), n_per_seq + ) + + segment, idx = segment_index[np.random.randint(len(segment_index))] + if len(segment) <= n_per_seq: + frame_idx.append(segment) + else: + start = np.clip(idx - n_per_seq // 2, 0, len(segment) - n_per_seq) + frame_idx.append(segment[start : start + n_per_seq]) + + else: + frame_idx = [ + _capped_random_choice( + self.dataset.seq_to_idx[seq], n_per_seq, replace=False + ) + for seq in chosen_seq + ] + frame_idx = np.concatenate(frame_idx)[: self.batch_size].tolist() + if len(frame_idx) < self.batch_size: + warnings.warn( + "Batch size smaller than self.batch_size!" + + " (This is fine for experiments with a single scene and viewpooling)" + ) + return frame_idx + + def _build_segment_index( + self, seq_frame_indices: List[int], size: int + ) -> List[Tuple[List[int], int]]: + """ + Returns a list of (segment, index) tuples, one per eligible frame, where + segment is a list of frame indices in the contiguous segment the frame + belongs to index is the frame's index within that segment. + Segment references are repeated but the memory is shared. + """ + if ( + self.consecutive_frames_max_gap > 0 + or self.consecutive_frames_max_gap_seconds > 0.0 + ): + sequence_timestamps = _sort_frames_by_timestamps_then_numbers( + seq_frame_indices, self.dataset + ) + # TODO: use new API to access frame numbers / timestamps + segments = self._split_to_segments(sequence_timestamps) + segments = _cull_short_segments(segments, size) + if not segments: + raise AssertionError("Empty segments after culling") + else: + segments = [seq_frame_indices] + + # build an index of segment for random selection of a pivot frame + segment_index = [ + (segment, i) for segment in segments for i in range(len(segment)) + ] + + return segment_index + + def _split_to_segments( + self, sequence_timestamps: List[Tuple[float, int, int]] + ) -> List[List[int]]: + if ( + self.consecutive_frames_max_gap <= 0 + and self.consecutive_frames_max_gap_seconds <= 0.0 + ): + raise AssertionError("This function is only needed for non-trivial max_gap") + + segments = [] + last_no = -self.consecutive_frames_max_gap - 1 # will trigger a new segment + last_ts = -self.consecutive_frames_max_gap_seconds - 1.0 + for ts, no, idx in sequence_timestamps: + if ts <= 0.0 and no <= last_no: + raise AssertionError( + "Frames are not ordered in seq_to_idx while timestamps are not given" + ) + + if ( + no - last_no > self.consecutive_frames_max_gap > 0 + or ts - last_ts > self.consecutive_frames_max_gap_seconds > 0.0 + ): # new group + segments.append([idx]) + else: + segments[-1].append(idx) + + last_no = no + last_ts = ts + + return segments + + +def _sort_frames_by_timestamps_then_numbers( + seq_frame_indices: List[int], dataset: ImplicitronDatasetBase +) -> List[Tuple[float, int, int]]: + """Build the list of triplets (timestamp, frame_no, dataset_idx). + We attempt to first sort by timestamp, then by frame number. + Timestamps are coalesced with 0s. + """ + nos_timestamps = dataset.get_frame_numbers_and_timestamps(seq_frame_indices) + + return sorted( + [ + (timestamp, frame_no, idx) + for idx, (frame_no, timestamp) in zip(seq_frame_indices, nos_timestamps) + ] + ) + + +def _cull_short_segments(segments: List[List[int]], min_size: int) -> List[List[int]]: + lengths = [(len(segment), segment) for segment in segments] + max_len, longest_segment = max(lengths) + + if max_len < min_size: + return [longest_segment] + + return [segment for segment in segments if len(segment) >= min_size] + + +def _capped_random_choice(x, size, replace: bool = True): + """ + if replace==True + randomly chooses from x `size` elements without replacement if len(x)>size + else allows replacement and selects `size` elements again. + if replace==False + randomly chooses from x `min(len(x), size)` elements without replacement + """ + len_x = x if isinstance(x, int) else len(x) + if replace: + return np.random.choice(x, size=size, replace=len_x < size) + else: + return np.random.choice(x, size=min(size, len_x), replace=False) diff --git a/pytorch3d/implicitron/dataset/types.py b/pytorch3d/implicitron/dataset/types.py new file mode 100644 index 00000000..1264dfb9 --- /dev/null +++ b/pytorch3d/implicitron/dataset/types.py @@ -0,0 +1,331 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import dataclasses +import gzip +import json +import sys +from dataclasses import MISSING, Field, dataclass +from typing import IO, Any, Optional, Tuple, Type, TypeVar, Union, cast + +import numpy as np + + +_X = TypeVar("_X") + + +if sys.version_info >= (3, 8, 0): + from typing import get_args, get_origin +elif sys.version_info >= (3, 7, 0): + + def get_origin(cls): + return getattr(cls, "__origin__", None) + + def get_args(cls): + return getattr(cls, "__args__", None) + + +else: + raise ImportError("This module requires Python 3.7+") + + +TF3 = Tuple[float, float, float] + + +@dataclass +class ImageAnnotation: + # path to jpg file, relative w.r.t. dataset_root + path: str + # H x W + size: Tuple[int, int] # TODO: rename size_hw? + + +@dataclass +class DepthAnnotation: + # path to png file, relative w.r.t. dataset_root, storing `depth / scale_adjustment` + path: str + # a factor to convert png values to actual depth: `depth = png * scale_adjustment` + scale_adjustment: float + # path to png file, relative w.r.t. dataset_root, storing binary `depth` mask + mask_path: Optional[str] + + +@dataclass +class MaskAnnotation: + # path to png file storing (Prob(fg | pixel) * 255) + path: str + # (soft) number of pixels in the mask; sum(Prob(fg | pixel)) + mass: Optional[float] = None + + +@dataclass +class ViewpointAnnotation: + # In right-multiply (PyTorch3D) format. X_cam = X_world @ R + T + R: Tuple[TF3, TF3, TF3] + T: TF3 + + focal_length: Tuple[float, float] + principal_point: Tuple[float, float] + + intrinsics_format: str = "ndc_norm_image_bounds" + # Defines the co-ordinate system where focal_length and principal_point live. + # Possible values: ndc_isotropic | ndc_norm_image_bounds (default) + # ndc_norm_image_bounds: legacy PyTorch3D NDC format, where image boundaries + # correspond to [-1, 1] x [-1, 1], and the scale along x and y may differ + # ndc_isotropic: PyTorch3D 0.5+ NDC convention where the shorter side has + # the range [-1, 1], and the longer one has the range [-s, s]; s >= 1, + # where s is the aspect ratio. The scale is same along x and y. + + +@dataclass +class FrameAnnotation: + """A dataclass used to load annotations from json.""" + + # can be used to join with `SequenceAnnotation` + sequence_name: str + # 0-based, continuous frame number within sequence + frame_number: int + # timestamp in seconds from the video start + frame_timestamp: float + + image: ImageAnnotation + depth: Optional[DepthAnnotation] = None + mask: Optional[MaskAnnotation] = None + viewpoint: Optional[ViewpointAnnotation] = None + + +@dataclass +class PointCloudAnnotation: + # path to ply file with points only, relative w.r.t. dataset_root + path: str + # the bigger the better + quality_score: float + n_points: Optional[int] + + +@dataclass +class VideoAnnotation: + # path to the original video file, relative w.r.t. dataset_root + path: str + # length of the video in seconds + length: float + + +@dataclass +class SequenceAnnotation: + sequence_name: str + category: str + video: Optional[VideoAnnotation] = None + point_cloud: Optional[PointCloudAnnotation] = None + # the bigger the better + viewpoint_quality_score: Optional[float] = None + + +def dump_dataclass(obj: Any, f: IO, binary: bool = False) -> None: + """ + Args: + f: Either a path to a file, or a file opened for writing. + obj: A @dataclass or collection hierarchy including dataclasses. + binary: Set to True if `f` is a file handle, else False. + """ + if binary: + f.write(json.dumps(_asdict_rec(obj)).encode("utf8")) + else: + json.dump(_asdict_rec(obj), f) + + +def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: + """ + Loads to a @dataclass or collection hierarchy including dataclasses + from a json recursively. + Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). + raises KeyError if json has keys not mapping to the dataclass fields. + + Args: + f: Either a path to a file, or a file opened for writing. + cls: The class of the loaded dataclass. + binary: Set to True if `f` is a file handle, else False. + """ + if binary: + asdict = json.loads(f.read().decode("utf8")) + else: + asdict = json.load(f) + + if isinstance(asdict, list): + # in the list case, run a faster "vectorized" version + cls = get_args(cls)[0] + res = list(_dataclass_list_from_dict_list(asdict, cls)) + else: + res = _dataclass_from_dict(asdict, cls) + + return res + + +def _dataclass_list_from_dict_list(dlist, typeannot): + """ + Vectorised version of `_dataclass_from_dict`. + The output should be equivalent to + `[_dataclass_from_dict(d, typeannot) for d in dlist]`. + + Args: + dlist: list of objects to convert. + typeannot: type of each of those objects. + Returns: + iterator or list over converted objects of the same length as `dlist`. + + Raises: + ValueError: it assumes the objects have None's in consistent places across + objects, otherwise it would ignore some values. This generally holds for + auto-generated annotations, but otherwise use `_dataclass_from_dict`. + """ + + cls = get_origin(typeannot) or typeannot + + if all(obj is None for obj in dlist): # 1st recursion base: all None nodes + return dlist + elif any(obj is None for obj in dlist): + # filter out Nones and recurse on the resulting list + idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] + idx, notnone = zip(*idx_notnone) + converted = _dataclass_list_from_dict_list(notnone, typeannot) + res = [None] * len(dlist) + for i, obj in zip(idx, converted): + res[i] = obj + return res + # otherwise, we dispatch by the type of the provided annotation to convert to + elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple + # For namedtuple, call the function recursively on the lists of corresponding keys + types = cls._field_types.values() + dlist_T = zip(*dlist) + res_T = [ + _dataclass_list_from_dict_list(key_list, tp) + for key_list, tp in zip(dlist_T, types) + ] + return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] + elif issubclass(cls, (list, tuple)): + # For list/tuple, call the function recursively on the lists of corresponding positions + types = get_args(typeannot) + if len(types) == 1: # probably List; replicate for all items + types = types * len(dlist[0]) + dlist_T = zip(*dlist) + res_T = ( + _dataclass_list_from_dict_list(pos_list, tp) + for pos_list, tp in zip(dlist_T, types) + ) + if issubclass(cls, tuple): + return list(zip(*res_T)) + else: + return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] + elif issubclass(cls, dict): + # For the dictionary, call the function recursively on concatenated keys and vertices + key_t, val_t = get_args(typeannot) + all_keys_res = _dataclass_list_from_dict_list( + [k for obj in dlist for k in obj.keys()], key_t + ) + all_vals_res = _dataclass_list_from_dict_list( + [k for obj in dlist for k in obj.values()], val_t + ) + indices = np.cumsum([len(obj) for obj in dlist]) + assert indices[-1] == len(all_keys_res) + + keys = np.split(list(all_keys_res), indices[:-1]) + vals = np.split(list(all_vals_res), indices[:-1]) + return [cls(zip(*k, v)) for k, v in zip(keys, vals)] + elif not dataclasses.is_dataclass(typeannot): + return dlist + + # dataclass node: 2nd recursion base; call the function recursively on the lists + # of the corresponding fields + assert dataclasses.is_dataclass(cls) + fieldtypes = { + f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) + for f in dataclasses.fields(typeannot) + } + + # NOTE the default object is shared here + key_lists = ( + _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) + for k, (type_, default) in fieldtypes.items() + ) + transposed = zip(*key_lists) + return [cls(*vals_as_tuple) for vals_as_tuple in transposed] + + +def _dataclass_from_dict(d, typeannot): + cls = get_origin(typeannot) or typeannot + if d is None: + return d + elif issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple + types = cls._field_types.values() + return cls(*[_dataclass_from_dict(v, tp) for v, tp in zip(d, types)]) + elif issubclass(cls, (list, tuple)): + types = get_args(typeannot) + if len(types) == 1: # probably List; replicate for all items + types = types * len(d) + return cls(_dataclass_from_dict(v, tp) for v, tp in zip(d, types)) + elif issubclass(cls, dict): + key_t, val_t = get_args(typeannot) + return cls( + (_dataclass_from_dict(k, key_t), _dataclass_from_dict(v, val_t)) + for k, v in d.items() + ) + elif not dataclasses.is_dataclass(typeannot): + return d + + assert dataclasses.is_dataclass(cls) + fieldtypes = {f.name: _unwrap_type(f.type) for f in dataclasses.fields(typeannot)} + return cls(**{k: _dataclass_from_dict(v, fieldtypes[k]) for k, v in d.items()}) + + +def _unwrap_type(tp): + # strips Optional wrapper, if any + if get_origin(tp) is Union: + args = get_args(tp) + if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721 + # this is typing.Optional + return args[0] if args[1] is type(None) else args[1] # noqa: E721 + return tp + + +def _get_dataclass_field_default(field: Field) -> Any: + if field.default_factory is not MISSING: + return field.default_factory() + elif field.default is not MISSING: + return field.default + else: + return None + + +def _asdict_rec(obj): + return dataclasses._asdict_inner(obj, dict) + + +def dump_dataclass_jgzip(outfile: str, obj: Any) -> None: + """ + Dumps obj to a gzipped json outfile. + + Args: + obj: A @dataclass or collection hiererchy including dataclasses. + outfile: The path to the output file. + """ + with gzip.GzipFile(outfile, "wb") as f: + dump_dataclass(obj, cast(IO, f), binary=True) + + +def load_dataclass_jgzip(outfile, cls): + """ + Loads a dataclass from a gzipped json outfile. + + Args: + outfile: The path to the loaded file. + cls: The type annotation of the loaded dataclass. + + Returns: + loaded_dataclass: The loaded dataclass. + """ + with gzip.GzipFile(outfile, "rb") as f: + return load_dataclass(cast(IO, f), cls, binary=True) diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py new file mode 100644 index 00000000..2bd06500 --- /dev/null +++ b/pytorch3d/implicitron/dataset/utils.py @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +import torch + + +DATASET_TYPE_TRAIN = "train" +DATASET_TYPE_TEST = "test" +DATASET_TYPE_KNOWN = "known" +DATASET_TYPE_UNKNOWN = "unseen" + + +def is_known_frame( + frame_type: List[str], device: Optional[str] = None +) -> torch.BoolTensor: + """ + Given a list `frame_type` of frame types in a batch, return a tensor + of boolean flags expressing whether the corresponding frame is a known frame. + """ + return torch.tensor( + [ft.endswith(DATASET_TYPE_KNOWN) for ft in frame_type], + dtype=torch.bool, + device=device, + ) + + +def is_train_frame( + frame_type: List[str], device: Optional[str] = None +) -> torch.BoolTensor: + """ + Given a list `frame_type` of frame types in a batch, return a tensor + of boolean flags expressing whether the corresponding frame is a training frame. + """ + return torch.tensor( + [ft.startswith(DATASET_TYPE_TRAIN) for ft in frame_type], + dtype=torch.bool, + device=device, + ) diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py new file mode 100644 index 00000000..37290ecd --- /dev/null +++ b/pytorch3d/implicitron/dataset/visualize.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple, cast + +import torch +from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud +from pytorch3d.structures import Pointclouds + +from .implicitron_dataset import FrameData, ImplicitronDataset + + +def get_implicitron_sequence_pointcloud( + dataset: ImplicitronDataset, + sequence_name: Optional[str] = None, + mask_points: bool = True, + max_frames: int = -1, + num_workers: int = 0, + load_dataset_point_cloud: bool = False, +) -> Tuple[Pointclouds, FrameData]: + """ + Make a point cloud by sampling random points from each frame the dataset. + """ + + if len(dataset) == 0: + raise ValueError("The dataset is empty.") + + if not dataset.load_depths: + raise ValueError("The dataset has to load depths (dataset.load_depths=True).") + + if mask_points and not dataset.load_masks: + raise ValueError( + "For mask_points=True, the dataset has to load masks" + + " (dataset.load_masks=True)." + ) + + # setup the indices of frames loaded from the dataset db + sequence_entries = list(range(len(dataset))) + if sequence_name is not None: + sequence_entries = [ + ei + for ei in sequence_entries + if dataset.frame_annots[ei]["frame_annotation"].sequence_name + == sequence_name + ] + if len(sequence_entries) == 0: + raise ValueError( + f'There are no dataset entries for sequence name "{sequence_name}".' + ) + + # subsample loaded frames if needed + if (max_frames > 0) and (len(sequence_entries) > max_frames): + sequence_entries = [ + sequence_entries[i] + for i in torch.randperm(len(sequence_entries))[:max_frames].sort().values + ] + + # take only the part of the dataset corresponding to the sequence entries + sequence_dataset = torch.utils.data.Subset(dataset, sequence_entries) + + # load the required part of the dataset + loader = torch.utils.data.DataLoader( + sequence_dataset, + batch_size=len(sequence_dataset), + shuffle=False, + num_workers=num_workers, + collate_fn=FrameData.collate, + ) + + frame_data = next(iter(loader)) # there's only one batch + + # scene point cloud + if load_dataset_point_cloud: + if not dataset.load_point_clouds: + raise ValueError( + "For load_dataset_point_cloud=True, the dataset has to" + + " load point clouds (dataset.load_point_clouds=True)." + ) + point_cloud = frame_data.sequence_point_cloud + + else: + point_cloud = get_rgbd_point_cloud( + frame_data.camera, + frame_data.image_rgb, + frame_data.depth_map, + (cast(torch.Tensor, frame_data.fg_probability) > 0.5).float() + if frame_data.fg_probability is not None + else None, + mask_points=mask_points, + ) + + return point_cloud, frame_data diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py new file mode 100644 index 00000000..d85158c8 --- /dev/null +++ b/pytorch3d/implicitron/eval_demo.py @@ -0,0 +1,216 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +import dataclasses +import os +from typing import Optional, cast + +import lpips +import torch +from pytorch3d.implicitron.dataset.dataloader_zoo import dataloader_zoo +from pytorch3d.implicitron.dataset.dataset_zoo import CO3D_CATEGORIES, dataset_zoo +from pytorch3d.implicitron.dataset.implicitron_dataset import ( + FrameData, + ImplicitronDataset, + ImplicitronDatasetBase, +) +from pytorch3d.implicitron.dataset.utils import is_known_frame +from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( + aggregate_nvs_results, + eval_batch, + pretty_print_nvs_metrics, + summarize_nvs_eval_results, +) +from pytorch3d.implicitron.models.model_dbir import ModelDBIR +from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ +from tqdm import tqdm + + +def main() -> None: + """ + Evaluates new view synthesis metrics of a simple depth-based image rendering + (DBIR) model for multisequence/singlesequence tasks for several categories. + + The evaluation is conducted on the same data as in [1] and, hence, the results + are directly comparable to the numbers reported in [1]. + + References: + [1] J. Reizenstein, R. Shapovalov, P. Henzler, L. Sbordone, + P. Labatut, D. Novotny: + Common Objects in 3D: Large-Scale Learning + and Evaluation of Real-life 3D Category Reconstruction + """ + + task_results = {} + for task in ("singlesequence", "multisequence"): + task_results[task] = [] + for category in CO3D_CATEGORIES[: (20 if task == "singlesequence" else 10)]: + for single_sequence_id in (0, 1) if task == "singlesequence" else (None,): + category_result = evaluate_dbir_for_category( + category, task=task, single_sequence_id=single_sequence_id + ) + print("") + print( + f"Results for task={task}; category={category};" + + ( + f" sequence={single_sequence_id}:" + if single_sequence_id is not None + else ":" + ) + ) + pretty_print_nvs_metrics(category_result) + print("") + + task_results[task].append(category_result) + _print_aggregate_results(task, task_results) + + for task in task_results: + _print_aggregate_results(task, task_results) + + +def evaluate_dbir_for_category( + category: str = "apple", + bg_color: float = 0.0, + task: str = "singlesequence", + single_sequence_id: Optional[int] = None, + num_workers: int = 16, +): + """ + Evaluates new view synthesis metrics of a simple depth-based image rendering + (DBIR) model for a given task, category, and sequence (in case task=='singlesequence'). + + Args: + category: Object category. + bg_color: Background color of the renders. + task: Evaluation task. Either singlesequence or multisequence. + single_sequence_id: The ID of the evaluiation sequence for the singlesequence task. + num_workers: The number of workers for the employed dataloaders. + + Returns: + category_result: A dictionary of quantitative metrics. + """ + + single_sequence_id = single_sequence_id if single_sequence_id is not None else -1 + + torch.manual_seed(42) + + if task not in ["multisequence", "singlesequence"]: + raise ValueError("'task' has to be either 'multisequence' or 'singlesequence'") + + datasets = dataset_zoo( + category=category, + dataset_root=os.environ["CO3D_DATASET_ROOT"], + assert_single_seq=task == "singlesequence", + dataset_name=f"co3d_{task}", + test_on_train=False, + load_point_clouds=True, + test_restrict_sequence_id=single_sequence_id, + ) + + dataloaders = dataloader_zoo( + datasets, + dataset_name=f"co3d_{task}", + ) + + test_dataset = datasets["test"] + test_dataloader = dataloaders["test"] + + if task == "singlesequence": + # all_source_cameras are needed for evaluation of the + # target camera difficulty + # pyre-fixme[16]: `ImplicitronDataset` has no attribute `frame_annots`. + sequence_name = test_dataset.frame_annots[0]["frame_annotation"].sequence_name + all_source_cameras = _get_all_source_cameras( + test_dataset, sequence_name, num_workers=num_workers + ) + else: + all_source_cameras = None + + image_size = cast(ImplicitronDataset, test_dataset).image_width + + if image_size is None: + raise ValueError("Image size should be set in the dataset") + + # init the simple DBIR model + model = ModelDBIR( + image_size=image_size, + bg_color=bg_color, + max_points=int(1e5), + ) + model.cuda() + + # init the lpips model for eval + lpips_model = lpips.LPIPS(net="vgg") + lpips_model = lpips_model.cuda() + + per_batch_eval_results = [] + print("Evaluating DBIR model ...") + for frame_data in tqdm(test_dataloader): + frame_data = dataclass_to_cuda_(frame_data) + preds = model(**dataclasses.asdict(frame_data)) + nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) + per_batch_eval_results.append( + eval_batch( + frame_data, + nvs_prediction, + bg_color=bg_color, + lpips_model=lpips_model, + source_cameras=all_source_cameras, + ) + ) + + category_result_flat, category_result = summarize_nvs_eval_results( + per_batch_eval_results, task + ) + + return category_result["results"] + + +def _print_aggregate_results(task, task_results) -> None: + """ + Prints the aggregate metrics for a given task. + """ + aggregate_task_result = aggregate_nvs_results(task_results[task]) + print("") + print(f"Aggregate results for task={task}:") + pretty_print_nvs_metrics(aggregate_task_result) + print("") + + +def _get_all_source_cameras( + dataset: ImplicitronDatasetBase, sequence_name: str, num_workers: int = 8 +): + """ + Loads all training cameras of a given sequence. + + The set of all seen cameras is needed for evaluating the viewpoint difficulty + for the singlescene evaluation. + + Args: + dataset: Co3D dataset object. + sequence_name: The name of the sequence. + num_workers: The number of for the utilized dataloader. + """ + + # load all source cameras of the sequence + seq_idx = dataset.seq_to_idx[sequence_name] + dataset_for_loader = torch.utils.data.Subset(dataset, seq_idx) + (all_frame_data,) = torch.utils.data.DataLoader( + dataset_for_loader, + shuffle=False, + batch_size=len(dataset_for_loader), + num_workers=num_workers, + collate_fn=FrameData.collate, + ) + is_known = is_known_frame(all_frame_data.frame_type) + source_cameras = all_frame_data.camera[torch.where(is_known)[0]] + return source_cameras + + +if __name__ == "__main__": + main() diff --git a/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py new file mode 100644 index 00000000..df8fbaec --- /dev/null +++ b/pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py @@ -0,0 +1,649 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +import warnings +from collections import OrderedDict +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import numpy as np +import torch +from pytorch3d.implicitron.dataset.implicitron_dataset import FrameData +from pytorch3d.implicitron.dataset.utils import is_known_frame, is_train_frame +from pytorch3d.implicitron.tools import vis_utils +from pytorch3d.implicitron.tools.camera_utils import volumetric_camera_overlaps +from pytorch3d.implicitron.tools.image_utils import mask_background +from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth, iou, rgb_l1 +from pytorch3d.implicitron.tools.point_cloud_utils import get_rgbd_point_cloud +from pytorch3d.implicitron.tools.vis_utils import make_depth_image +from pytorch3d.renderer.camera_utils import join_cameras_as_batch +from pytorch3d.renderer.cameras import CamerasBase, PerspectiveCameras +from pytorch3d.vis.plotly_vis import plot_scene +from tabulate import tabulate +from visdom import Visdom + + +EVAL_N_SRC_VIEWS = [1, 3, 5, 7, 9] + + +@dataclass +class NewViewSynthesisPrediction: + """ + Holds the tensors that describe a result of synthesizing new views. + """ + + depth_render: Optional[torch.Tensor] = None + image_render: Optional[torch.Tensor] = None + mask_render: Optional[torch.Tensor] = None + camera_distance: Optional[torch.Tensor] = None + + +@dataclass +class _Visualizer: + image_render: torch.Tensor + image_rgb_masked: torch.Tensor + depth_render: torch.Tensor + depth_map: torch.Tensor + depth_mask: torch.Tensor + + visdom_env: str = "eval_debug" + + _viz: Visdom = field(init=False) + + def __post_init__(self): + self._viz = vis_utils.get_visdom_connection() + + def show_rgb( + self, loss_value: float, metric_name: str, loss_mask_now: torch.Tensor + ): + self._viz.images( + torch.cat( + ( + self.image_render, + self.image_rgb_masked, + loss_mask_now.repeat(1, 3, 1, 1), + ), + dim=3, + ), + env=self.visdom_env, + win=metric_name, + opts={"title": f"{metric_name}_{loss_value:1.2f}"}, + ) + + def show_depth( + self, depth_loss: float, name_postfix: str, loss_mask_now: torch.Tensor + ): + self._viz.images( + torch.cat( + ( + make_depth_image(self.depth_render, loss_mask_now), + make_depth_image(self.depth_map, loss_mask_now), + ), + dim=3, + ), + env=self.visdom_env, + win="depth_abs" + name_postfix, + opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}"}, + ) + self._viz.images( + loss_mask_now, + env=self.visdom_env, + win="depth_abs" + name_postfix + "_mask", + opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_mask"}, + ) + self._viz.images( + self.depth_mask, + env=self.visdom_env, + win="depth_abs" + name_postfix + "_maskd", + opts={"title": f"depth_abs_{name_postfix}_{depth_loss:1.2f}_maskd"}, + ) + + # show the 3D plot + # pyre-fixme[9]: viewpoint_trivial has type `PerspectiveCameras`; used as + # `TensorProperties`. + viewpoint_trivial: PerspectiveCameras = PerspectiveCameras().to( + loss_mask_now.device + ) + pcl_pred = get_rgbd_point_cloud( + viewpoint_trivial, + self.image_render, + self.depth_render, + # mask_crop, + torch.ones_like(self.depth_render), + # loss_mask_now, + ) + pcl_gt = get_rgbd_point_cloud( + viewpoint_trivial, + self.image_rgb_masked, + self.depth_map, + # mask_crop, + torch.ones_like(self.depth_map), + # loss_mask_now, + ) + _pcls = { + pn: p + for pn, p in zip(("pred_depth", "gt_depth"), (pcl_pred, pcl_gt)) + if int(p.num_points_per_cloud()) > 0 + } + plotlyplot = plot_scene( + {f"pcl{name_postfix}": _pcls}, + camera_scale=1.0, + pointcloud_max_points=10000, + pointcloud_marker_size=1, + ) + self._viz.plotlyplot( + plotlyplot, + env=self.visdom_env, + win=f"pcl{name_postfix}", + ) + + +def eval_batch( + frame_data: FrameData, + nvs_prediction: NewViewSynthesisPrediction, + bg_color: Union[torch.Tensor, str, float] = "black", + mask_thr: float = 0.5, + lpips_model=None, + visualize: bool = False, + visualize_visdom_env: str = "eval_debug", + break_after_visualising: bool = True, + source_cameras: Optional[List[CamerasBase]] = None, +) -> Dict[str, Any]: + """ + Produce performance metrics for a single batch of new-view synthesis + predictions. + + Given a set of known views (for which frame_data.frame_type.endswith('known') + is True), a new-view synthesis method (NVS) is tasked to generate new views + of the scene from the viewpoint of the target views (for which + frame_data.frame_type.endswith('known') is False). The resulting + synthesized new views, stored in `nvs_prediction`, are compared to the + target ground truth in `frame_data` in terms of geometry and appearance + resulting in a dictionary of metrics returned by the `eval_batch` function. + + Args: + frame_data: A FrameData object containing the input to the new view + synthesis method. + nvs_prediction: The data describing the synthesized new views. + bg_color: The background color of the generated new views and the + ground truth. + lpips_model: A pre-trained model for evaluating the LPIPS metric. + visualize: If True, visualizes the results to Visdom. + source_cameras: A list of all training cameras for evaluating the + difficulty of the target views. + + Returns: + results: A dictionary holding evaluation metrics. + + Throws: + ValueError if frame_data does not have frame_type, camera, or image_rgb + ValueError if the batch has a mix of training and test samples + ValueError if the batch frames are not [unseen, known, known, ...] + ValueError if one of the required fields in nvs_prediction is missing + """ + REQUIRED_NVS_PREDICTION_FIELDS = ["mask_render", "image_render", "depth_render"] + frame_type = frame_data.frame_type + if frame_type is None: + raise ValueError("Frame type has not been set.") + + # we check that all those fields are not None but Pyre can't infer that properly + # TODO: assign to local variables + if frame_data.image_rgb is None: + raise ValueError("Image is not in the evaluation batch.") + + if frame_data.camera is None: + raise ValueError("Camera is not in the evaluation batch.") + + if any(not hasattr(nvs_prediction, k) for k in REQUIRED_NVS_PREDICTION_FIELDS): + raise ValueError("One of the required predicted fields is missing") + + # obtain copies to make sure we dont edit the original data + nvs_prediction = copy.deepcopy(nvs_prediction) + frame_data = copy.deepcopy(frame_data) + + # mask the ground truth depth in case frame_data contains the depth mask + if frame_data.depth_map is not None and frame_data.depth_mask is not None: + frame_data.depth_map *= frame_data.depth_mask + + if not isinstance(frame_type, list): # not batch FrameData + frame_type = [frame_type] + + is_train = is_train_frame(frame_type) + if not (is_train[0] == is_train).all(): + raise ValueError("All frames in the eval batch have to be either train/test.") + + # pyre-fixme[16]: `Optional` has no attribute `device`. + is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device) + + if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()): + raise ValueError( + "For evaluation the first element of the batch has to be" + + " a target view while the rest should be source views." + ) # TODO: do we need to enforce this? + + # take only the first (target image) + for k in REQUIRED_NVS_PREDICTION_FIELDS: + setattr(nvs_prediction, k, getattr(nvs_prediction, k)[:1]) + for k in [ + "depth_map", + "image_rgb", + "fg_probability", + "mask_crop", + ]: + if not hasattr(frame_data, k) or getattr(frame_data, k) is None: + continue + setattr(frame_data, k, getattr(frame_data, k)[:1]) + + if frame_data.depth_map is None or frame_data.depth_map.sum() <= 0: + warnings.warn("Empty or missing depth map in evaluation!") + + # eval all results in the resolution of the frame_data image + # pyre-fixme[16]: `Optional` has no attribute `shape`. + image_resol = list(frame_data.image_rgb.shape[2:]) + + # threshold the masks to make ground truth binary masks + mask_fg, mask_crop = [ + (getattr(frame_data, k) >= mask_thr) for k in ("fg_probability", "mask_crop") + ] + image_rgb_masked = mask_background( + # pyre-fixme[6]: Expected `Tensor` for 1st param but got + # `Optional[torch.Tensor]`. + frame_data.image_rgb, + mask_fg, + bg_color=bg_color, + ) + + # resize to the target resolution + for k in REQUIRED_NVS_PREDICTION_FIELDS: + imode = "bilinear" if k == "image_render" else "nearest" + val = getattr(nvs_prediction, k) + setattr( + nvs_prediction, + k, + # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got + # `List[typing.Any]`. + torch.nn.functional.interpolate(val, size=image_resol, mode=imode), + ) + + # clamp predicted images + # pyre-fixme[16]: `Optional` has no attribute `clamp`. + image_render = nvs_prediction.image_render.clamp(0.0, 1.0) + + if visualize: + visualizer = _Visualizer( + image_render=image_render, + image_rgb_masked=image_rgb_masked, + # pyre-fixme[6]: Expected `Tensor` for 3rd param but got + # `Optional[torch.Tensor]`. + depth_render=nvs_prediction.depth_render, + # pyre-fixme[6]: Expected `Tensor` for 4th param but got + # `Optional[torch.Tensor]`. + depth_map=frame_data.depth_map, + # pyre-fixme[16]: `Optional` has no attribute `__getitem__`. + depth_mask=frame_data.depth_mask[:1], + visdom_env=visualize_visdom_env, + ) + + results: Dict[str, Any] = {} + + results["iou"] = iou( + # pyre-fixme[6]: Expected `Tensor` for 1st param but got + # `Optional[torch.Tensor]`. + nvs_prediction.mask_render, + mask_fg, + mask=mask_crop, + ) + + for loss_fg_mask, name_postfix in zip((mask_crop, mask_fg), ("", "_fg")): + + loss_mask_now = mask_crop * loss_fg_mask + + for rgb_metric_name, rgb_metric_fun in zip( + ("psnr", "rgb_l1"), (calc_psnr, rgb_l1) + ): + metric_name = rgb_metric_name + name_postfix + results[metric_name] = rgb_metric_fun( + image_render, + image_rgb_masked, + mask=loss_mask_now, + ) + + if visualize: + visualizer.show_rgb( + results[metric_name].item(), metric_name, loss_mask_now + ) + + if name_postfix == "_fg": + # only record depth metrics for the foreground + _, abs_ = eval_depth( + # pyre-fixme[6]: Expected `Tensor` for 1st param but got + # `Optional[torch.Tensor]`. + nvs_prediction.depth_render, + # pyre-fixme[6]: Expected `Tensor` for 2nd param but got + # `Optional[torch.Tensor]`. + frame_data.depth_map, + get_best_scale=True, + mask=loss_mask_now, + crop=5, + ) + results["depth_abs" + name_postfix] = abs_.mean() + + if visualize: + visualizer.show_depth(abs_.mean().item(), name_postfix, loss_mask_now) + if break_after_visualising: + import pdb + + pdb.set_trace() + + if lpips_model is not None: + im1, im2 = [ + 2.0 * im.clamp(0.0, 1.0) - 1.0 + for im in (image_rgb_masked, nvs_prediction.image_render) + ] + results["lpips"] = lpips_model.forward(im1, im2).item() + + # convert all metrics to floats + results = {k: float(v) for k, v in results.items()} + + if source_cameras is None: + # pyre-fixme[16]: Optional has no attribute __getitem__ + source_cameras = frame_data.camera[torch.where(is_known)[0]] + + results["meta"] = { + # calculate the camera difficulties and add to results + "camera_difficulty": calculate_camera_difficulties( + frame_data.camera[0], + source_cameras, + )[0].item(), + # store the size of the batch (corresponds to n_src_views+1) + "batch_size": int(is_known.numel()), + # store the type of the target frame + # pyre-fixme[16]: `None` has no attribute `__getitem__`. + "frame_type": str(frame_data.frame_type[0]), + } + + return results + + +def average_per_batch_results( + results_per_batch: List[Dict[str, Any]], + idx: Optional[torch.Tensor] = None, +) -> dict: + """ + Average a list of per-batch metrics `results_per_batch`. + Optionally, if `idx` is given, only a subset of the per-batch + metrics, indexed by `idx`, is averaged. + """ + result_keys = list(results_per_batch[0].keys()) + result_keys.remove("meta") + if idx is not None: + results_per_batch = [results_per_batch[i] for i in idx] + if len(results_per_batch) == 0: + return {k: float("NaN") for k in result_keys} + return { + k: float(np.array([r[k] for r in results_per_batch]).mean()) + for k in result_keys + } + + +def calculate_camera_difficulties( + cameras_target: CamerasBase, + cameras_source: CamerasBase, +) -> torch.Tensor: + """ + Calculate the difficulties of the target cameras, given a set of known + cameras `cameras_source`. + + Returns: + a tensor of shape (len(cameras_target),) + """ + ious = [ + volumetric_camera_overlaps( + join_cameras_as_batch( + # pyre-fixme[6]: Expected `CamerasBase` for 1st param but got + # `Optional[pytorch3d.renderer.utils.TensorProperties]`. + [cameras_target[cami], cameras_source.to(cameras_target.device)] + ) + )[0, :] + for cami in range(cameras_target.R.shape[0]) + ] + camera_difficulties = torch.stack( + [_reduce_camera_iou_overlap(iou[1:]) for iou in ious] + ) + return camera_difficulties + + +def _reduce_camera_iou_overlap(ious: torch.Tensor, topk: int = 2) -> torch.Tensor: + """ + Calculate the final camera difficulty by computing the average of the + ious of the two most similar cameras. + + Returns: + single-element Tensor + """ + # pyre-ignore[16] topk not recognized + return ious.topk(k=min(topk, len(ious) - 1)).values.mean() + + +def get_camera_difficulty_bin_edges(task: str): + """ + Get the edges of camera difficulty bins. + """ + _eps = 1e-5 + if task == "multisequence": + # TODO: extract those to constants + diff_bin_edges = torch.linspace(0.5, 1.0 + _eps, 4) + diff_bin_edges[0] = 0.0 - _eps + elif task == "singlesequence": + diff_bin_edges = torch.tensor([0.0 - _eps, 0.97, 0.98, 1.0 + _eps]).float() + else: + raise ValueError(f"No such eval task {task}.") + diff_bin_names = ["hard", "medium", "easy"] + return diff_bin_edges, diff_bin_names + + +def summarize_nvs_eval_results( + per_batch_eval_results: List[Dict[str, Any]], + task: str = "singlesequence", +): + """ + Compile the per-batch evaluation results `per_batch_eval_results` into + a set of aggregate metrics. The produced metrics depend on the task. + + Args: + per_batch_eval_results: Metrics of each per-batch evaluation. + task: The type of the new-view synthesis task. + Either 'singlesequence' or 'multisequence'. + + Returns: + nvs_results_flat: A flattened dict of all aggregate metrics. + aux_out: A dictionary holding a set of auxiliary results. + """ + n_batches = len(per_batch_eval_results) + eval_sets: List[Optional[str]] = [] + if task == "singlesequence": + eval_sets = [None] + # assert n_batches==100 + elif task == "multisequence": + eval_sets = ["train", "test"] + # assert n_batches==1000 + else: + raise ValueError(task) + batch_sizes = torch.tensor( + [r["meta"]["batch_size"] for r in per_batch_eval_results] + ).long() + camera_difficulty = torch.tensor( + [r["meta"]["camera_difficulty"] for r in per_batch_eval_results] + ).float() + is_train = is_train_frame([r["meta"]["frame_type"] for r in per_batch_eval_results]) + + # init the result database dict + results = [] + + diff_bin_edges, diff_bin_names = get_camera_difficulty_bin_edges(task) + n_diff_edges = diff_bin_edges.numel() + + # add per set averages + for SET in eval_sets: + if SET is None: + # task=='singlesequence' + ok_set = torch.ones(n_batches, dtype=torch.bool) + set_name = "test" + else: + # task=='multisequence' + ok_set = is_train == int(SET == "train") + set_name = SET + + # eval each difficulty bin, including a full average result (diff_bin=None) + for diff_bin in [None, *list(range(n_diff_edges - 1))]: + if diff_bin is None: + # average over all results + in_bin = ok_set + diff_bin_name = "all" + else: + b1, b2 = diff_bin_edges[diff_bin : (diff_bin + 2)] + in_bin = ok_set & (camera_difficulty > b1) & (camera_difficulty <= b2) + diff_bin_name = diff_bin_names[diff_bin] + bin_results = average_per_batch_results( + per_batch_eval_results, idx=torch.where(in_bin)[0] + ) + results.append( + { + "subset": set_name, + "subsubset": f"diff={diff_bin_name}", + "metrics": bin_results, + } + ) + + if task == "multisequence": + # split based on n_src_views + n_src_views = batch_sizes - 1 + for n_src in EVAL_N_SRC_VIEWS: + ok_src = ok_set & (n_src_views == n_src) + n_src_results = average_per_batch_results( + per_batch_eval_results, + idx=torch.where(ok_src)[0], + ) + results.append( + { + "subset": set_name, + "subsubset": f"n_src={int(n_src)}", + "metrics": n_src_results, + } + ) + + aux_out = {"results": results} + return flatten_nvs_results(results), aux_out + + +def _get_flat_nvs_metric_key(result, metric_name) -> str: + metric_key_postfix = f"|subset={result['subset']}|{result['subsubset']}" + metric_key = f"{metric_name}{metric_key_postfix}" + return metric_key + + +def flatten_nvs_results(results): + """ + Takes input `results` list of dicts of the form: + ``` + [ + { + 'subset':'train/test/...', + 'subsubset': 'src=1/src=2/...', + 'metrics': nvs_eval_metrics} + }, + ... + ] + ``` + And converts to a flat dict as follows: + { + 'subset=train/test/...|subsubset=src=1/src=2/...': nvs_eval_metrics, + ... + } + """ + results_flat = {} + for result in results: + for metric_name, metric_val in result["metrics"].items(): + metric_key = _get_flat_nvs_metric_key(result, metric_name) + assert metric_key not in results_flat + results_flat[metric_key] = metric_val + return results_flat + + +def pretty_print_nvs_metrics(results) -> None: + subsets, subsubsets = [ + _ordered_set([r[k] for r in results]) for k in ("subset", "subsubset") + ] + metrics = _ordered_set([metric for r in results for metric in r["metrics"]]) + + for subset in subsets: + tab = {} + for metric in metrics: + tab[metric] = [] + header = ["metric"] + for subsubset in subsubsets: + metric_vals = [ + r["metrics"][metric] + for r in results + if r["subsubset"] == subsubset and r["subset"] == subset + ] + if len(metric_vals) > 0: + tab[metric].extend(metric_vals) + header.extend(subsubsets) + + if any(len(v) > 0 for v in tab.values()): + print(f"===== NVS results; subset={subset} =====") + print( + tabulate( + [[metric, *v] for metric, v in tab.items()], + # pyre-fixme[61]: `header` is undefined, or not always defined. + headers=header, + ) + ) + + +def _ordered_set(list_): + return list(OrderedDict((i, 0) for i in list_).keys()) + + +def aggregate_nvs_results(task_results): + """ + Aggregate nvs results. + For singlescene, this averages over all categories and scenes, + for multiscene, the average is over all per-category results. + """ + task_results_cat = [r_ for r in task_results for r_ in r] + subsets, subsubsets = [ + _ordered_set([r[k] for r in task_results_cat]) for k in ("subset", "subsubset") + ] + metrics = _ordered_set( + [metric for r in task_results_cat for metric in r["metrics"]] + ) + average_results = [] + for subset in subsets: + for subsubset in subsubsets: + metrics_lists = [ + r["metrics"] + for r in task_results_cat + if r["subsubset"] == subsubset and r["subset"] == subset + ] + avg_metrics = {} + for metric in metrics: + avg_metrics[metric] = float( + np.nanmean( + np.array([metric_list[metric] for metric_list in metrics_lists]) + ) + ) + average_results.append( + { + "subset": subset, + "subsubset": subsubset, + "metrics": avg_metrics, + } + ) + return average_results diff --git a/pytorch3d/implicitron/models/autodecoder.py b/pytorch3d/implicitron/models/autodecoder.py new file mode 100644 index 00000000..8b1dd4fb --- /dev/null +++ b/pytorch3d/implicitron/models/autodecoder.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from collections import defaultdict +from typing import Dict, List, Optional, Union + +import torch +from pytorch3d.implicitron.tools.config import Configurable + + +# TODO: probabilistic embeddings? +class Autodecoder(Configurable, torch.nn.Module): + """ + Autodecoder module + + Settings: + encoding_dim: Embedding dimension for the decoder. + n_instances: The maximum number of instances stored by the autodecoder. + init_scale: Scale factor for the initial autodecoder weights. + ignore_input: If `True`, optimizes a single code for any input. + """ + + encoding_dim: int = 0 + n_instances: int = 0 + init_scale: float = 1.0 + ignore_input: bool = False + + def __post_init__(self): + super().__init__() + if self.n_instances <= 0: + # Do not init the codes at all in case we have 0 instances. + return + self._autodecoder_codes = torch.nn.Embedding( + self.n_instances, + self.encoding_dim, + scale_grad_by_freq=True, + ) + with torch.no_grad(): + # weight has been initialised from Normal(0, 1) + self._autodecoder_codes.weight *= self.init_scale + + self._sequence_map = self._build_sequence_map() + # Make sure to register hooks for correct handling of saving/loading + # the module's _sequence_map. + self._register_load_state_dict_pre_hook(self._load_sequence_map_hook) + self._register_state_dict_hook(_save_sequence_map_hook) + + def _build_sequence_map( + self, sequence_map_dict: Optional[Dict[str, int]] = None + ) -> Dict[str, int]: + """ + Args: + sequence_map_dict: A dictionary used to initialize the sequence_map. + + Returns: + sequence_map: a dictionary of key: id pairs. + """ + # increments the counter when asked for a new value + sequence_map = defaultdict(iter(range(self.n_instances)).__next__) + if sequence_map_dict is not None: + # Assign all keys from the loaded sequence_map_dict to self._sequence_map. + # Since this is done in the original order, it should generate + # the same set of key:id pairs. We check this with an assert to be sure. + for x, x_id in sequence_map_dict.items(): + x_id_ = sequence_map[x] + assert x_id == x_id_ + return sequence_map + + def calc_squared_encoding_norm(self): + if self.n_instances <= 0: + return None + return (self._autodecoder_codes.weight ** 2).mean() + + def get_encoding_dim(self) -> int: + if self.n_instances <= 0: + return 0 + return self.encoding_dim + + def forward(self, x: Union[torch.LongTensor, List[str]]) -> Optional[torch.Tensor]: + """ + Args: + x: A batch of `N` sequence identifiers. Either a long tensor of size + `(N,)` keys in [0, n_instances), or a list of `N` string keys that + are hashed to codes (without collisions). + + Returns: + codes: A tensor of shape `(N, self.encoding_dim)` containing the + sequence-specific autodecoder codes. + """ + if self.n_instances == 0: + return None + + if self.ignore_input: + x = ["singleton"] + + if isinstance(x[0], str): + try: + x = torch.tensor( + # pyre-ignore[29] + [self._sequence_map[elem] for elem in x], + dtype=torch.long, + device=next(self.parameters()).device, + ) + except StopIteration: + raise ValueError("Not enough n_instances in the autodecoder") + + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + return self._autodecoder_codes(x) + + def _load_sequence_map_hook( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): + """ + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + + Returns: + Constructed sequence_map if it exists in the state_dict + else raises a warning only. + """ + sequence_map_key = prefix + "_sequence_map" + if sequence_map_key in state_dict: + sequence_map_dict = state_dict.pop(sequence_map_key) + self._sequence_map = self._build_sequence_map( + sequence_map_dict=sequence_map_dict + ) + else: + warnings.warn("No sequence map in Autodecoder state dict!") + + +def _save_sequence_map_hook( + self, + state_dict, + prefix, + local_metadata, +) -> None: + """ + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + """ + sequence_map_key = prefix + "_sequence_map" + sequence_map_dict = dict(self._sequence_map.items()) + state_dict[sequence_map_key] = sequence_map_dict diff --git a/pytorch3d/implicitron/models/base.py b/pytorch3d/implicitron/models/base.py new file mode 100644 index 00000000..5ab29783 --- /dev/null +++ b/pytorch3d/implicitron/models/base.py @@ -0,0 +1,883 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import math +import warnings +from dataclasses import field +from typing import Any, Dict, List, Optional, Tuple + +import torch +import tqdm +from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( + NewViewSynthesisPrediction, +) +from pytorch3d.implicitron.tools import image_utils, vis_utils +from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation +from pytorch3d.implicitron.tools.rasterize_mc import rasterize_mc_samples +from pytorch3d.implicitron.tools.utils import cat_dataclass +from pytorch3d.renderer import RayBundle, utils as rend_utils +from pytorch3d.renderer.cameras import CamerasBase +from visdom import Visdom + +from .autodecoder import Autodecoder +from .implicit_function.base import ImplicitFunctionBase +from .implicit_function.idr_feature_field import IdrFeatureField # noqa +from .implicit_function.neural_radiance_field import ( # noqa + NeRFormerImplicitFunction, + NeuralRadianceFieldImplicitFunction, +) +from .implicit_function.scene_representation_networks import ( # noqa + SRNHyperNetImplicitFunction, + SRNImplicitFunction, +) +from .metrics import ViewMetrics +from .renderer.base import ( + BaseRenderer, + EvaluationMode, + ImplicitFunctionWrapper, + RendererOutput, + RenderSamplingMode, +) +from .renderer.lstm_renderer import LSTMRenderer # noqa +from .renderer.multipass_ea import MultiPassEmissionAbsorptionRenderer # noqa +from .renderer.ray_sampler import RaySampler +from .renderer.sdf_renderer import SignedDistanceFunctionRenderer # noqa +from .resnet_feature_extractor import ResNetFeatureExtractor +from .view_pooling.feature_aggregation import FeatureAggregatorBase +from .view_pooling.view_sampling import ViewSampler + + +STD_LOG_VARS = ["objective", "epoch", "sec/it"] + + +# pyre-ignore: 13 +class GenericModel(Configurable, torch.nn.Module): + """ + GenericModel is a wrapper for the neural implicit + rendering and reconstruction pipeline which consists + of the following sequence of 7 steps (steps 2–4 are normally + skipped in overfitting scenario, since conditioning on source views + does not add much information; otherwise they should be present altogether): + + + (1) Ray Sampling + ------------------ + Rays are sampled from an image grid based on the target view(s). + │_____________ + │ │ + │ ▼ + │ (2) Feature Extraction (optional) + │ ----------------------- + │ A feature extractor (e.g. a convolutional + │ neural net) is used to extract image features + │ from the source view(s). + │ │ + │ ▼ + │ (3) View Sampling (optional) + │ ------------------ + │ Image features are sampled at the 2D projections + │ of a set of 3D points along each of the sampled + │ target rays from (1). + │ │ + │ ▼ + │ (4) Feature Aggregation (optional) + │ ------------------ + │ Aggregate features and masks sampled from + │ image view(s) in (3). + │ │ + │____________▼ + │ + ▼ + (5) Implicit Function Evaluation + ------------------ + Evaluate the implicit function(s) at the sampled ray points + (optionally pass in the aggregated image features from (4)). + │ + ▼ + (6) Rendering + ------------------ + Render the image into the target cameras by raymarching along + the sampled rays and aggregating the colors and densities + output by the implicit function in (5). + │ + ▼ + (7) Loss Computation + ------------------ + Compute losses based on the predicted target image(s). + + + The `forward` function of GenericModel executes + this sequence of steps. Currently, steps 1, 3, 4, 5, 6 + can be customized by intializing a subclass of the appropriate + baseclass and adding the newly created module to the registry. + Please see https://github.com/fairinternal/pytorch3d/blob/co3d/projects/implicitron_trainer/README.md#custom-plugins + for more details on how to create and register a custom component. + + In the config .yaml files for experiments, the parameters below are + contained in the `generic_model_args` node. As GenericModel + derives from Configurable, the input arguments are + parsed by the run_auto_creation function to initialize the + necessary member modules. Please see implicitron_trainer/README.md + for more details on this process. + + Args: + mask_images: Whether or not to mask the RGB image background given the + foreground mask (the `fg_probability` argument of `GenericModel.forward`) + mask_depths: Whether or not to mask the depth image background given the + foreground mask (the `fg_probability` argument of `GenericModel.forward`) + render_image_width: Width of the output image to render + render_image_height: Height of the output image to render + mask_threshold: If greater than 0.0, the foreground mask is + thresholded by this value before being applied to the RGB/Depth images + output_rasterized_mc: If True, visualize the Monte-Carlo pixel renders by + splatting onto an image grid. Default: False. + bg_color: RGB values for the background color. Default (0.0, 0.0, 0.0) + view_pool: If True, features are sampled from the source image(s) + at the projected 2d locations of the sampled 3d ray points from the target + view(s), i.e. this activates step (3) above. + num_passes: The specified implicit_function is initialized num_passes + times and run sequentially. + chunk_size_grid: The total number of points which can be rendered + per chunk. This is used to compute the number of rays used + per chunk when the chunked version of the renderer is used (in order + to fit rendering on all rays in memory) + render_features_dimensions: The number of output features to render. + Defaults to 3, corresponding to RGB images. + n_train_target_views: The number of cameras to render into at training + time; first `n_train_target_views` in the batch are considered targets, + the rest are sources. + sampling_mode_training: The sampling method to use during training. Must be + a value from the RenderSamplingMode Enum. + sampling_mode_evaluation: Same as above but for evaluation. + sequence_autodecoder: An instance of `Autodecoder`. This is used to generate an encoding + of the image (referred to as the global_code) that can be used to model aspects of + the scene such as multiple objects or morphing objects. It is up to the implicit + function definition how to use it, but the most typical way is to broadcast and + concatenate to the other inputs for the implicit function. + raysampler: An instance of RaySampler which is used to emit + rays from the target view(s). + renderer_class_type: The name of the renderer class which is available in the global + registry. + renderer: A renderer class which inherits from BaseRenderer. This is used to + generate the images from the target view(s). + image_feature_extractor: A module for extrating features from an input image. + view_sampler: An instance of ViewSampler which is used for sampling of + image-based features at the 2D projections of a set + of 3D points. + feature_aggregator_class_type: The name of the feature aggregator class which + is available in the global registry. + feature_aggregator: A feature aggregator class which inherits from + FeatureAggregatorBase. Typically, the aggregated features and their + masks are output by a `ViewSampler` which samples feature tensors extracted + from a set of source images. FeatureAggregator executes step (4) above. + implicit_function_class_type: The type of implicit function to use which + is available in the global registry. + implicit_function: An instance of ImplicitFunctionBase. The actual implicit functions + are initialised to be in self._implicit_functions. + loss_weights: A dictionary with a {loss_name: weight} mapping; see documentation + for `ViewMetrics` class for available loss functions. + log_vars: A list of variable names which should be logged. + The names should correspond to a subset of the keys of the + dict `preds` output by the `forward` function. + """ + + mask_images: bool = True + mask_depths: bool = True + render_image_width: int = 400 + render_image_height: int = 400 + mask_threshold: float = 0.5 + output_rasterized_mc: bool = False + bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) + view_pool: bool = False + num_passes: int = 1 + chunk_size_grid: int = 4096 + render_features_dimensions: int = 3 + tqdm_trigger_threshold: int = 16 + + n_train_target_views: int = 1 + sampling_mode_training: str = "mask_sample" + sampling_mode_evaluation: str = "full_grid" + + # ---- autodecoder settings + sequence_autodecoder: Autodecoder + + # ---- raysampler + raysampler: RaySampler + + # ---- renderer configs + renderer_class_type: str = "MultiPassEmissionAbsorptionRenderer" + renderer: BaseRenderer + + # ---- view sampling settings - used if view_pool=True + # (This is only created if view_pool is False) + image_feature_extractor: ResNetFeatureExtractor + view_sampler: ViewSampler + # ---- ---- view sampling feature aggregator settings + feature_aggregator_class_type: str = "AngleWeightedReductionFeatureAggregator" + feature_aggregator: FeatureAggregatorBase + + # ---- implicit function settings + implicit_function_class_type: str = "NeuralRadianceFieldImplicitFunction" + # This is just a model, never constructed. + # The actual implicit functions live in self._implicit_functions + implicit_function: ImplicitFunctionBase + + # ---- loss weights + loss_weights: Dict[str, float] = field( + default_factory=lambda: { + "loss_rgb_mse": 1.0, + "loss_prev_stage_rgb_mse": 1.0, + "loss_mask_bce": 0.0, + "loss_prev_stage_mask_bce": 0.0, + } + ) + + # ---- variables to be logged (logger automatically ignores if not computed) + log_vars: List[str] = field( + default_factory=lambda: [ + "loss_rgb_psnr_fg", + "loss_rgb_psnr", + "loss_rgb_mse", + "loss_rgb_huber", + "loss_depth_abs", + "loss_depth_abs_fg", + "loss_mask_neg_iou", + "loss_mask_bce", + "loss_mask_beta_prior", + "loss_eikonal", + "loss_density_tv", + "loss_depth_neg_penalty", + "loss_autodecoder_norm", + # metrics that are only logged in 2+stage renderes + "loss_prev_stage_rgb_mse", + "loss_prev_stage_rgb_psnr_fg", + "loss_prev_stage_rgb_psnr", + "loss_prev_stage_mask_bce", + *STD_LOG_VARS, + ] + ) + + def __post_init__(self): + super().__init__() + self.view_metrics = ViewMetrics() + + self._check_and_preprocess_renderer_configs() + self.raysampler_args["sampling_mode_training"] = self.sampling_mode_training + self.raysampler_args["sampling_mode_evaluation"] = self.sampling_mode_evaluation + self.raysampler_args["image_width"] = self.render_image_width + self.raysampler_args["image_height"] = self.render_image_height + run_auto_creation(self) + + self._implicit_functions = self._construct_implicit_functions() + + self.print_loss_weights() + + def forward( + self, + *, # force keyword-only arguments + image_rgb: Optional[torch.Tensor], + camera: CamerasBase, + fg_probability: Optional[torch.Tensor], + mask_crop: Optional[torch.Tensor], + depth_map: Optional[torch.Tensor], + sequence_name: Optional[List[str]], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs, + ) -> Dict[str, Any]: + """ + Args: + image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images; + the first `min(B, n_train_target_views)` images are considered targets and + are used to supervise the renders; the rest corresponding to the source + viewpoints from which features will be extracted. + camera: An instance of CamerasBase containing a batch of `B` cameras corresponding + to the viewpoints of target images, from which the rays will be sampled, + and source images, which will be used for intersecting with target rays. + fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch of + foreground masks. + mask_crop: A binary tensor of shape `(B, 1, H, W)` deonting valid + regions in the input images (i.e. regions that do not correspond + to, e.g., zero-padding). When the `RaySampler`'s sampling mode is set to + "mask_sample", rays will be sampled in the non zero regions. + depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. + sequence_name: A list of `B` strings corresponding to the sequence names + from which images `image_rgb` were extracted. They are used to match + target frames with relevant source frames. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering. + + Returns: + preds: A dictionary containing all outputs of the forward pass including the + rendered images, depths, masks, losses and other metrics. + """ + image_rgb, fg_probability, depth_map = self._preprocess_input( + image_rgb, fg_probability, depth_map + ) + + # Obtain the batch size from the camera as this is the only required input. + batch_size = camera.R.shape[0] + + # Determine the number of target views, i.e. cameras we render into. + n_targets = ( + 1 + if evaluation_mode == EvaluationMode.EVALUATION + else batch_size + if self.n_train_target_views <= 0 + else min(self.n_train_target_views, batch_size) + ) + + # Select the target cameras. + target_cameras = camera[list(range(n_targets))] + + # Determine the used ray sampling mode. + sampling_mode = RenderSamplingMode( + self.sampling_mode_training + if evaluation_mode == EvaluationMode.TRAINING + else self.sampling_mode_evaluation + ) + + # (1) Sample rendering rays with the ray sampler. + ray_bundle: RayBundle = self.raysampler( + target_cameras, + evaluation_mode, + mask=mask_crop[:n_targets] + if mask_crop is not None and sampling_mode == RenderSamplingMode.MASK_SAMPLE + else None, + ) + + # custom_args hold additional arguments to the implicit function. + custom_args = {} + + if self.view_pool: + if sequence_name is None: + raise ValueError("sequence_name must be provided for view pooling") + # (2) Extract features for the image + img_feats = self.image_feature_extractor(image_rgb, fg_probability) + + # (3) Sample features and masks at the ray points + curried_view_sampler = lambda pts: self.view_sampler( # noqa: E731 + pts=pts, + seq_id_pts=sequence_name[:n_targets], + camera=camera, + seq_id_camera=sequence_name, + feats=img_feats, + masks=mask_crop, + ) # returns feats_sampled, masks_sampled + + # (4) Aggregate features from multiple views + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + curried_view_pool = lambda pts: self.feature_aggregator( # noqa: E731 + *curried_view_sampler(pts=pts), + pts=pts, + camera=camera, + ) # TODO: do we need to pass a callback rather than compute here? + # precomputing will be faster for 2 passes + # -> but this is important for non-nerf + custom_args["fun_viewpool"] = curried_view_pool + + global_code = None + if self.sequence_autodecoder.n_instances > 0: + if sequence_name is None: + raise ValueError("sequence_name must be provided for autodecoder.") + global_code = self.sequence_autodecoder(sequence_name[:n_targets]) + custom_args["global_code"] = global_code + + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, + # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor, + # torch.nn.Module]` is not a function. + for func in self._implicit_functions: + func.bind_args(**custom_args) + + object_mask: Optional[torch.Tensor] = None + if fg_probability is not None: + sampled_fb_prob = rend_utils.ndc_grid_sample( + fg_probability[:n_targets], ray_bundle.xys, mode="nearest" + ) + object_mask = sampled_fb_prob > 0.5 + + # (5)-(6) Implicit function evaluation and Rendering + rendered = self._render( + ray_bundle=ray_bundle, + sampling_mode=sampling_mode, + evaluation_mode=evaluation_mode, + implicit_functions=self._implicit_functions, + object_mask=object_mask, + ) + + # Unbind the custom arguments to prevent pytorch from storing + # large buffers of intermediate results due to points in the + # bound arguments. + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self, + # torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor, + # torch.nn.Module]` is not a function. + for func in self._implicit_functions: + func.unbind_args() + + preds = self._get_view_metrics( + raymarched=rendered, + xys=ray_bundle.xys, + image_rgb=None if image_rgb is None else image_rgb[:n_targets], + depth_map=None if depth_map is None else depth_map[:n_targets], + fg_probability=None + if fg_probability is None + else fg_probability[:n_targets], + mask_crop=None if mask_crop is None else mask_crop[:n_targets], + ) + + if sampling_mode == RenderSamplingMode.MASK_SAMPLE: + if self.output_rasterized_mc: + # Visualize the monte-carlo pixel renders by splatting onto + # an image grid. + ( + preds["images_render"], + preds["depths_render"], + preds["masks_render"], + ) = self._rasterize_mc_samples( + ray_bundle.xys, + rendered.features, + rendered.depths, + masks=rendered.masks, + ) + elif sampling_mode == RenderSamplingMode.FULL_GRID: + preds["images_render"] = rendered.features.permute(0, 3, 1, 2) + preds["depths_render"] = rendered.depths.permute(0, 3, 1, 2) + preds["masks_render"] = rendered.masks.permute(0, 3, 1, 2) + + preds["nvs_prediction"] = NewViewSynthesisPrediction( + image_render=preds["images_render"], + depth_render=preds["depths_render"], + mask_render=preds["masks_render"], + ) + else: + raise AssertionError("Unreachable state") + + # calc the AD penalty, returns None if autodecoder is not active + ad_penalty = self.sequence_autodecoder.calc_squared_encoding_norm() + if ad_penalty is not None: + preds["loss_autodecoder_norm"] = ad_penalty + + # (7) Compute losses + # finally get the optimization objective using self.loss_weights + objective = self._get_objective(preds) + if objective is not None: + preds["objective"] = objective + + return preds + + def _get_objective(self, preds) -> Optional[torch.Tensor]: + """ + A helper function to compute the overall loss as the dot product + of individual loss functions with the corresponding weights. + """ + losses_weighted = [ + preds[k] * float(w) + for k, w in self.loss_weights.items() + if (k in preds and w != 0.0) + ] + if len(losses_weighted) == 0: + warnings.warn("No main objective found.") + return None + loss = sum(losses_weighted) + assert torch.is_tensor(loss) + return loss + + def visualize( + self, + viz: Visdom, + visdom_env_imgs: str, + preds: Dict[str, Any], + prefix: str, + ) -> None: + """ + Helper function to visualize the predictions generated + in the forward pass. + + + Args: + viz: Visdom connection object + visdom_env_imgs: name of visdom environment for the images. + preds: predictions dict like returned by forward() + prefix: prepended to the names of images + """ + if not viz.check_connection(): + print("no visdom server! -> skipping batch vis") + return + + idx_image = 0 + title = f"{prefix}_im{idx_image}" + + vis_utils.visualize_basics(viz, preds, visdom_env_imgs, title=title) + + def _render( + self, + *, + ray_bundle: RayBundle, + object_mask: Optional[torch.Tensor], + sampling_mode: RenderSamplingMode, + **kwargs, + ) -> RendererOutput: + """ + Args: + ray_bundle: A `RayBundle` object containing the parametrizations of the + sampled rendering rays. + object_mask: A tensor of shape `(B, 3, H, W)` denoting the silhouette of the object + in the image. This is required for the SignedDistanceFunctionRenderer. + sampling_mode: The sampling method to use. Must be a value from the + RenderSamplingMode Enum. + Returns: + An instance of RendererOutput + + """ + if sampling_mode == RenderSamplingMode.FULL_GRID and self.chunk_size_grid > 0: + return _apply_chunked( + self.renderer, + _chunk_generator( + self.chunk_size_grid, + ray_bundle, + object_mask, + self.tqdm_trigger_threshold, + **kwargs, + ), + lambda batch: _tensor_collator(batch, ray_bundle.lengths.shape[:-1]), + ) + else: + # pyre-fixme[29]: `BaseRenderer` is not a function. + return self.renderer( + ray_bundle=ray_bundle, + object_mask=object_mask, + **kwargs, + ) + + def _get_viewpooled_feature_dim(self): + return ( + self.feature_aggregator.get_aggregated_feature_dim( + self.image_feature_extractor.get_feat_dims() + ) + if self.view_pool + else 0 + ) + + def _check_and_preprocess_renderer_configs(self): + self.renderer_MultiPassEmissionAbsorptionRenderer_args[ + "stratified_sampling_coarse_training" + ] = self.raysampler_args["stratified_point_sampling_training"] + self.renderer_MultiPassEmissionAbsorptionRenderer_args[ + "stratified_sampling_coarse_evaluation" + ] = self.raysampler_args["stratified_point_sampling_evaluation"] + self.renderer_SignedDistanceFunctionRenderer_args[ + "render_features_dimensions" + ] = self.render_features_dimensions + self.renderer_SignedDistanceFunctionRenderer_args.ray_tracer_args[ + "object_bounding_sphere" + ] = self.raysampler_args["scene_extent"] + + def create_image_feature_extractor(self): + """ + Custom creation function called by run_auto_creation so that the + image_feature_extractor is not created if it is not be needed. + """ + if self.view_pool: + self.image_feature_extractor = ResNetFeatureExtractor( + **self.image_feature_extractor_args + ) + + def create_implicit_function(self) -> None: + """ + No-op called by run_auto_creation so that self.implicit_function + does not get created. __post_init__ creates the implicit function(s) + in wrappers explicitly in self._implicit_functions. + """ + pass + + def _construct_implicit_functions(self): + """ + After run_auto_creation has been called, the arguments + for each of the possible implicit function methods are + available. `GenericModel` arguments are first validated + based on the custom requirements for each specific + implicit function method. Then the required implicit + function(s) are initialized. + """ + # nerf preprocessing + nerf_args = self.implicit_function_NeuralRadianceFieldImplicitFunction_args + nerformer_args = self.implicit_function_NeRFormerImplicitFunction_args + nerf_args["latent_dim"] = nerformer_args["latent_dim"] = ( + self._get_viewpooled_feature_dim() + + self.sequence_autodecoder.get_encoding_dim() + ) + nerf_args["color_dim"] = nerformer_args[ + "color_dim" + ] = self.render_features_dimensions + + # idr preprocessing + idr = self.implicit_function_IdrFeatureField_args + idr["feature_vector_size"] = self.render_features_dimensions + idr["encoding_dim"] = self.sequence_autodecoder.get_encoding_dim() + + # srn preprocessing + srn = self.implicit_function_SRNImplicitFunction_args + srn.raymarch_function_args.latent_dim = ( + self._get_viewpooled_feature_dim() + + self.sequence_autodecoder.get_encoding_dim() + ) + + # srn_hypernet preprocessing + srn_hypernet = self.implicit_function_SRNHyperNetImplicitFunction_args + srn_hypernet_args = srn_hypernet.hypernet_args + srn_hypernet_args.latent_dim_hypernet = ( + self.sequence_autodecoder.get_encoding_dim() + ) + srn_hypernet_args.latent_dim = self._get_viewpooled_feature_dim() + + # check that for srn, srn_hypernet, idr we have self.num_passes=1 + implicit_function_type = registry.get( + ImplicitFunctionBase, self.implicit_function_class_type + ) + if self.num_passes != 1 and not implicit_function_type.allows_multiple_passes(): + raise ValueError( + self.implicit_function_class_type + + f"requires num_passes=1 not {self.num_passes}" + ) + + if implicit_function_type.requires_pooling_without_aggregation(): + has_aggregation = hasattr(self.feature_aggregator, "reduction_functions") + if not self.view_pool or has_aggregation: + raise ValueError( + "Chosen implicit function requires view pooling without aggregation." + ) + config_name = f"implicit_function_{self.implicit_function_class_type}_args" + config = getattr(self, config_name, None) + if config is None: + raise ValueError(f"{config_name} not present") + implicit_functions_list = [ + ImplicitFunctionWrapper(implicit_function_type(**config)) + for _ in range(self.num_passes) + ] + return torch.nn.ModuleList(implicit_functions_list) + + def print_loss_weights(self) -> None: + """ + Print a table of the loss weights. + """ + print("-------\nloss_weights:") + for k, w in self.loss_weights.items(): + print(f"{k:40s}: {w:1.2e}") + print("-------") + + def _preprocess_input( + self, + image_rgb: Optional[torch.Tensor], + fg_probability: Optional[torch.Tensor], + depth_map: Optional[torch.Tensor], + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Helper function to preprocess the input images and optional depth maps + to apply masking if required. + + Args: + image_rgb: A tensor of shape `(B, 3, H, W)` containing a batch of rgb images + corresponding to the source viewpoints from which features will be extracted + fg_probability: A tensor of shape `(B, 1, H, W)` containing a batch + of foreground masks with values in [0, 1]. + depth_map: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. + + Returns: + Modified image_rgb, fg_mask, depth_map + """ + fg_mask = fg_probability + if fg_mask is not None and self.mask_threshold > 0.0: + # threshold masks + warnings.warn("Thresholding masks!") + fg_mask = (fg_mask >= self.mask_threshold).type_as(fg_mask) + + if self.mask_images and fg_mask is not None and image_rgb is not None: + # mask the image + warnings.warn("Masking images!") + image_rgb = image_utils.mask_background( + image_rgb, fg_mask, dim_color=1, bg_color=torch.tensor(self.bg_color) + ) + + if self.mask_depths and fg_mask is not None and depth_map is not None: + # mask the depths + assert ( + self.mask_threshold > 0.0 + ), "Depths should be masked only with thresholded masks" + warnings.warn("Masking depths!") + depth_map = depth_map * fg_mask + + return image_rgb, fg_mask, depth_map + + def _get_view_metrics( + self, + raymarched: RendererOutput, + xys: torch.Tensor, + image_rgb: Optional[torch.Tensor] = None, + depth_map: Optional[torch.Tensor] = None, + fg_probability: Optional[torch.Tensor] = None, + mask_crop: Optional[torch.Tensor] = None, + keys_prefix: str = "loss_", + ): + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + metrics = self.view_metrics( + image_sampling_grid=xys, + images_pred=raymarched.features, + images=image_rgb, + depths_pred=raymarched.depths, + depths=depth_map, + masks_pred=raymarched.masks, + masks=fg_probability, + masks_crop=mask_crop, + keys_prefix=keys_prefix, + **raymarched.aux, + ) + + if raymarched.prev_stage: + metrics.update( + self._get_view_metrics( + raymarched.prev_stage, + xys, + image_rgb, + depth_map, + fg_probability, + mask_crop, + keys_prefix=(keys_prefix + "prev_stage_"), + ) + ) + + return metrics + + @torch.no_grad() + def _rasterize_mc_samples( + self, + xys: torch.Tensor, + features: torch.Tensor, + depth: torch.Tensor, + masks: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Rasterizes Monte-Carlo features back onto the image. + + Args: + xys: B x ... x 2 2D point locations in PyTorch3D NDC convention + features: B x ... x C tensor containing per-point rendered features. + depth: B x ... x 1 tensor containing per-point rendered depth. + """ + ba = xys.shape[0] + + # Flatten the features and xy locations. + features_depth_ras = torch.cat( + ( + features.reshape(ba, -1, features.shape[-1]), + depth.reshape(ba, -1, 1), + ), + dim=-1, + ) + xys_ras = xys.reshape(ba, -1, 2) + if masks is not None: + masks_ras = masks.reshape(ba, -1, 1) + else: + masks_ras = None + + if min(self.render_image_height, self.render_image_width) <= 0: + raise ValueError( + "Need to specify a positive" + " self.render_image_height and self.render_image_width" + " for MC rasterisation." + ) + + # Estimate the rasterization point radius so that we approximately fill + # the whole image given the number of rasterized points. + pt_radius = 2.0 * math.sqrt(xys.shape[1]) + + # Rasterize the samples. + features_depth_render, masks_render = rasterize_mc_samples( + xys_ras, + features_depth_ras, + (self.render_image_height, self.render_image_width), + radius=pt_radius, + masks=masks_ras, + ) + images_render = features_depth_render[:, :-1] + depths_render = features_depth_render[:, -1:] + return images_render, depths_render, masks_render + + +def _apply_chunked(func, chunk_generator, tensor_collator): + """ + Helper function to apply a function on a sequence of + chunked inputs yielded by a generator and collate + the result. + """ + processed_chunks = [ + func(*chunk_args, **chunk_kwargs) + for chunk_args, chunk_kwargs in chunk_generator + ] + + return cat_dataclass(processed_chunks, tensor_collator) + + +def _tensor_collator(batch, new_dims) -> torch.Tensor: + """ + Helper function to reshape the batch to the desired shape + """ + return torch.cat(batch, dim=1).reshape(*new_dims, -1) + + +def _chunk_generator( + chunk_size: int, + ray_bundle: RayBundle, + object_mask: Optional[torch.Tensor], + tqdm_trigger_threshold: int, + *args, + **kwargs, +): + """ + Helper function which yields chunks of rays from the + input ray_bundle, to be used when the number of rays is + large and will not fit in memory for rendering. + """ + ( + batch_size, + *spatial_dim, + n_pts_per_ray, + ) = ray_bundle.lengths.shape # B x ... x n_pts_per_ray + if n_pts_per_ray > 0 and chunk_size % n_pts_per_ray != 0: + raise ValueError( + f"chunk_size_grid ({chunk_size}) should be divisible " + f"by n_pts_per_ray ({n_pts_per_ray})" + ) + + n_rays = math.prod(spatial_dim) + # special handling for raytracing-based methods + n_chunks = -(-n_rays * max(n_pts_per_ray, 1) // chunk_size) + chunk_size_in_rays = -(-n_rays // n_chunks) + + iter = range(0, n_rays, chunk_size_in_rays) + if len(iter) >= tqdm_trigger_threshold: + iter = tqdm.tqdm(iter) + + for start_idx in iter: + end_idx = min(start_idx + chunk_size_in_rays, n_rays) + ray_bundle_chunk = RayBundle( + origins=ray_bundle.origins.reshape(batch_size, -1, 3)[:, start_idx:end_idx], + directions=ray_bundle.directions.reshape(batch_size, -1, 3)[ + :, start_idx:end_idx + ], + lengths=ray_bundle.lengths.reshape( + batch_size, math.prod(spatial_dim), n_pts_per_ray + )[:, start_idx:end_idx], + xys=ray_bundle.xys.reshape(batch_size, -1, 2)[:, start_idx:end_idx], + ) + extra_args = kwargs.copy() + if object_mask is not None: + extra_args["object_mask"] = object_mask.reshape(batch_size, -1, 1)[ + :, start_idx:end_idx + ] + yield [ray_bundle_chunk, *args], extra_args diff --git a/pytorch3d/implicitron/models/implicit_function/__init__.py b/pytorch3d/implicitron/models/implicit_function/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/pytorch3d/implicitron/models/implicit_function/base.py b/pytorch3d/implicitron/models/implicit_function/base.py new file mode 100644 index 00000000..742fde16 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/base.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from typing import Optional + +from pytorch3d.implicitron.tools.config import ReplaceableBase +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit import RayBundle + + +class ImplicitFunctionBase(ABC, ReplaceableBase): + def __init__(self): + super().__init__() + + @abstractmethod + def forward( + self, + ray_bundle: RayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + **kwargs, + ): + raise NotImplementedError() + + @staticmethod + def allows_multiple_passes() -> bool: + """ + Returns True if this implicit function allows + multiple passes. + """ + return False + + @staticmethod + def requires_pooling_without_aggregation() -> bool: + """ + Returns True if this implicit function needs + pooling without aggregation. + """ + return False + + def on_bind_args(self) -> None: + """ + Called when the custom args are fixed in the main model forward pass. + """ + pass diff --git a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py new file mode 100644 index 00000000..1cf39275 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py @@ -0,0 +1,142 @@ +# @lint-ignore-every LICENSELINT +# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/ +# implicit_differentiable_renderer.py +# Copyright (c) 2020 Lior Yariv +import math +from typing import Sequence + +import torch +from pytorch3d.implicitron.tools.config import registry +from pytorch3d.renderer.implicit import HarmonicEmbedding +from torch import nn + +from .base import ImplicitFunctionBase + + +@registry.register +class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): + feature_vector_size: int = 3 + d_in: int = 3 + d_out: int = 1 + dims: Sequence[int] = (512, 512, 512, 512, 512, 512, 512, 512) + geometric_init: bool = True + bias: float = 1.0 + skip_in: Sequence[int] = () + weight_norm: bool = True + n_harmonic_functions_xyz: int = 0 + pooled_feature_dim: int = 0 + encoding_dim: int = 0 + + def __post_init__(self): + super().__init__() + + dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size] + + self.embed_fn = None + if self.n_harmonic_functions_xyz > 0: + self.embed_fn = HarmonicEmbedding( + self.n_harmonic_functions_xyz, append_input=True + ) + dims[0] = self.embed_fn.get_output_dim() + if self.pooled_feature_dim > 0: + dims[0] += self.pooled_feature_dim + if self.encoding_dim > 0: + dims[0] += self.encoding_dim + + self.num_layers = len(dims) + + out_dim = 0 + layers = [] + for layer_idx in range(self.num_layers - 1): + if layer_idx + 1 in self.skip_in: + out_dim = dims[layer_idx + 1] - dims[0] + else: + out_dim = dims[layer_idx + 1] + + lin = nn.Linear(dims[layer_idx], out_dim) + + if self.geometric_init: + if layer_idx == self.num_layers - 2: + torch.nn.init.normal_( + lin.weight, + mean=math.pi ** 0.5 / dims[layer_idx] ** 0.5, + std=0.0001, + ) + torch.nn.init.constant_(lin.bias, -self.bias) + elif self.n_harmonic_functions_xyz > 0 and layer_idx == 0: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.constant_(lin.weight[:, 3:], 0.0) + torch.nn.init.normal_( + lin.weight[:, :3], 0.0, 2 ** 0.5 / out_dim ** 0.5 + ) + elif self.n_harmonic_functions_xyz > 0 and layer_idx in self.skip_in: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5) + torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) + else: + torch.nn.init.constant_(lin.bias, 0.0) + torch.nn.init.normal_(lin.weight, 0.0, 2 ** 0.5 / out_dim ** 0.5) + + if self.weight_norm: + lin = nn.utils.weight_norm(lin) + + layers.append(lin) + + self.linear_layers = torch.nn.ModuleList(layers) + self.out_dim = out_dim + self.softplus = nn.Softplus(beta=100) + + # pyre-fixme[14]: `forward` overrides method defined in `ImplicitFunctionBase` + # inconsistently. + def forward( + self, + # ray_bundle: RayBundle, + rays_points_world: torch.Tensor, # TODO: unify the APIs + fun_viewpool=None, + global_code=None, + ): + # this field only uses point locations + # rays_points_world = ray_bundle_to_ray_points(ray_bundle) + # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] + + if rays_points_world.numel() == 0 or ( + self.embed_fn is None and fun_viewpool is None and global_code is None + ): + return torch.tensor( + [], device=rays_points_world.device, dtype=rays_points_world.dtype + ).view(0, self.out_dim) + + embedding = None + if self.embed_fn is not None: + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + embedding = self.embed_fn(rays_points_world) + + if fun_viewpool is not None: + assert rays_points_world.ndim == 2 + pooled_feature = fun_viewpool(rays_points_world[None]) + # TODO: pooled features are 4D! + embedding = torch.cat((embedding, pooled_feature), dim=-1) + + if global_code is not None: + assert embedding.ndim == 2 + assert global_code.shape[0] == 1 # TODO: generalize to batches! + # This will require changing raytracer code + # embedding = embedding[None].expand(global_code.shape[0], *embedding.shape) + embedding = torch.cat( + (embedding, global_code[0, None, :].expand(*embedding.shape[:-1], -1)), + dim=-1, + ) + + x = embedding + for layer_idx in range(self.num_layers - 1): + if layer_idx in self.skip_in: + x = torch.cat([x, embedding], dim=-1) / 2 ** 0.5 + + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + x = self.linear_layers[layer_idx](x) + + if layer_idx < self.num_layers - 2: + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + x = self.softplus(x) + + return x # TODO: unify the APIs diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py new file mode 100644 index 00000000..1b283543 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -0,0 +1,542 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import field +from typing import List, Optional + +import torch +from pytorch3d.common.linear_with_repeat import LinearWithRepeat +from pytorch3d.implicitron.tools.config import registry +from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit import HarmonicEmbedding + +from .base import ImplicitFunctionBase +from .utils import create_embeddings_for_implicit_function + + +class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): + n_harmonic_functions_xyz: int = 10 + n_harmonic_functions_dir: int = 4 + n_hidden_neurons_dir: int = 128 + latent_dim: int = 0 + input_xyz: bool = True + xyz_ray_dir_in_camera_coords: bool = False + color_dim: int = 3 + """ + Args: + n_harmonic_functions_xyz: The number of harmonic functions + used to form the harmonic embedding of 3D point locations. + n_harmonic_functions_dir: The number of harmonic functions + used to form the harmonic embedding of the ray directions. + n_hidden_neurons_xyz: The number of hidden units in the + fully connected layers of the MLP that accepts the 3D point + locations and outputs the occupancy field with the intermediate + features. + n_hidden_neurons_dir: The number of hidden units in the + fully connected layers of the MLP that accepts the intermediate + features and ray directions and outputs the radiance field + (per-point colors). + n_layers_xyz: The number of layers of the MLP that outputs the + occupancy field. + append_xyz: The list of indices of the skip layers of the occupancy MLP. + """ + + def __post_init__(self): + super().__init__() + # The harmonic embedding layer converts input 3D coordinates + # to a representation that is more suitable for + # processing with a deep neural network. + self.harmonic_embedding_xyz = HarmonicEmbedding( + self.n_harmonic_functions_xyz, append_input=True + ) + self.harmonic_embedding_dir = HarmonicEmbedding( + self.n_harmonic_functions_dir, append_input=True + ) + if not self.input_xyz and self.latent_dim <= 0: + raise ValueError("The latent dimension has to be > 0 if xyz is not input!") + + embedding_dim_dir = self.harmonic_embedding_dir.get_output_dim() + + self.xyz_encoder = self._construct_xyz_encoder( + input_dim=self.get_xyz_embedding_dim() + ) + + self.intermediate_linear = torch.nn.Linear( + self.n_hidden_neurons_xyz, self.n_hidden_neurons_xyz + ) + _xavier_init(self.intermediate_linear) + + self.density_layer = torch.nn.Linear(self.n_hidden_neurons_xyz, 1) + _xavier_init(self.density_layer) + + # Zero the bias of the density layer to avoid + # a completely transparent initialization. + self.density_layer.bias.data[:] = 0.0 # fixme: Sometimes this is not enough + + self.color_layer = torch.nn.Sequential( + LinearWithRepeat( + self.n_hidden_neurons_xyz + embedding_dim_dir, self.n_hidden_neurons_dir + ), + torch.nn.ReLU(True), + torch.nn.Linear(self.n_hidden_neurons_dir, self.color_dim), + torch.nn.Sigmoid(), + ) + + def get_xyz_embedding_dim(self): + return ( + self.harmonic_embedding_xyz.get_output_dim() * int(self.input_xyz) + + self.latent_dim + ) + + def _construct_xyz_encoder(self, input_dim: int): + raise NotImplementedError() + + def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor): + """ + This function takes per-point `features` predicted by `self.xyz_encoder` + and evaluates the color model in order to attach to each + point a 3D vector of its RGB color. + """ + # Normalize the ray_directions to unit l2 norm. + rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) + # Obtain the harmonic embedding of the normalized ray directions. + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + rays_embedding = self.harmonic_embedding_dir(rays_directions_normed) + + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + return self.color_layer((self.intermediate_linear(features), rays_embedding)) + + @staticmethod + def allows_multiple_passes() -> bool: + """ + Returns True as this implicit function allows + multiple passes. Overridden from ImplicitFunctionBase. + """ + return True + + def forward( + self, + ray_bundle: RayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + **kwargs, + ): + """ + The forward function accepts the parametrizations of + 3D points sampled along projection rays. The forward + pass is responsible for attaching a 3D vector + and a 1D scalar representing the point's + RGB color and opacity respectively. + + Args: + ray_bundle: A RayBundle object containing the following variables: + origins: A tensor of shape `(minibatch, ..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(minibatch, ..., 3)` + containing the direction vectors of sampling rays in world coords. + lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + fun_viewpool: an optional callback with the signature + fun_fiewpool(points) -> pooled_features + where points is a [N_TGT x N x 3] tensor of world coords, + and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor + of the features pooled from the context images. + + Returns: + rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` + denoting the opacitiy of each ray point. + rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` + denoting the color of each ray point. + """ + # We first convert the ray parametrizations to world + # coordinates with `ray_bundle_to_ray_points`. + rays_points_world = ray_bundle_to_ray_points(ray_bundle) + # rays_points_world.shape = [minibatch x ... x pts_per_ray x 3] + + embeds = create_embeddings_for_implicit_function( + xyz_world=ray_bundle_to_ray_points(ray_bundle), + # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` + # for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`. + xyz_embedding_function=self.harmonic_embedding_xyz + if self.input_xyz + else None, + global_code=global_code, + fun_viewpool=fun_viewpool, + xyz_in_camera_coords=self.xyz_ray_dir_in_camera_coords, + camera=camera, + ) + + # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3] + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + features = self.xyz_encoder(embeds) + # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz] + # NNs operate on the flattenned rays; reshaping to the correct spatial size + # TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape + features = features.reshape(*rays_points_world.shape[:-1], -1) + + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + raw_densities = self.density_layer(features) + # raw_densities.shape = [minibatch x ... x 1] in [0-1] + + if self.xyz_ray_dir_in_camera_coords: + if camera is None: + raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") + + directions = ray_bundle.directions @ camera.R + else: + directions = ray_bundle.directions + + rays_colors = self._get_colors(features, directions) + # rays_colors.shape = [minibatch x ... x 3] in [0-1] + + return raw_densities, rays_colors, {} + + +@registry.register +class NeuralRadianceFieldImplicitFunction(NeuralRadianceFieldBase): + transformer_dim_down_factor: float = 1.0 + n_hidden_neurons_xyz: int = 256 + n_layers_xyz: int = 8 + append_xyz: List[int] = field(default_factory=lambda: [5]) + + def _construct_xyz_encoder(self, input_dim: int): + return MLPWithInputSkips( + self.n_layers_xyz, + input_dim, + self.n_hidden_neurons_xyz, + input_dim, + self.n_hidden_neurons_xyz, + input_skips=self.append_xyz, + ) + + +@registry.register +class NeRFormerImplicitFunction(NeuralRadianceFieldBase): + transformer_dim_down_factor: float = 2.0 + n_hidden_neurons_xyz: int = 80 + n_layers_xyz: int = 2 + append_xyz: List[int] = field(default_factory=lambda: [1]) + + def _construct_xyz_encoder(self, input_dim: int): + return TransformerWithInputSkips( + self.n_layers_xyz, + input_dim, + self.n_hidden_neurons_xyz, + input_dim, + self.n_hidden_neurons_xyz, + input_skips=self.append_xyz, + dim_down_factor=self.transformer_dim_down_factor, + ) + + @staticmethod + def requires_pooling_without_aggregation() -> bool: + """ + Returns True as this implicit function needs + pooling without aggregation. Overridden from ImplicitFunctionBase. + """ + return True + + +class MLPWithInputSkips(torch.nn.Module): + """ + Implements the multi-layer perceptron architecture of the Neural Radiance Field. + + As such, `MLPWithInputSkips` is a multi layer perceptron consisting + of a sequence of linear layers with ReLU activations. + + Additionally, for a set of predefined layers `input_skips`, the forward pass + appends a skip tensor `z` to the output of the preceding layer. + + Note that this follows the architecture described in the Supplementary + Material (Fig. 7) of [1]. + + References: + [1] Ben Mildenhall and Pratul P. Srinivasan and Matthew Tancik + and Jonathan T. Barron and Ravi Ramamoorthi and Ren Ng: + NeRF: Representing Scenes as Neural Radiance Fields for View + Synthesis, ECCV2020 + """ + + def _make_affine_layer(self, input_dim, hidden_dim): + l1 = torch.nn.Linear(input_dim, hidden_dim * 2) + l2 = torch.nn.Linear(hidden_dim * 2, hidden_dim * 2) + _xavier_init(l1) + _xavier_init(l2) + return torch.nn.Sequential(l1, torch.nn.ReLU(True), l2) + + def _apply_affine_layer(self, layer, x, z): + mu_log_std = layer(z) + mu, log_std = mu_log_std.split(mu_log_std.shape[-1] // 2, dim=-1) + std = torch.nn.functional.softplus(log_std) + return (x - mu) * std + + def __init__( + self, + n_layers: int = 8, + input_dim: int = 39, + output_dim: int = 256, + skip_dim: int = 39, + hidden_dim: int = 256, + input_skips: List[int] = [5], + skip_affine_trans: bool = False, + no_last_relu=False, + ): + """ + Args: + n_layers: The number of linear layers of the MLP. + input_dim: The number of channels of the input tensor. + output_dim: The number of channels of the output. + skip_dim: The number of channels of the tensor `z` appended when + evaluating the skip layers. + hidden_dim: The number of hidden units of the MLP. + input_skips: The list of layer indices at which we append the skip + tensor `z`. + """ + super().__init__() + layers = [] + skip_affine_layers = [] + for layeri in range(n_layers): + dimin = hidden_dim if layeri > 0 else input_dim + dimout = hidden_dim if layeri + 1 < n_layers else output_dim + + if layeri > 0 and layeri in input_skips: + if skip_affine_trans: + skip_affine_layers.append( + self._make_affine_layer(skip_dim, hidden_dim) + ) + else: + dimin = hidden_dim + skip_dim + + linear = torch.nn.Linear(dimin, dimout) + _xavier_init(linear) + layers.append( + torch.nn.Sequential(linear, torch.nn.ReLU(True)) + if not no_last_relu or layeri + 1 < n_layers + else linear + ) + self.mlp = torch.nn.ModuleList(layers) + if skip_affine_trans: + self.skip_affines = torch.nn.ModuleList(skip_affine_layers) + self._input_skips = set(input_skips) + self._skip_affine_trans = skip_affine_trans + + def forward(self, x: torch.Tensor, z: Optional[torch.Tensor] = None): + """ + Args: + x: The input tensor of shape `(..., input_dim)`. + z: The input skip tensor of shape `(..., skip_dim)` which is appended + to layers whose indices are specified by `input_skips`. + Returns: + y: The output tensor of shape `(..., output_dim)`. + """ + y = x + if z is None: + # if the skip tensor is None, we use `x` instead. + z = x + skipi = 0 + for li, layer in enumerate(self.mlp): + if li in self._input_skips: + if self._skip_affine_trans: + y = self._apply_affine_layer(self.skip_affines[skipi], y, z) + else: + y = torch.cat((y, z), dim=-1) + skipi += 1 + y = layer(y) + return y + + +class TransformerWithInputSkips(torch.nn.Module): + def __init__( + self, + n_layers: int = 8, + input_dim: int = 39, + output_dim: int = 256, + skip_dim: int = 39, + hidden_dim: int = 64, + input_skips: List[int] = [5], + dim_down_factor: float = 1, + ): + """ + Args: + n_layers: The number of linear layers of the MLP. + input_dim: The number of channels of the input tensor. + output_dim: The number of channels of the output. + skip_dim: The number of channels of the tensor `z` appended when + evaluating the skip layers. + hidden_dim: The number of hidden units of the MLP. + input_skips: The list of layer indices at which we append the skip + tensor `z`. + """ + super().__init__() + + self.first = torch.nn.Linear(input_dim, hidden_dim) + _xavier_init(self.first) + + self.skip_linear = torch.nn.ModuleList() + + layers_pool, layers_ray = [], [] + dimout = 0 + for layeri in range(n_layers): + dimin = int(round(hidden_dim / (dim_down_factor ** layeri))) + dimout = int(round(hidden_dim / (dim_down_factor ** (layeri + 1)))) + print(f"Tr: {dimin} -> {dimout}") + for _i, l in enumerate((layers_pool, layers_ray)): + l.append( + TransformerEncoderLayer( + d_model=[dimin, dimout][_i], + nhead=4, + dim_feedforward=hidden_dim, + dropout=0.0, + d_model_out=dimout, + ) + ) + + if layeri in input_skips: + self.skip_linear.append(torch.nn.Linear(input_dim, dimin)) + + self.last = torch.nn.Linear(dimout, output_dim) + _xavier_init(self.last) + + self.layers_pool, self.layers_ray = ( + torch.nn.ModuleList(layers_pool), + torch.nn.ModuleList(layers_ray), + ) + self._input_skips = set(input_skips) + + def forward( + self, + x: torch.Tensor, + z: Optional[torch.Tensor] = None, + ): + """ + Args: + x: The input tensor of shape + `(minibatch, n_pooled_feats, ..., n_ray_pts, input_dim)`. + z: The input skip tensor of shape + `(minibatch, n_pooled_feats, ..., n_ray_pts, skip_dim)` + which is appended to layers whose indices are specified by `input_skips`. + Returns: + y: The output tensor of shape + `(minibatch, 1, ..., n_ray_pts, input_dim)`. + """ + + if z is None: + # if the skip tensor is None, we use `x` instead. + z = x + + y = self.first(x) + + B, n_pool, n_rays, n_pts, dim = y.shape + + # y_p in n_pool, n_pts, B x n_rays x dim + y_p = y.permute(1, 3, 0, 2, 4) + + skipi = 0 + dimh = dim + for li, (layer_pool, layer_ray) in enumerate( + zip(self.layers_pool, self.layers_ray) + ): + y_pool_attn = y_p.reshape(n_pool, n_pts * B * n_rays, dimh) + if li in self._input_skips: + z_skip = self.skip_linear[skipi](z) + y_pool_attn = y_pool_attn + z_skip.permute(1, 3, 0, 2, 4).reshape( + n_pool, n_pts * B * n_rays, dimh + ) + skipi += 1 + # n_pool x B*n_rays*n_pts x dim + y_pool_attn, pool_attn = layer_pool(y_pool_attn, src_key_padding_mask=None) + dimh = y_pool_attn.shape[-1] + + y_ray_attn = ( + y_pool_attn.view(n_pool, n_pts, B * n_rays, dimh) + .permute(1, 0, 2, 3) + .reshape(n_pts, n_pool * B * n_rays, dimh) + ) + # n_pts x n_pool*B*n_rays x dim + y_ray_attn, ray_attn = layer_ray( + y_ray_attn, + src_key_padding_mask=None, + ) + + y_p = y_ray_attn.view(n_pts, n_pool, B * n_rays, dimh).permute(1, 0, 2, 3) + + y = y_p.view(n_pool, n_pts, B, n_rays, dimh).permute(2, 0, 3, 1, 4) + + W = torch.softmax(y[..., :1], dim=1) + y = (y * W).sum(dim=1) + y = self.last(y) + + return y + + +class TransformerEncoderLayer(torch.nn.Module): + r"""TransformerEncoderLayer is made up of self-attn and feedforward network. + This standard encoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of intermediate layer, relu or gelu (default=relu). + + Examples:: + >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, d_model, nhead, dim_feedforward=2048, dropout=0.1, d_model_out=-1 + ): + super(TransformerEncoderLayer, self).__init__() + self.self_attn = torch.nn.MultiheadAttention(d_model, nhead, dropout=dropout) + # Implementation of Feedforward model + self.linear1 = torch.nn.Linear(d_model, dim_feedforward) + self.dropout = torch.nn.Dropout(dropout) + d_model_out = d_model if d_model_out <= 0 else d_model_out + self.linear2 = torch.nn.Linear(dim_feedforward, d_model_out) + self.norm1 = torch.nn.LayerNorm(d_model) + self.norm2 = torch.nn.LayerNorm(d_model_out) + self.dropout1 = torch.nn.Dropout(dropout) + self.dropout2 = torch.nn.Dropout(dropout) + + self.activation = torch.nn.functional.relu + + def forward(self, src, src_mask=None, src_key_padding_mask=None): + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + src2, attn = self.self_attn( + src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask + ) + src = src + self.dropout1(src2) + src = self.norm1(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + d_out = src2.shape[-1] + src = src[..., :d_out] + self.dropout2(src2)[..., :d_out] + src = self.norm2(src) + return src, attn + + +def _xavier_init(linear) -> None: + """ + Performs the Xavier weight initialization of the linear layer `linear`. + """ + torch.nn.init.xavier_uniform_(linear.weight.data) diff --git a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py new file mode 100644 index 00000000..ce3389e4 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py @@ -0,0 +1,411 @@ +# @lint-ignore-every LICENSELINT +# Adapted from https://github.com/vsitzmann/scene-representation-networks +# Copyright (c) 2019 Vincent Sitzmann +from typing import Any, Optional, Tuple, cast + +import torch +from pytorch3d.common.linear_with_repeat import LinearWithRepeat +from pytorch3d.implicitron.third_party import hyperlayers, pytorch_prototyping +from pytorch3d.implicitron.tools.config import Configurable, registry, run_auto_creation +from pytorch3d.renderer import RayBundle, ray_bundle_to_ray_points +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.implicit import HarmonicEmbedding + +from .base import ImplicitFunctionBase +from .utils import create_embeddings_for_implicit_function + + +def _kaiming_normal_init(module: torch.nn.Module) -> None: + if isinstance(module, (torch.nn.Linear, LinearWithRepeat)): + torch.nn.init.kaiming_normal_( + module.weight, a=0.0, nonlinearity="relu", mode="fan_in" + ) + + +class SRNRaymarchFunction(Configurable, torch.nn.Module): + n_harmonic_functions: int = 3 # 0 means raw 3D coord inputs + n_hidden_units: int = 256 + n_layers: int = 2 + in_features: int = 3 + out_features: int = 256 + latent_dim: int = 0 + xyz_in_camera_coords: bool = False + + # The internal network can be set as an output of an SRNHyperNet. + # Note that, in order to avoid Pytorch's automatic registering of the + # raymarch_function module on construction, we input the network wrapped + # as a 1-tuple. + + # raymarch_function should ideally be typed as Optional[Tuple[Callable]] + # but Omegaconf.structured doesn't like that. TODO: revisit after new + # release of omegaconf including https://github.com/omry/omegaconf/pull/749 . + raymarch_function: Any = None + + def __post_init__(self): + super().__init__() + self._harmonic_embedding = HarmonicEmbedding( + self.n_harmonic_functions, append_input=True + ) + input_embedding_dim = ( + HarmonicEmbedding.get_output_dim_static( + self.in_features, + self.n_harmonic_functions, + True, + ) + + self.latent_dim + ) + + if self.raymarch_function is not None: + self._net = self.raymarch_function[0] + else: + self._net = pytorch_prototyping.FCBlock( + hidden_ch=self.n_hidden_units, + num_hidden_layers=self.n_layers, + in_features=input_embedding_dim, + out_features=self.out_features, + ) + + def forward( + self, + ray_bundle: RayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + **kwargs, + ): + """ + Args: + ray_bundle: A RayBundle object containing the following variables: + origins: A tensor of shape `(minibatch, ..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(minibatch, ..., 3)` + containing the direction vectors of sampling rays in world coords. + lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + fun_viewpool: an optional callback with the signature + fun_fiewpool(points) -> pooled_features + where points is a [N_TGT x N x 3] tensor of world coords, + and pooled_features is a [N_TGT x ... x N_SRC x latent_dim] tensor + of the features pooled from the context images. + + Returns: + rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` + denoting the opacitiy of each ray point. + rays_colors: Set to None. + """ + # We first convert the ray parametrizations to world + # coordinates with `ray_bundle_to_ray_points`. + rays_points_world = ray_bundle_to_ray_points(ray_bundle) + + embeds = create_embeddings_for_implicit_function( + xyz_world=ray_bundle_to_ray_points(ray_bundle), + # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]` + # for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`. + xyz_embedding_function=self._harmonic_embedding, + global_code=global_code, + fun_viewpool=fun_viewpool, + xyz_in_camera_coords=self.xyz_in_camera_coords, + camera=camera, + ) + + # Before running the network, we have to resize embeds to ndims=3, + # otherwise the SRN layers consume huge amounts of memory. + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + raymarch_features = self._net( + embeds.view(embeds.shape[0], -1, embeds.shape[-1]) + ) + # raymarch_features.shape = [minibatch x ... x self.n_hidden_neurons_xyz] + + # NNs operate on the flattenned rays; reshaping to the correct spatial size + raymarch_features = raymarch_features.reshape(*rays_points_world.shape[:-1], -1) + + return raymarch_features, None + + +class SRNPixelGenerator(Configurable, torch.nn.Module): + n_harmonic_functions: int = 4 + n_hidden_units: int = 256 + n_hidden_units_color: int = 128 + n_layers: int = 2 + in_features: int = 256 + out_features: int = 3 + ray_dir_in_camera_coords: bool = False + + def __post_init__(self): + super().__init__() + self._harmonic_embedding = HarmonicEmbedding( + self.n_harmonic_functions, append_input=True + ) + self._net = pytorch_prototyping.FCBlock( + hidden_ch=self.n_hidden_units, + num_hidden_layers=self.n_layers, + in_features=self.in_features, + out_features=self.n_hidden_units, + ) + self._density_layer = torch.nn.Linear(self.n_hidden_units, 1) + self._density_layer.apply(_kaiming_normal_init) + embedding_dim_dir = self._harmonic_embedding.get_output_dim(input_dims=3) + self._color_layer = torch.nn.Sequential( + LinearWithRepeat( + self.n_hidden_units + embedding_dim_dir, + self.n_hidden_units_color, + ), + torch.nn.LayerNorm([self.n_hidden_units_color]), + torch.nn.ReLU(inplace=True), + torch.nn.Linear(self.n_hidden_units_color, self.out_features), + ) + self._color_layer.apply(_kaiming_normal_init) + + # TODO: merge with NeuralRadianceFieldBase's _get_colors + def _get_colors(self, features: torch.Tensor, rays_directions: torch.Tensor): + """ + This function takes per-point `features` predicted by `self.net` + and evaluates the color model in order to attach to each + point a 3D vector of its RGB color. + """ + # Normalize the ray_directions to unit l2 norm. + rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1) + # Obtain the harmonic embedding of the normalized ray directions. + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + rays_embedding = self._harmonic_embedding(rays_directions_normed) + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + return self._color_layer((features, rays_embedding)) + + def forward( + self, + raymarch_features: torch.Tensor, + ray_bundle: RayBundle, + camera: Optional[CamerasBase] = None, + **kwargs, + ): + """ + Args: + raymarch_features: Features from the raymarching network of shape + `(minibatch, ..., self.in_features)` + ray_bundle: A RayBundle object containing the following variables: + origins: A tensor of shape `(minibatch, ..., 3)` denoting the + origins of the sampling rays in world coords. + directions: A tensor of shape `(minibatch, ..., 3)` + containing the direction vectors of sampling rays in world coords. + lengths: A tensor of shape `(minibatch, ..., num_points_per_ray)` + containing the lengths at which the rays are sampled. + + Returns: + rays_densities: A tensor of shape `(minibatch, ..., num_points_per_ray, 1)` + denoting the opacitiy of each ray point. + rays_colors: A tensor of shape `(minibatch, ..., num_points_per_ray, 3)` + denoting the color of each ray point. + """ + # raymarch_features.shape = [minibatch x ... x pts_per_ray x 3] + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + features = self._net(raymarch_features) + # features.shape = [minibatch x ... x self.n_hidden_units] + + if self.ray_dir_in_camera_coords: + if camera is None: + raise ValueError("Camera must be given if xyz_ray_dir_in_camera_coords") + + directions = ray_bundle.directions @ camera.R + else: + directions = ray_bundle.directions + + # NNs operate on the flattenned rays; reshaping to the correct spatial size + features = features.reshape(*raymarch_features.shape[:-1], -1) + + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + raw_densities = self._density_layer(features) + + rays_colors = self._get_colors(features, directions) + + return raw_densities, rays_colors + + +class SRNRaymarchHyperNet(Configurable, torch.nn.Module): + """ + This is a raymarching function which has a forward like SRNRaymarchFunction + but instead of the weights being parameters of the module, they + are the output of another network, the hypernet, which takes the global_code + as input. All the dataclass members of SRNRaymarchFunction are here with the + same meaning. In addition, there are members with names ending `_hypernet` + which affect the hypernet. + + Because this class may be called repeatedly for the same global_code, the + output of the hypernet is cached in self.cached_srn_raymarch_function. + This member must be manually set to None whenever the global_code changes. + """ + + n_harmonic_functions: int = 3 # 0 means raw 3D coord inputs + n_hidden_units: int = 256 + n_layers: int = 2 + n_hidden_units_hypernet: int = 256 + n_layers_hypernet: int = 1 + in_features: int = 3 + out_features: int = 256 + latent_dim_hypernet: int = 0 + latent_dim: int = 0 + xyz_in_camera_coords: bool = False + + def __post_init__(self): + super().__init__() + raymarch_input_embedding_dim = ( + HarmonicEmbedding.get_output_dim_static( + self.in_features, + self.n_harmonic_functions, + True, + ) + + self.latent_dim + ) + + self._hypernet = hyperlayers.HyperFC( + hyper_in_ch=self.latent_dim_hypernet, + hyper_num_hidden_layers=self.n_layers_hypernet, + hyper_hidden_ch=self.n_hidden_units_hypernet, + hidden_ch=self.n_hidden_units, + num_hidden_layers=self.n_layers, + in_ch=raymarch_input_embedding_dim, + out_ch=self.n_hidden_units, + ) + + self.cached_srn_raymarch_function: Optional[Tuple[SRNRaymarchFunction]] = None + + def _run_hypernet(self, global_code: torch.Tensor) -> Tuple[SRNRaymarchFunction]: + """ + Runs the hypernet and returns a 1-tuple containing the generated + srn_raymarch_function. + """ + + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + net = self._hypernet(global_code) + + # use the hyper-net generated network to instantiate the raymarch module + srn_raymarch_function = SRNRaymarchFunction( + n_harmonic_functions=self.n_harmonic_functions, + n_hidden_units=self.n_hidden_units, + n_layers=self.n_layers, + in_features=self.in_features, + out_features=self.out_features, + latent_dim=self.latent_dim, + xyz_in_camera_coords=self.xyz_in_camera_coords, + raymarch_function=(net,), + ) + + # move the generated raymarch function to the correct device + srn_raymarch_function.to(global_code.device) + + return (srn_raymarch_function,) + + def forward( + self, + ray_bundle: RayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + **kwargs, + ): + + if global_code is None: + raise ValueError("SRN Hypernetwork requires a non-trivial global code.") + + # The raymarching network is cached in case the function is called repeatedly + # across LSTM iterations for the same global_code. + if self.cached_srn_raymarch_function is None: + # generate the raymarching network from the hypernet + # pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute + self.cached_srn_raymarch_function = self._run_hypernet(global_code) + (srn_raymarch_function,) = cast( + Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function + ) + + return srn_raymarch_function( + ray_bundle=ray_bundle, + fun_viewpool=fun_viewpool, + camera=camera, + global_code=None, # the hypernetwork takes the global code + ) + + +@registry.register +# pyre-fixme[13]: Uninitialized attribute +class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): + raymarch_function: SRNRaymarchFunction + pixel_generator: SRNPixelGenerator + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + def forward( + self, + ray_bundle: RayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + raymarch_features: Optional[torch.Tensor] = None, + **kwargs, + ): + predict_colors = raymarch_features is not None + if predict_colors: + return self.pixel_generator( + raymarch_features=raymarch_features, + ray_bundle=ray_bundle, + camera=camera, + **kwargs, + ) + else: + return self.raymarch_function( + ray_bundle=ray_bundle, + fun_viewpool=fun_viewpool, + camera=camera, + global_code=global_code, + **kwargs, + ) + + +@registry.register +# pyre-fixme[13]: Uninitialized attribute +class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module): + """ + This implicit function uses a hypernetwork to generate the + SRNRaymarchingFunction, and this is cached. Whenever the + global_code changes, `on_bind_args` must be called to clear + the cache. + """ + + hypernet: SRNRaymarchHyperNet + pixel_generator: SRNPixelGenerator + + def __post_init__(self): + super().__init__() + run_auto_creation(self) + + def forward( + self, + ray_bundle: RayBundle, + fun_viewpool=None, + camera: Optional[CamerasBase] = None, + global_code=None, + raymarch_features: Optional[torch.Tensor] = None, + **kwargs, + ): + predict_colors = raymarch_features is not None + if predict_colors: + return self.pixel_generator( + raymarch_features=raymarch_features, + ray_bundle=ray_bundle, + camera=camera, + **kwargs, + ) + else: + return self.hypernet( + ray_bundle=ray_bundle, + fun_viewpool=fun_viewpool, + camera=camera, + global_code=global_code, + **kwargs, + ) + + def on_bind_args(self): + """ + The global_code may have changed, so we reset the hypernet. + """ + self.hypernet.cached_srn_raymarch_function = None diff --git a/pytorch3d/implicitron/models/implicit_function/utils.py b/pytorch3d/implicitron/models/implicit_function/utils.py new file mode 100644 index 00000000..f4d440e5 --- /dev/null +++ b/pytorch3d/implicitron/models/implicit_function/utils.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Callable, Optional + +import torch +from pytorch3d.renderer.cameras import CamerasBase + + +def broadcast_global_code(embeds: torch.Tensor, global_code: torch.Tensor): + """ + Expands the `global_code` of shape (minibatch, dim) + so that it can be appended to `embeds` of shape (minibatch, ..., dim2), + and appends to the last dimension of `embeds`. + """ + bs = embeds.shape[0] + global_code_broadcast = global_code.view(bs, *([1] * (embeds.ndim - 2)), -1).expand( + *embeds.shape[:-1], + global_code.shape[-1], + ) + return torch.cat([embeds, global_code_broadcast], dim=-1) + + +def create_embeddings_for_implicit_function( + xyz_world: torch.Tensor, + xyz_in_camera_coords: bool, + global_code: Optional[torch.Tensor], + camera: Optional[CamerasBase], + fun_viewpool: Optional[Callable], + xyz_embedding_function: Optional[Callable], +) -> torch.Tensor: + + bs, *spatial_size, pts_per_ray, _ = xyz_world.shape + + if xyz_in_camera_coords: + if camera is None: + raise ValueError("Camera must be given if xyz_in_camera_coords") + + ray_points_for_embed = ( + camera.get_world_to_view_transform() + .transform_points(xyz_world.view(bs, -1, 3)) + .view(xyz_world.shape) + ) + else: + ray_points_for_embed = xyz_world + + if xyz_embedding_function is None: + embeds = torch.empty( + bs, + 1, + math.prod(spatial_size), + pts_per_ray, + 0, + dtype=xyz_world.dtype, + device=xyz_world.device, + ) + else: + embeds = xyz_embedding_function(ray_points_for_embed).reshape( + bs, + 1, + math.prod(spatial_size), + pts_per_ray, + -1, + ) # flatten spatial, add n_src dim + + if fun_viewpool is not None: + # viewpooling + embeds_viewpooled = fun_viewpool(xyz_world.reshape(bs, -1, 3)) + embed_shape = ( + bs, + embeds_viewpooled.shape[1], + math.prod(spatial_size), + pts_per_ray, + -1, + ) + embeds_viewpooled = embeds_viewpooled.reshape(*embed_shape) + if embeds is not None: + embeds = torch.cat([embeds.expand(*embed_shape), embeds_viewpooled], dim=-1) + else: + embeds = embeds_viewpooled + + if global_code is not None: + # append the broadcasted global code to embeds + embeds = broadcast_global_code(embeds, global_code) + + return embeds diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py new file mode 100644 index 00000000..2d8a978c --- /dev/null +++ b/pytorch3d/implicitron/models/metrics.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import warnings +from typing import Dict, Optional + +import torch +from pytorch3d.implicitron.tools import metric_utils as utils +from pytorch3d.renderer import utils as rend_utils + + +class ViewMetrics(torch.nn.Module): + def forward( + self, + image_sampling_grid: torch.Tensor, + images: Optional[torch.Tensor] = None, + images_pred: Optional[torch.Tensor] = None, + depths: Optional[torch.Tensor] = None, + depths_pred: Optional[torch.Tensor] = None, + masks: Optional[torch.Tensor] = None, + masks_pred: Optional[torch.Tensor] = None, + masks_crop: Optional[torch.Tensor] = None, + grad_theta: Optional[torch.Tensor] = None, + density_grid: Optional[torch.Tensor] = None, + keys_prefix: str = "loss_", + mask_renders_by_pred: bool = False, + ) -> Dict[str, torch.Tensor]: + """ + Calculates various differentiable metrics useful for supervising + differentiable rendering pipelines. + + Args: + image_sampling_grid: A tensor of shape `(B, ..., 2)` containing 2D + image locations at which the predictions are defined. + All ground truth inputs are sampled at these + locations in order to extract values that correspond + to the predictions. + images: A tensor of shape `(B, H, W, 3)` containing ground truth + rgb values. + images_pred: A tensor of shape `(B, ..., 3)` containing predicted + rgb values. + depths: A tensor of shape `(B, Hd, Wd, 1)` containing ground truth + depth values. + depths_pred: A tensor of shape `(B, ..., 1)` containing predicted + depth values. + masks: A tensor of shape `(B, Hm, Wm, 1)` containing ground truth + foreground masks. + masks_pred: A tensor of shape `(B, ..., 1)` containing predicted + foreground masks. + grad_theta: A tensor of shape `(B, ..., 3)` containing an evaluation + of a gradient of a signed distance function w.r.t. + input 3D coordinates used to compute the eikonal loss. + density_grid: A tensor of shape `(B, Hg, Wg, Dg, 1)` containing a + `Hg x Wg x Dg` voxel grid of density values. + keys_prefix: A common prefix for all keys in the output dictionary + containing all metrics. + mask_renders_by_pred: If `True`, masks rendered images by the predicted + `masks_pred` prior to computing all rgb metrics. + + Returns: + metrics: A dictionary `{metric_name_i: metric_value_i}` keyed by the + names of the output metrics `metric_name_i` with their corresponding + values `metric_value_i` represented as 0-dimensional float tensors. + + The calculated metrics are: + rgb_huber: A robust huber loss between `image_pred` and `image`. + rgb_mse: Mean squared error between `image_pred` and `image`. + rgb_psnr: Peak signal-to-noise ratio between `image_pred` and `image`. + rgb_psnr_fg: Peak signal-to-noise ratio between the foreground + region of `image_pred` and `image` as defined by `mask`. + rgb_mse_fg: Mean squared error between the foreground + region of `image_pred` and `image` as defined by `mask`. + mask_neg_iou: (1 - intersection-over-union) between `mask_pred` + and `mask`. + mask_bce: Binary cross entropy between `mask_pred` and `mask`. + mask_beta_prior: A loss enforcing strictly binary values + of `mask_pred`: `log(mask_pred) + log(1-mask_pred)` + depth_abs: Mean per-pixel L1 distance between + `depth_pred` and `depth`. + depth_abs_fg: Mean per-pixel L1 distance between the foreground + region of `depth_pred` and `depth` as defined by `mask`. + eikonal: Eikonal regularizer `(||grad_theta|| - 1)**2`. + density_tv: The Total Variation regularizer of density + values in `density_grid` (sum of L1 distances of values + of all 4-neighbouring cells). + depth_neg_penalty: `min(depth_pred, 0)**2` penalizing negative + predicted depth values. + """ + + # TODO: extract functions + + # reshape from B x ... x DIM to B x DIM x -1 x 1 + images_pred, masks_pred, depths_pred = [ + _reshape_nongrid_var(x) for x in [images_pred, masks_pred, depths_pred] + ] + # reshape the sampling grid as well + # TODO: we can get rid of the singular dimension here and in _reshape_nongrid_var + # now that we use rend_utils.ndc_grid_sample + image_sampling_grid = image_sampling_grid.reshape( + image_sampling_grid.shape[0], -1, 1, 2 + ) + + # closure with the given image_sampling_grid + def sample(tensor, mode): + if tensor is None: + return tensor + return rend_utils.ndc_grid_sample(tensor, image_sampling_grid, mode=mode) + + # eval all results in this size + images = sample(images, mode="bilinear") + depths = sample(depths, mode="nearest") + masks = sample(masks, mode="nearest") + masks_crop = sample(masks_crop, mode="nearest") + if masks_crop is None and images_pred is not None: + masks_crop = torch.ones_like(images_pred[:, :1]) + if masks_crop is None and depths_pred is not None: + masks_crop = torch.ones_like(depths_pred[:, :1]) + + preds = {} + if images is not None and images_pred is not None: + # TODO: mask_renders_by_pred is always false; simplify + preds.update( + _rgb_metrics( + images, + images_pred, + masks, + masks_pred, + masks_crop, + mask_renders_by_pred, + ) + ) + + if masks_pred is not None: + preds["mask_beta_prior"] = utils.beta_prior(masks_pred) + if masks is not None and masks_pred is not None: + preds["mask_neg_iou"] = utils.neg_iou_loss( + masks_pred, masks, mask=masks_crop + ) + preds["mask_bce"] = utils.calc_bce(masks_pred, masks, mask=masks_crop) + + if depths is not None and depths_pred is not None: + assert masks_crop is not None + _, abs_ = utils.eval_depth( + depths_pred, depths, get_best_scale=True, mask=masks_crop, crop=0 + ) + preds["depth_abs"] = abs_.mean() + + if masks is not None: + mask = masks * masks_crop + _, abs_ = utils.eval_depth( + depths_pred, depths, get_best_scale=True, mask=mask, crop=0 + ) + preds["depth_abs_fg"] = abs_.mean() + + # regularizers + if grad_theta is not None: + preds["eikonal"] = _get_eikonal_loss(grad_theta) + + if density_grid is not None: + preds["density_tv"] = _get_grid_tv_loss(density_grid) + + if depths_pred is not None: + preds["depth_neg_penalty"] = _get_depth_neg_penalty_loss(depths_pred) + + if keys_prefix is not None: + preds = {(keys_prefix + k): v for k, v in preds.items()} + + return preds + + +def _rgb_metrics( + images, images_pred, masks, masks_pred, masks_crop, mask_renders_by_pred +): + assert masks_crop is not None + if mask_renders_by_pred: + images = images[..., masks_pred.reshape(-1), :] + masks_crop = masks_crop[..., masks_pred.reshape(-1), :] + masks = masks is not None and masks[..., masks_pred.reshape(-1), :] + rgb_squared = ((images_pred - images) ** 2).mean(dim=1, keepdim=True) + rgb_loss = utils.huber(rgb_squared, scaling=0.03) + crop_mass = masks_crop.sum().clamp(1.0) + # print("IMAGE:", images.mean().item(), images_pred.mean().item()) # TEMP + preds = { + "rgb_huber": (rgb_loss * masks_crop).sum() / crop_mass, + "rgb_mse": (rgb_squared * masks_crop).sum() / crop_mass, + "rgb_psnr": utils.calc_psnr(images_pred, images, mask=masks_crop), + } + if masks is not None: + masks = masks_crop * masks + preds["rgb_psnr_fg"] = utils.calc_psnr(images_pred, images, mask=masks) + preds["rgb_mse_fg"] = (rgb_squared * masks).sum() / masks.sum().clamp(1.0) + return preds + + +def _get_eikonal_loss(grad_theta): + return ((grad_theta.norm(2, dim=1) - 1) ** 2).mean() + + +def _get_grid_tv_loss(grid, log_domain: bool = True, eps: float = 1e-5): + if log_domain: + if (grid <= -eps).any(): + warnings.warn("Grid has negative values; this will produce NaN loss") + grid = torch.log(grid + eps) + + # this is an isotropic version, note that it ignores last rows/cols + return torch.mean( + utils.safe_sqrt( + (grid[..., :-1, :-1, 1:] - grid[..., :-1, :-1, :-1]) ** 2 + + (grid[..., :-1, 1:, :-1] - grid[..., :-1, :-1, :-1]) ** 2 + + (grid[..., 1:, :-1, :-1] - grid[..., :-1, :-1, :-1]) ** 2, + eps=1e-5, + ) + ) + + +def _get_depth_neg_penalty_loss(depth): + neg_penalty = depth.clamp(min=None, max=0.0) ** 2 + return torch.mean(neg_penalty) + + +def _reshape_nongrid_var(x): + if x is None: + return None + + ba, *_, dim = x.shape + return x.reshape(ba, -1, 1, dim).permute(0, 3, 1, 2).contiguous() diff --git a/pytorch3d/implicitron/models/model_dbir.py b/pytorch3d/implicitron/models/model_dbir.py new file mode 100644 index 00000000..7f3031dc --- /dev/null +++ b/pytorch3d/implicitron/models/model_dbir.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict, List + +import torch +from pytorch3d.implicitron.dataset.utils import is_known_frame +from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import ( + NewViewSynthesisPrediction, +) +from pytorch3d.implicitron.tools.point_cloud_utils import ( + get_rgbd_point_cloud, + render_point_cloud_pytorch3d, +) +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds + + +class ModelDBIR(torch.nn.Module): + """ + A simple depth-based image rendering model. + """ + + def __init__( + self, + image_size: int = 256, + bg_color: float = 0.0, + max_points: int = -1, + ): + """ + Initializes a simple DBIR model. + + Args: + image_size: The size of the rendered rectangular images. + bg_color: The color of the background. + max_points: Maximum number of points in the point cloud + formed by unprojecting all source view depths. + If more points are present, they are randomly subsampled + to #max_size points without replacement. + """ + + super().__init__() + self.image_size = image_size + self.bg_color = bg_color + self.max_points = max_points + + def forward( + self, + camera: CamerasBase, + image_rgb: torch.Tensor, + depth_map: torch.Tensor, + fg_probability: torch.Tensor, + frame_type: List[str], + **kwargs, + ) -> Dict[str, Any]: # TODO: return a namedtuple or dataclass + """ + Given a set of input source cameras images and depth maps, unprojects + all RGBD maps to a colored point cloud and renders into the target views. + + Args: + camera: A batch of `N` PyTorch3D cameras. + image_rgb: A batch of `N` images of shape `(N, 3, H, W)`. + depth_map: A batch of `N` depth maps of shape `(N, 1, H, W)`. + fg_probability: A batch of `N` foreground probability maps + of shape `(N, 1, H, W)`. + frame_type: A list of `N` strings containing frame type indicators + which specify target and source views. + + Returns: + preds: A dict with the following fields: + nvs_prediction: The rendered colors, depth and mask + of the target views. + point_cloud: The point cloud of the scene. It's renders are + stored in `nvs_prediction`. + """ + + is_known = is_known_frame(frame_type) + is_known_idx = torch.where(is_known)[0] + + mask_fg = (fg_probability > 0.5).type_as(image_rgb) + + point_cloud = get_rgbd_point_cloud( + camera[is_known_idx], + image_rgb[is_known_idx], + depth_map[is_known_idx], + mask_fg[is_known_idx], + ) + + pcl_size = int(point_cloud.num_points_per_cloud()) + if (self.max_points > 0) and (pcl_size > self.max_points): + prm = torch.randperm(pcl_size)[: self.max_points] + point_cloud = Pointclouds( + point_cloud.points_padded()[:, prm, :], + # pyre-fixme[16]: Optional type has no attribute `__getitem__`. + features=point_cloud.features_padded()[:, prm, :], + ) + + is_target_idx = torch.where(~is_known)[0] + + depth_render, image_render, mask_render = [], [], [] + + # render into target frames in a for loop to save memory + for tgt_idx in is_target_idx: + _image_render, _mask_render, _depth_render = render_point_cloud_pytorch3d( + camera[int(tgt_idx)], + point_cloud, + render_size=(self.image_size, self.image_size), + point_radius=1e-2, + topk=10, + bg_color=self.bg_color, + ) + _image_render = _image_render.clamp(0.0, 1.0) + # the mask is the set of pixels with opacity bigger than eps + _mask_render = (_mask_render > 1e-4).float() + + depth_render.append(_depth_render) + image_render.append(_image_render) + mask_render.append(_mask_render) + + nvs_prediction = NewViewSynthesisPrediction( + **{ + k: torch.cat(v, dim=0) + for k, v in zip( + ["depth_render", "image_render", "mask_render"], + [depth_render, image_render, mask_render], + ) + } + ) + + preds = { + "nvs_prediction": nvs_prediction, + "point_cloud": point_cloud, + } + + return preds diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py new file mode 100644 index 00000000..f14d7231 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from pytorch3d.implicitron.tools.config import ReplaceableBase + + +class EvaluationMode(Enum): + TRAINING = "training" + EVALUATION = "evaluation" + + +class RenderSamplingMode(Enum): + MASK_SAMPLE = "mask_sample" + FULL_GRID = "full_grid" + + +@dataclass +class RendererOutput: + """ + A structure for storing the output of a renderer. + + Args: + features: rendered features (usually RGB colors), (B, ..., C) tensor. + depth: rendered ray-termination depth map, in NDC coordinates, (B, ..., 1) tensor. + mask: rendered object mask, values in [0, 1], (B, ..., 1) tensor. + prev_stage: for multi-pass renderers (e.g. in NeRF), + a reference to the output of the previous stage. + normals: surface normals, for renderers that estimate them; (B, ..., 3) tensor. + points: ray-termination points in the world coordinates, (B, ..., 3) tensor. + aux: dict for implementation-specific renderer outputs. + """ + + features: torch.Tensor + depths: torch.Tensor + masks: torch.Tensor + prev_stage: Optional[RendererOutput] = None + normals: Optional[torch.Tensor] = None + points: Optional[torch.Tensor] = None # TODO: redundant with depths + aux: Dict[str, Any] = field(default_factory=lambda: {}) + + +class ImplicitFunctionWrapper(torch.nn.Module): + def __init__(self, fn: torch.nn.Module): + super().__init__() + self._fn = fn + self.bound_args = {} + + def bind_args(self, **bound_args): + self.bound_args = bound_args + self._fn.on_bind_args() + + def unbind_args(self): + self.bound_args = {} + + def forward(self, *args, **kwargs): + return self._fn(*args, **{**kwargs, **self.bound_args}) + + +class BaseRenderer(ABC, ReplaceableBase): + """ + Base class for all Renderer implementations. + """ + + def __init__(self): + super().__init__() + + @abstractmethod + def forward( + self, + ray_bundle, + implicit_functions: List[ImplicitFunctionWrapper], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs + ) -> RendererOutput: + """ + Each Renderer should implement its own forward function + that returns an instance of RendererOutput. + + Args: + ray_bundle: A RayBundle object containing the following variables: + origins: A tensor of shape (minibatch, ..., 3) denoting + the origins of the rendering rays. + directions: A tensor of shape (minibatch, ..., 3) + containing the direction vectors of rendering rays. + lengths: A tensor of shape + (minibatch, ..., num_points_per_ray)containing the + lengths at which the ray points are sampled. + The coordinates of the points on the rays are thus computed + as `origins + lengths * directions`. + xys: A tensor of shape + (minibatch, ..., 2) containing the + xy locations of each ray's pixel in the NDC screen space. + implicit_functions: List of ImplicitFunctionWrappers which define the + implicit function methods to be used. Most Renderers only allow + a single implicit function. Currently, only the MultiPassEARenderer + allows specifying mulitple values in the list. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering. + **kwargs: In addition to the name args, custom keyword args can be specified. + For example in the SignedDistanceFunctionRenderer, an object_mask is + required which needs to be passed via the kwargs. + + Returns: + instance of RendererOutput + """ + pass diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py new file mode 100644 index 00000000..17c667a8 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional, Tuple + +import torch +from pytorch3d.implicitron.tools.config import registry +from pytorch3d.renderer import RayBundle + +from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput + + +@registry.register +class LSTMRenderer(BaseRenderer, torch.nn.Module): + """ + Implements the learnable LSTM raymarching function from SRN [1]. + + Settings: + num_raymarch_steps: The number of LSTM raymarching steps. + init_depth: Initializes the bias of the last raymarching LSTM layer so that + the farthest point from the camera reaches a far z-plane that + lies `init_depth` units from the camera plane. + init_depth_noise_std: The standard deviation of the random normal noise + added to the initial depth of each marched ray. + hidden_size: The dimensionality of the LSTM's hidden state. + n_feature_channels: The number of feature channels returned by the + implicit_function evaluated at each raymarching step. + verbose: If `True`, prints raymarching debug info. + + References: + [1] Sitzmann, V. and Zollhöfer, M. and Wetzstein, G.. + "Scene representation networks: Continuous 3d-structure-aware + neural scene representations." NeurIPS 2019. + """ + + num_raymarch_steps: int = 10 + init_depth: float = 17.0 + init_depth_noise_std: float = 5e-4 + hidden_size: int = 16 + n_feature_channels: int = 256 + verbose: bool = False + + def __post_init__(self): + super().__init__() + self._lstm = torch.nn.LSTMCell( + input_size=self.n_feature_channels, + hidden_size=self.hidden_size, + ) + self._lstm.apply(_init_recurrent_weights) + _lstm_forget_gate_init(self._lstm) + self._out_layer = torch.nn.Linear(self.hidden_size, 1) + + one_step = self.init_depth / self.num_raymarch_steps + self._out_layer.bias.data.fill_(one_step) + self._out_layer.weight.data.normal_(mean=0.0, std=1e-3) + + def forward( + self, + ray_bundle: RayBundle, + implicit_functions: List[ImplicitFunctionWrapper], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs, + ) -> RendererOutput: + """ + + Args: + ray_bundle: A `RayBundle` object containing the parametrizations of the + sampled rendering rays. + implicit_functions: A single-element list of ImplicitFunctionWrappers which + defines the implicit function to be used. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering, specifically the RayPointRefiner and the density_noise_std. + + Returns: + instance of RendererOutput + """ + if len(implicit_functions) != 1: + raise ValueError("LSTM renderer expects a single implicit function.") + + implicit_function = implicit_functions[0] + + if ray_bundle.lengths.shape[-1] != 1: + raise ValueError( + "LSTM renderer requires a ray-bundle with a single point per ray" + + " which is the initial raymarching point." + ) + + # jitter the initial depths + ray_bundle_t = ray_bundle._replace( + lengths=ray_bundle.lengths + + torch.randn_like(ray_bundle.lengths) * self.init_depth_noise_std + ) + + states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [None] + signed_distance = torch.zeros_like(ray_bundle_t.lengths) + raymarch_features = None + for t in range(self.num_raymarch_steps + 1): + # move signed_distance along each ray + ray_bundle_t = ray_bundle_t._replace( + lengths=ray_bundle_t.lengths + signed_distance + ) + + # eval the raymarching function + raymarch_features, _ = implicit_function( + ray_bundle_t, + raymarch_features=None, + ) + if self.verbose: + # print some stats + print( + f"{t}: mu={float(signed_distance.mean()):1.2e};" + + f" std={float(signed_distance.std()):1.2e};" + # pyre-fixme[6]: Expected `Union[bytearray, bytes, str, + # typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st + # param but got `Tensor`. + + f" mu_d={float(ray_bundle_t.lengths.mean()):1.2e};" + # pyre-fixme[6]: Expected `Union[bytearray, bytes, str, + # typing.SupportsFloat, typing_extensions.SupportsIndex]` for 1st + # param but got `Tensor`. + + f" std_d={float(ray_bundle_t.lengths.std()):1.2e};" + ) + if t == self.num_raymarch_steps: + break + + # run the lstm marcher + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + state_h, state_c = self._lstm( + raymarch_features.view(-1, raymarch_features.shape[-1]), + states[-1], + ) + if state_h.requires_grad: + state_h.register_hook(lambda x: x.clamp(min=-10, max=10)) + # predict the next step size + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape) + # log the lstm states + states.append((state_h, state_c)) + + opacity_logits, features = implicit_function( + raymarch_features=raymarch_features, + ray_bundle=ray_bundle_t, + ) + mask = torch.sigmoid(opacity_logits) + depth = ray_bundle_t.lengths * ray_bundle_t.directions.norm( + dim=-1, keepdim=True + ) + + return RendererOutput( + features=features[..., 0, :], + depths=depth, + masks=mask[..., 0, :], + ) + + +def _init_recurrent_weights(self) -> None: + # copied from SRN codebase + for m in self.modules(): + if type(m) in [torch.nn.GRU, torch.nn.LSTM, torch.nn.RNN]: + for name, param in m.named_parameters(): + if "weight_ih" in name: + torch.nn.init.kaiming_normal_(param.data) + elif "weight_hh" in name: + torch.nn.init.orthogonal_(param.data) + elif "bias" in name: + param.data.fill_(0) + + +def _lstm_forget_gate_init(lstm_layer) -> None: + # copied from SRN codebase + for name, parameter in lstm_layer.named_parameters(): + if "bias" not in name: + continue + n = parameter.size(0) + start, end = n // 4, n // 2 + parameter.data[start:end].fill_(1.0) diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py new file mode 100644 index 00000000..84872e56 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple + +import torch +from pytorch3d.implicitron.tools.config import registry + +from .base import BaseRenderer, EvaluationMode, RendererOutput +from .ray_point_refiner import RayPointRefiner +from .raymarcher import GenericRaymarcher + + +@registry.register +class MultiPassEmissionAbsorptionRenderer(BaseRenderer, torch.nn.Module): + """ + Implements the multi-pass rendering function, in particular, + with emission-absorption ray marching used in NeRF [1]. First, it evaluates + opacity-based ray-point weights and then optionally (in case more implicit + functions are given) resamples points using importance sampling and evaluates + new weights. + + During each ray marching pass, features, depth map, and masks + are integrated: Let o_i be the opacity estimated by the implicit function, + and d_i be the offset between points `i` and `i+1` along the respective ray. + Ray marching is performed using the following equations: + ``` + ray_opacity_n = cap_fn(sum_i=1^n cap_fn(d_i * o_i)), + weight_n = weight_fn(cap_fn(d_i * o_i), 1 - ray_opacity_{n-1}), + ``` + and the final rendered quantities are computed by a dot-product of ray values + with the weights, e.g. `features = sum_n(weight_n * ray_features_n)`. + See below for possible values of `cap_fn` and `weight_fn`. + + Settings: + n_pts_per_ray_fine_training: The number of points sampled per ray for the + fine rendering pass during training. + n_pts_per_ray_fine_evaluation: The number of points sampled per ray for the + fine rendering pass during evaluation. + stratified_sampling_coarse_training: Enable/disable stratified sampling during + training. + stratified_sampling_coarse_evaluation: Enable/disable stratified sampling during + evaluation. + append_coarse_samples_to_fine: Add the fine ray points to the coarse points + after sampling. + bg_color: The background color. A tuple of either 1 element or of D elements, + where D matches the feature dimensionality; it is broadcasted when necessary. + density_noise_std_train: Standard deviation of the noise added to the + opacity field. + capping_function: The capping function of the raymarcher. + Options: + - "exponential" (`cap_fn(x) = 1 - exp(-x)`) + - "cap1" (`cap_fn(x) = min(x, 1)`) + Set to "exponential" for the standard Emission Absorption raymarching. + weight_function: The weighting function of the raymarcher. + Options: + - "product" (`weight_fn(w, x) = w * x`) + - "minimum" (`weight_fn(w, x) = min(w, x)`) + Set to "product" for the standard Emission Absorption raymarching. + background_opacity: The raw opacity value (i.e. before exponentiation) + of the background. + blend_output: If `True`, alpha-blends the output renders with the + background color using the rendered opacity mask. + + References: + [1] Mildenhall, Ben, et al. "Nerf: Representing scenes as neural radiance + fields for view synthesis." ECCV 2020. + + """ + + n_pts_per_ray_fine_training: int = 64 + n_pts_per_ray_fine_evaluation: int = 64 + stratified_sampling_coarse_training: bool = True + stratified_sampling_coarse_evaluation: bool = False + append_coarse_samples_to_fine: bool = True + bg_color: Tuple[float, ...] = (0.0,) + density_noise_std_train: float = 0.0 + capping_function: str = "exponential" # exponential | cap1 + weight_function: str = "product" # product | minimum + background_opacity: float = 1e10 + blend_output: bool = False + + def __post_init__(self): + super().__init__() + self._refiners = { + EvaluationMode.TRAINING: RayPointRefiner( + n_pts_per_ray=self.n_pts_per_ray_fine_training, + random_sampling=self.stratified_sampling_coarse_training, + add_input_samples=self.append_coarse_samples_to_fine, + ), + EvaluationMode.EVALUATION: RayPointRefiner( + n_pts_per_ray=self.n_pts_per_ray_fine_evaluation, + random_sampling=self.stratified_sampling_coarse_evaluation, + add_input_samples=self.append_coarse_samples_to_fine, + ), + } + + self._raymarcher = GenericRaymarcher( + 1, + self.bg_color, + capping_function=self.capping_function, + weight_function=self.weight_function, + background_opacity=self.background_opacity, + blend_output=self.blend_output, + ) + + def forward( + self, + ray_bundle, + implicit_functions=[], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + **kwargs + ) -> RendererOutput: + """ + Args: + ray_bundle: A `RayBundle` object containing the parametrizations of the + sampled rendering rays. + implicit_functions: List of ImplicitFunctionWrappers which + define the implicit functions to be used sequentially in + the raymarching step. The output of raymarching with + implicit_functions[n-1] is refined, and then used as + input for raymarching with implicit_functions[n]. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering + + Returns: + instance of RendererOutput + """ + if not implicit_functions: + raise ValueError("EA renderer expects implicit functions") + + return self._run_raymarcher( + ray_bundle, + implicit_functions, + None, + evaluation_mode, + ) + + def _run_raymarcher( + self, ray_bundle, implicit_functions, prev_stage, evaluation_mode + ): + density_noise_std = ( + self.density_noise_std_train + if evaluation_mode == EvaluationMode.TRAINING + else 0.0 + ) + + features, depth, mask, weights, aux = self._raymarcher( + *implicit_functions[0](ray_bundle), + ray_lengths=ray_bundle.lengths, + density_noise_std=density_noise_std, + ) + output = RendererOutput( + features=features, depths=depth, masks=mask, aux=aux, prev_stage=prev_stage + ) + + # we may need to make a recursive call + if len(implicit_functions) > 1: + fine_ray_bundle = self._refiners[evaluation_mode](ray_bundle, weights) + output = self._run_raymarcher( + fine_ray_bundle, + implicit_functions[1:], + output, + evaluation_mode, + ) + + return output diff --git a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py new file mode 100644 index 00000000..55fbc8d6 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from pytorch3d.implicitron.tools.config import Configurable, expand_args_fields +from pytorch3d.renderer import RayBundle +from pytorch3d.renderer.implicit.sample_pdf import sample_pdf + + +@expand_args_fields +# pyre-fixme[13]: Attribute `n_pts_per_ray` is never initialized. +# pyre-fixme[13]: Attribute `random_sampling` is never initialized. +class RayPointRefiner(Configurable, torch.nn.Module): + """ + Implements the importance sampling of points along rays. + The input is a `RayBundle` object with a `ray_weights` tensor + which specifies the probabilities of sampling a point along each ray. + + This raysampler is used for the fine rendering pass of NeRF. + As such, the forward pass accepts the RayBundle output by the + raysampling of the coarse rendering pass. Hence, it does not + take cameras as input. + + Args: + n_pts_per_ray: The number of points to sample along each ray. + random_sampling: If `False`, returns equispaced percentiles of the + distribution defined by the input weights, otherwise performs + sampling from that distribution. + add_input_samples: Concatenates and returns the sampled values + together with the input samples. + """ + + n_pts_per_ray: int + random_sampling: bool + add_input_samples: bool = True + + def __post_init__(self) -> None: + super().__init__() + + def forward( + self, + input_ray_bundle: RayBundle, + ray_weights: torch.Tensor, + **kwargs, + ) -> RayBundle: + """ + Args: + input_ray_bundle: An instance of `RayBundle` specifying the + source rays for sampling of the probability distribution. + ray_weights: A tensor of shape + `(..., input_ray_bundle.legths.shape[-1])` with non-negative + elements defining the probability distribution to sample + ray points from. + + Returns: + ray_bundle: A new `RayBundle` instance containing the input ray + points together with `n_pts_per_ray` additionally sampled + points per ray. For each ray, the lengths are sorted. + """ + + z_vals = input_ray_bundle.lengths + with torch.no_grad(): + z_vals_mid = torch.lerp(z_vals[..., 1:], z_vals[..., :-1], 0.5) + z_samples = sample_pdf( + z_vals_mid.view(-1, z_vals_mid.shape[-1]), + ray_weights.view(-1, ray_weights.shape[-1])[..., 1:-1], + self.n_pts_per_ray, + det=not self.random_sampling, + ).view(*z_vals.shape[:-1], self.n_pts_per_ray) + + if self.add_input_samples: + # Add the new samples to the input ones. + z_vals = torch.cat((z_vals, z_samples), dim=-1) + else: + z_vals = z_samples + # Resort by depth. + z_vals, _ = torch.sort(z_vals, dim=-1) + + return RayBundle( + origins=input_ray_bundle.origins, + directions=input_ray_bundle.directions, + lengths=z_vals, + xys=input_ray_bundle.xys, + ) diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py new file mode 100644 index 00000000..4f5cee69 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import field +from typing import Optional, Tuple + +import torch +from pytorch3d.implicitron.tools import camera_utils +from pytorch3d.implicitron.tools.config import Configurable +from pytorch3d.renderer import NDCMultinomialRaysampler, RayBundle +from pytorch3d.renderer.cameras import CamerasBase + +from .base import EvaluationMode, RenderSamplingMode + + +class RaySampler(Configurable, torch.nn.Module): + """ + + Samples a fixed number of points along rays which are in turn sampled for + each camera in a batch. + + This class utilizes `NDCMultinomialRaysampler` which allows to either + randomly sample rays from an input foreground saliency mask + (`RenderSamplingMode.MASK_SAMPLE`), or on a rectangular image grid + (`RenderSamplingMode.FULL_GRID`). The sampling mode can be set separately + for training and evaluation by setting `self.sampling_mode_training` + and `self.sampling_mode_training` accordingly. + + The class allows two modes of sampling points along the rays: + 1) Sampling between fixed near and far z-planes: + Active when `self.scene_extent <= 0`, samples points along each ray + with approximately uniform spacing of z-coordinates between + the minimum depth `self.min_depth` and the maximum depth `self.max_depth`. + This sampling is useful for rendering scenes where the camera is + in a constant distance from the focal point of the scene. + 2) Adaptive near/far plane estimation around the world scene center: + Active when `self.scene_extent > 0`. Samples points on each + ray between near and far planes whose depths are determined based on + the distance from the camera center to a predefined scene center. + More specifically, + `min_depth = max( + (self.scene_center-camera_center).norm() - self.scene_extent, eps + )` and + `max_depth = (self.scene_center-camera_center).norm() + self.scene_extent`. + This sampling is ideal for object-centric scenes whose contents are + centered around a known `self.scene_center` and fit into a bounding sphere + with a radius of `self.scene_extent`. + + Similar to the sampling mode, the sampling parameters can be set separately + for training and evaluation. + + Settings: + image_width: The horizontal size of the image grid. + image_height: The vertical size of the image grid. + scene_center: The xyz coordinates of the center of the scene used + along with `scene_extent` to compute the min and max depth planes + for sampling ray-points. + scene_extent: The radius of the scene bounding sphere centered at `scene_center`. + If `scene_extent <= 0`, the raysampler samples points between + `self.min_depth` and `self.max_depth` depths instead. + sampling_mode_training: The ray sampling mode for training. This should be a str + option from the RenderSamplingMode Enum + sampling_mode_evaluation: Same as above but for evaluation. + n_pts_per_ray_training: The number of points sampled along each ray during training. + n_pts_per_ray_evaluation: The number of points sampled along each ray during evaluation. + n_rays_per_image_sampled_from_mask: The amount of rays to be sampled from the image grid + min_depth: The minimum depth of a ray-point. Active when `self.scene_extent > 0`. + max_depth: The maximum depth of a ray-point. Active when `self.scene_extent > 0`. + stratified_point_sampling_training: if set, performs stratified random sampling + along the ray; otherwise takes ray points at deterministic offsets. + stratified_point_sampling_evaluation: Same as above but for evaluation. + + """ + + image_width: int = 400 + image_height: int = 400 + scene_center: Tuple[float, float, float] = field( + default_factory=lambda: (0.0, 0.0, 0.0) + ) + scene_extent: float = 0.0 + sampling_mode_training: str = "mask_sample" + sampling_mode_evaluation: str = "full_grid" + n_pts_per_ray_training: int = 64 + n_pts_per_ray_evaluation: int = 64 + n_rays_per_image_sampled_from_mask: int = 1024 + min_depth: float = 0.1 + max_depth: float = 8.0 + # stratified sampling vs taking points at deterministic offsets + stratified_point_sampling_training: bool = True + stratified_point_sampling_evaluation: bool = False + + def __post_init__(self): + super().__init__() + self.scene_center = torch.FloatTensor(self.scene_center) + + self._sampling_mode = { + EvaluationMode.TRAINING: RenderSamplingMode(self.sampling_mode_training), + EvaluationMode.EVALUATION: RenderSamplingMode( + self.sampling_mode_evaluation + ), + } + + self._raysamplers = { + EvaluationMode.TRAINING: NDCMultinomialRaysampler( + image_width=self.image_width, + image_height=self.image_height, + n_pts_per_ray=self.n_pts_per_ray_training, + min_depth=self.min_depth, + max_depth=self.max_depth, + n_rays_per_image=self.n_rays_per_image_sampled_from_mask + if self._sampling_mode[EvaluationMode.TRAINING] + == RenderSamplingMode.MASK_SAMPLE + else None, + unit_directions=True, + stratified_sampling=self.stratified_point_sampling_training, + ), + EvaluationMode.EVALUATION: NDCMultinomialRaysampler( + image_width=self.image_width, + image_height=self.image_height, + n_pts_per_ray=self.n_pts_per_ray_evaluation, + min_depth=self.min_depth, + max_depth=self.max_depth, + n_rays_per_image=self.n_rays_per_image_sampled_from_mask + if self._sampling_mode[EvaluationMode.EVALUATION] + == RenderSamplingMode.MASK_SAMPLE + else None, + unit_directions=True, + stratified_sampling=self.stratified_point_sampling_evaluation, + ), + } + + def forward( + self, + cameras: CamerasBase, + evaluation_mode: EvaluationMode, + mask: Optional[torch.Tensor] = None, + ) -> RayBundle: + """ + + Args: + cameras: A batch of `batch_size` cameras from which the rays are emitted. + evaluation_mode: one of `EvaluationMode.TRAINING` or + `EvaluationMode.EVALUATION` which determines the sampling mode + that is used. + mask: Active for the `RenderSamplingMode.MASK_SAMPLE` sampling mode. + Defines a non-negative mask of shape + `(batch_size, image_height, image_width)` where each per-pixel + value is proportional to the probability of sampling the + corresponding pixel's ray. + + Returns: + ray_bundle: A `RayBundle` object containing the parametrizations of the + sampled rendering rays. + """ + sample_mask = None + if ( + # pyre-fixme[29] + self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE + and mask is not None + ): + sample_mask = torch.nn.functional.interpolate( + mask, + # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got + # `List[int]`. + size=[self.image_height, self.image_width], + mode="nearest", + )[:, 0] + + if self.scene_extent > 0.0: + # Override the min/max depth set in initialization based on the + # input cameras. + min_depth, max_depth = camera_utils.get_min_max_depth_bounds( + cameras, self.scene_center, self.scene_extent + ) + + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, + # torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor], + # torch.Tensor, torch.nn.Module]` is not a function. + ray_bundle = self._raysamplers[evaluation_mode]( + cameras=cameras, + mask=sample_mask, + min_depth=float(min_depth[0]) if self.scene_extent > 0.0 else None, + max_depth=float(max_depth[0]) if self.scene_extent > 0.0 else None, + ) + + return ray_bundle diff --git a/pytorch3d/implicitron/models/renderer/ray_tracing.py b/pytorch3d/implicitron/models/renderer/ray_tracing.py new file mode 100644 index 00000000..4dbc64c4 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/ray_tracing.py @@ -0,0 +1,573 @@ +# @lint-ignore-every LICENSELINT +# Adapted from https://github.com/lioryariv/idr +# Copyright (c) 2020 Lior Yariv + +from typing import Any, Callable, Tuple + +import torch +import torch.nn as nn +from pytorch3d.implicitron.tools.config import Configurable + + +class RayTracing(Configurable, nn.Module): + """ + Finds the intersection points of rays with the implicit surface defined + by a signed distance function (SDF). The algorithm follows the pipeline: + 1. Initialise start and end points on rays by the intersections with + the circumscribing sphere. + 2. Run sphere tracing from both ends. + 3. Divide the untraced segments of non-convergent rays into uniform + intervals and find the one with the sign transition. + 4. Run the secant method to estimate the point of the sign transition. + + Args: + object_bounding_sphere: The radius of the initial sphere circumscribing + the object. + sdf_threshold: Absolute SDF value small enough for the sphere tracer + to consider it a surface. + line_search_step: Length of the backward correction on sphere tracing + iterations. + line_step_iters: Number of backward correction iterations. + sphere_tracing_iters: Maximum number of sphere tracing iterations + (the actual number of iterations may be smaller if all ray + intersections are found). + n_steps: Number of intervals sampled for unconvergent rays. + n_secant_steps: Number of iterations in the secant algorithm. + """ + + object_bounding_sphere: float = 1.0 + sdf_threshold: float = 5.0e-5 + line_search_step: float = 0.5 + line_step_iters: int = 1 + sphere_tracing_iters: int = 10 + n_steps: int = 100 + n_secant_steps: int = 8 + + def __post_init__(self): + super().__init__() + + def forward( + self, + sdf: Callable[[torch.Tensor], torch.Tensor], + cam_loc: torch.Tensor, + object_mask: torch.BoolTensor, + ray_directions: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + sdf: A callable that takes a (N, 3) tensor of points and returns + a tensor of (N,) SDF values. + cam_loc: A tensor of (B, N, 3) ray origins. + object_mask: A (N, 3) tensor of indicators whether a sampled pixel + corresponds to the rendered object or background. + ray_directions: A tensor of (B, N, 3) ray directions. + + Returns: + curr_start_points: A tensor of (B*N, 3) found intersection points + with the implicit surface. + network_object_mask: A tensor of (B*N,) indicators denoting whether + intersections were found. + acc_start_dis: A tensor of (B*N,) distances from the ray origins + to intersrection points. + """ + batch_size, num_pixels, _ = ray_directions.shape + device = cam_loc.device + + sphere_intersections, mask_intersect = _get_sphere_intersection( + cam_loc, ray_directions, r=self.object_bounding_sphere + ) + + ( + curr_start_points, + unfinished_mask_start, + acc_start_dis, + acc_end_dis, + min_dis, + max_dis, + ) = self.sphere_tracing( + batch_size, + num_pixels, + sdf, + cam_loc, + ray_directions, + mask_intersect, + sphere_intersections, + ) + + network_object_mask = acc_start_dis < acc_end_dis + + # The non convergent rays should be handled by the sampler + sampler_mask = unfinished_mask_start + sampler_net_obj_mask = torch.zeros_like( + sampler_mask, dtype=torch.bool, device=device + ) + if sampler_mask.sum() > 0: + sampler_min_max = torch.zeros((batch_size, num_pixels, 2), device=device) + sampler_min_max.reshape(-1, 2)[sampler_mask, 0] = acc_start_dis[ + sampler_mask + ] + sampler_min_max.reshape(-1, 2)[sampler_mask, 1] = acc_end_dis[sampler_mask] + + sampler_pts, sampler_net_obj_mask, sampler_dists = self.ray_sampler( + sdf, cam_loc, object_mask, ray_directions, sampler_min_max, sampler_mask + ) + + curr_start_points[sampler_mask] = sampler_pts[sampler_mask] + acc_start_dis[sampler_mask] = sampler_dists[sampler_mask] + network_object_mask[sampler_mask] = sampler_net_obj_mask[sampler_mask] + + if not self.training: + return curr_start_points, network_object_mask, acc_start_dis + + # in case we are training, we are updating curr_start_points and acc_start_dis for + + ray_directions = ray_directions.reshape(-1, 3) + mask_intersect = mask_intersect.reshape(-1) + object_mask = object_mask.reshape(-1) + + in_mask = ~network_object_mask & object_mask & ~sampler_mask + out_mask = ~object_mask & ~sampler_mask + + # pyre-fixme[16]: `Tensor` has no attribute `__invert__`. + mask_left_out = (in_mask | out_mask) & ~mask_intersect + if ( + mask_left_out.sum() > 0 + ): # project the origin to the not intersect points on the sphere + cam_left_out = cam_loc.reshape(-1, 3)[mask_left_out] + rays_left_out = ray_directions[mask_left_out] + acc_start_dis[mask_left_out] = -torch.bmm( + rays_left_out.view(-1, 1, 3), cam_left_out.view(-1, 3, 1) + ).squeeze() + curr_start_points[mask_left_out] = ( + cam_left_out + acc_start_dis[mask_left_out].unsqueeze(1) * rays_left_out + ) + + mask = (in_mask | out_mask) & mask_intersect + + if mask.sum() > 0: + min_dis[network_object_mask & out_mask] = acc_start_dis[ + network_object_mask & out_mask + ] + + min_mask_points, min_mask_dist = self.minimal_sdf_points( + sdf, cam_loc, ray_directions, mask, min_dis, max_dis + ) + + curr_start_points[mask] = min_mask_points + acc_start_dis[mask] = min_mask_dist + + return curr_start_points, network_object_mask, acc_start_dis + + def sphere_tracing( + self, + batch_size: int, + num_pixels: int, + sdf: Callable[[torch.Tensor], torch.Tensor], + cam_loc: torch.Tensor, + ray_directions: torch.Tensor, + mask_intersect: torch.Tensor, + sphere_intersections: torch.Tensor, + ) -> Tuple[Any, Any, Any, Any, Any, Any]: + """ + Run sphere tracing algorithm for max iterations + from both sides of unit sphere intersection + + Args: + batch_size: + num_pixels: + sdf: + cam_loc: + ray_directions: + mask_intersect: + sphere_intersections: + + Returns: + curr_start_points: + unfinished_mask_start: + acc_start_dis: + acc_end_dis: + min_dis: + max_dis: + """ + + device = cam_loc.device + sphere_intersections_points = ( + cam_loc[..., None, :] + + sphere_intersections[..., None] * ray_directions[..., None, :] + ) + unfinished_mask_start = mask_intersect.reshape(-1).clone() + unfinished_mask_end = mask_intersect.reshape(-1).clone() + + # Initialize start current points + curr_start_points = torch.zeros(batch_size * num_pixels, 3, device=device) + curr_start_points[unfinished_mask_start] = sphere_intersections_points[ + :, :, 0, : + ].reshape(-1, 3)[unfinished_mask_start] + acc_start_dis = torch.zeros(batch_size * num_pixels, device=device) + acc_start_dis[unfinished_mask_start] = sphere_intersections.reshape(-1, 2)[ + unfinished_mask_start, 0 + ] + + # Initialize end current points + curr_end_points = torch.zeros(batch_size * num_pixels, 3, device=device) + curr_end_points[unfinished_mask_end] = sphere_intersections_points[ + :, :, 1, : + ].reshape(-1, 3)[unfinished_mask_end] + acc_end_dis = torch.zeros(batch_size * num_pixels, device=device) + acc_end_dis[unfinished_mask_end] = sphere_intersections.reshape(-1, 2)[ + unfinished_mask_end, 1 + ] + + # Initialise min and max depth + min_dis = acc_start_dis.clone() + max_dis = acc_end_dis.clone() + + # Iterate on the rays (from both sides) till finding a surface + iters = 0 + + # TODO: sdf should also pass info about batches + + next_sdf_start = torch.zeros_like(acc_start_dis) + next_sdf_start[unfinished_mask_start] = sdf( + curr_start_points[unfinished_mask_start] + ) + + next_sdf_end = torch.zeros_like(acc_end_dis) + next_sdf_end[unfinished_mask_end] = sdf(curr_end_points[unfinished_mask_end]) + + while True: + # Update sdf + curr_sdf_start = torch.zeros_like(acc_start_dis) + curr_sdf_start[unfinished_mask_start] = next_sdf_start[ + unfinished_mask_start + ] + curr_sdf_start[curr_sdf_start <= self.sdf_threshold] = 0 + + curr_sdf_end = torch.zeros_like(acc_end_dis) + curr_sdf_end[unfinished_mask_end] = next_sdf_end[unfinished_mask_end] + curr_sdf_end[curr_sdf_end <= self.sdf_threshold] = 0 + + # Update masks + unfinished_mask_start = unfinished_mask_start & ( + curr_sdf_start > self.sdf_threshold + ) + unfinished_mask_end = unfinished_mask_end & ( + curr_sdf_end > self.sdf_threshold + ) + + if ( + unfinished_mask_start.sum() == 0 and unfinished_mask_end.sum() == 0 + ) or iters == self.sphere_tracing_iters: + break + iters += 1 + + # Make step + # Update distance + acc_start_dis = acc_start_dis + curr_sdf_start + acc_end_dis = acc_end_dis - curr_sdf_end + + # Update points + curr_start_points = ( + cam_loc + + acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions + ).reshape(-1, 3) + curr_end_points = ( + cam_loc + + acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions + ).reshape(-1, 3) + + # Fix points which wrongly crossed the surface + next_sdf_start = torch.zeros_like(acc_start_dis) + next_sdf_start[unfinished_mask_start] = sdf( + curr_start_points[unfinished_mask_start] + ) + + next_sdf_end = torch.zeros_like(acc_end_dis) + next_sdf_end[unfinished_mask_end] = sdf( + curr_end_points[unfinished_mask_end] + ) + + not_projected_start = next_sdf_start < 0 + not_projected_end = next_sdf_end < 0 + not_proj_iters = 0 + while ( + not_projected_start.sum() > 0 or not_projected_end.sum() > 0 + ) and not_proj_iters < self.line_step_iters: + # Step backwards + acc_start_dis[not_projected_start] -= ( + (1 - self.line_search_step) / (2 ** not_proj_iters) + ) * curr_sdf_start[not_projected_start] + curr_start_points[not_projected_start] = ( + cam_loc + + acc_start_dis.reshape(batch_size, num_pixels, 1) * ray_directions + ).reshape(-1, 3)[not_projected_start] + + acc_end_dis[not_projected_end] += ( + (1 - self.line_search_step) / (2 ** not_proj_iters) + ) * curr_sdf_end[not_projected_end] + curr_end_points[not_projected_end] = ( + cam_loc + + acc_end_dis.reshape(batch_size, num_pixels, 1) * ray_directions + ).reshape(-1, 3)[not_projected_end] + + # Calc sdf + next_sdf_start[not_projected_start] = sdf( + curr_start_points[not_projected_start] + ) + next_sdf_end[not_projected_end] = sdf( + curr_end_points[not_projected_end] + ) + + # Update mask + not_projected_start = next_sdf_start < 0 + not_projected_end = next_sdf_end < 0 + not_proj_iters += 1 + + unfinished_mask_start = unfinished_mask_start & ( + acc_start_dis < acc_end_dis + ) + unfinished_mask_end = unfinished_mask_end & (acc_start_dis < acc_end_dis) + + return ( + curr_start_points, + unfinished_mask_start, + acc_start_dis, + acc_end_dis, + min_dis, + max_dis, + ) + + def ray_sampler( + self, + sdf: Callable[[torch.Tensor], torch.Tensor], + cam_loc: torch.Tensor, + object_mask: torch.Tensor, + ray_directions: torch.Tensor, + sampler_min_max: torch.Tensor, + sampler_mask: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sample the ray in a given range and run secant on rays which have sign transition. + + Args: + sdf: + cam_loc: + object_mask: + ray_directions: + sampler_min_max: + sampler_mask: + + Returns: + + """ + + batch_size, num_pixels, _ = ray_directions.shape + device = cam_loc.device + n_total_pxl = batch_size * num_pixels + sampler_pts = torch.zeros(n_total_pxl, 3, device=device) + sampler_dists = torch.zeros(n_total_pxl, device=device) + + intervals_dist = torch.linspace(0, 1, steps=self.n_steps, device=device).view( + 1, 1, -1 + ) + + pts_intervals = sampler_min_max[:, :, 0].unsqueeze(-1) + intervals_dist * ( + sampler_min_max[:, :, 1] - sampler_min_max[:, :, 0] + ).unsqueeze(-1) + points = ( + cam_loc[..., None, :] + + pts_intervals[..., None] * ray_directions[..., None, :] + ) + + # Get the non convergent rays + mask_intersect_idx = torch.nonzero(sampler_mask).flatten() + points = points.reshape((-1, self.n_steps, 3))[sampler_mask, :, :] + pts_intervals = pts_intervals.reshape((-1, self.n_steps))[sampler_mask] + + sdf_val_all = [] + for pnts in torch.split(points.reshape(-1, 3), 100000, dim=0): + sdf_val_all.append(sdf(pnts)) + sdf_val = torch.cat(sdf_val_all).reshape(-1, self.n_steps) + + tmp = torch.sign(sdf_val) * torch.arange( + self.n_steps, 0, -1, device=device, dtype=torch.float32 + ).reshape(1, self.n_steps) + # Force argmin to return the first min value + sampler_pts_ind = torch.argmin(tmp, -1) + sampler_pts[mask_intersect_idx] = points[ + torch.arange(points.shape[0]), sampler_pts_ind, : + ] + sampler_dists[mask_intersect_idx] = pts_intervals[ + torch.arange(pts_intervals.shape[0]), sampler_pts_ind + ] + + true_surface_pts = object_mask.reshape(-1)[sampler_mask] + net_surface_pts = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind] < 0 + + # take points with minimal SDF value for P_out pixels + p_out_mask = ~(true_surface_pts & net_surface_pts) + n_p_out = p_out_mask.sum() + if n_p_out > 0: + out_pts_idx = torch.argmin(sdf_val[p_out_mask, :], -1) + sampler_pts[mask_intersect_idx[p_out_mask]] = points[p_out_mask, :, :][ + torch.arange(n_p_out), out_pts_idx, : + ] + sampler_dists[mask_intersect_idx[p_out_mask]] = pts_intervals[ + p_out_mask, : + ][torch.arange(n_p_out), out_pts_idx] + + # Get Network object mask + sampler_net_obj_mask = sampler_mask.clone() + sampler_net_obj_mask[mask_intersect_idx[~net_surface_pts]] = False + + # Run Secant method + secant_pts = ( + net_surface_pts & true_surface_pts if self.training else net_surface_pts + ) + n_secant_pts = secant_pts.sum() + if n_secant_pts > 0: + # Get secant z predictions + z_high = pts_intervals[ + torch.arange(pts_intervals.shape[0]), sampler_pts_ind + ][secant_pts] + sdf_high = sdf_val[torch.arange(sdf_val.shape[0]), sampler_pts_ind][ + secant_pts + ] + z_low = pts_intervals[secant_pts][ + torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1 + ] + sdf_low = sdf_val[secant_pts][ + torch.arange(n_secant_pts), sampler_pts_ind[secant_pts] - 1 + ] + cam_loc_secant = cam_loc.reshape(-1, 3)[mask_intersect_idx[secant_pts]] + ray_directions_secant = ray_directions.reshape((-1, 3))[ + mask_intersect_idx[secant_pts] + ] + z_pred_secant = self.secant( + sdf_low, + sdf_high, + z_low, + z_high, + cam_loc_secant, + ray_directions_secant, + # pyre-fixme[6]: For 7th param expected `Module` but got `(Tensor) + # -> Tensor`. + sdf, + ) + + # Get points + sampler_pts[mask_intersect_idx[secant_pts]] = ( + cam_loc_secant + z_pred_secant.unsqueeze(-1) * ray_directions_secant + ) + sampler_dists[mask_intersect_idx[secant_pts]] = z_pred_secant + + return sampler_pts, sampler_net_obj_mask, sampler_dists + + def secant( + self, + sdf_low: torch.Tensor, + sdf_high: torch.Tensor, + z_low: torch.Tensor, + z_high: torch.Tensor, + cam_loc: torch.Tensor, + ray_directions: torch.Tensor, + sdf: nn.Module, + ) -> torch.Tensor: + """ + Runs the secant method for interval [z_low, z_high] for n_secant_steps + """ + + z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + for _ in range(self.n_secant_steps): + p_mid = cam_loc + z_pred.unsqueeze(-1) * ray_directions + sdf_mid = sdf(p_mid) + ind_low = sdf_mid > 0 + if ind_low.sum() > 0: + z_low[ind_low] = z_pred[ind_low] + sdf_low[ind_low] = sdf_mid[ind_low] + ind_high = sdf_mid < 0 + if ind_high.sum() > 0: + z_high[ind_high] = z_pred[ind_high] + sdf_high[ind_high] = sdf_mid[ind_high] + + z_pred = -sdf_low * (z_high - z_low) / (sdf_high - sdf_low) + z_low + + return z_pred + + def minimal_sdf_points( + self, + sdf: Callable[[torch.Tensor], torch.Tensor], + cam_loc: torch.Tensor, + ray_directions: torch.Tensor, + mask: torch.Tensor, + min_dis: torch.Tensor, + max_dis: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Find points with minimal SDF value on rays for P_out pixels + """ + + n_mask_points = mask.sum() + + n = self.n_steps + steps = torch.empty(n, device=cam_loc.device).uniform_(0.0, 1.0) + mask_max_dis = max_dis[mask].unsqueeze(-1) + mask_min_dis = min_dis[mask].unsqueeze(-1) + steps = ( + steps.unsqueeze(0).repeat(n_mask_points, 1) * (mask_max_dis - mask_min_dis) + + mask_min_dis + ) + + mask_points = cam_loc.reshape(-1, 3)[mask] + mask_rays = ray_directions[mask, :] + + mask_points_all = mask_points.unsqueeze(1).repeat(1, n, 1) + steps.unsqueeze( + -1 + ) * mask_rays.unsqueeze(1).repeat(1, n, 1) + points = mask_points_all.reshape(-1, 3) + + mask_sdf_all = [] + for pnts in torch.split(points, 100000, dim=0): + mask_sdf_all.append(sdf(pnts)) + + mask_sdf_all = torch.cat(mask_sdf_all).reshape(-1, n) + min_vals, min_idx = mask_sdf_all.min(-1) + min_mask_points = mask_points_all.reshape(-1, n, 3)[ + torch.arange(0, n_mask_points), min_idx + ] + min_mask_dist = steps.reshape(-1, n)[torch.arange(0, n_mask_points), min_idx] + + return min_mask_points, min_mask_dist + + +# TODO: support variable origins +def _get_sphere_intersection( + cam_loc: torch.Tensor, ray_directions: torch.Tensor, r: float = 1.0 +) -> Tuple[torch.Tensor, torch.Tensor]: + # Input: n_images x 3 ; n_images x n_rays x 3 + # Output: n_images * n_rays x 2 (close and far) ; n_images * n_rays + + n_imgs, n_pix, _ = ray_directions.shape + device = cam_loc.device + + # cam_loc = cam_loc.unsqueeze(-1) + # ray_cam_dot = torch.bmm(ray_directions, cam_loc).squeeze() + ray_cam_dot = (ray_directions * cam_loc).sum(-1) # n_images x n_rays + under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, dim=-1) ** 2 - r ** 2) + + under_sqrt = under_sqrt.reshape(-1) + mask_intersect = under_sqrt > 0 + + sphere_intersections = torch.zeros(n_imgs * n_pix, 2, device=device) + sphere_intersections[mask_intersect] = torch.sqrt( + under_sqrt[mask_intersect] + ).unsqueeze(-1) * torch.tensor([-1.0, 1.0], device=device) + sphere_intersections[mask_intersect] -= ray_cam_dot.reshape(-1)[ + mask_intersect + ].unsqueeze(-1) + + sphere_intersections = sphere_intersections.reshape(n_imgs, n_pix, 2) + sphere_intersections = sphere_intersections.clamp_min(0.0) + mask_intersect = mask_intersect.reshape(n_imgs, n_pix) + + return sphere_intersections, mask_intersect diff --git a/pytorch3d/implicitron/models/renderer/raymarcher.py b/pytorch3d/implicitron/models/renderer/raymarcher.py new file mode 100644 index 00000000..87e52911 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/raymarcher.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, Dict, Tuple, Union + +import torch +from pytorch3d.renderer.implicit.raymarching import _check_raymarcher_inputs + + +_TTensor = torch.Tensor + + +class GenericRaymarcher(torch.nn.Module): + """ + This generalizes the `pytorch3d.renderer.EmissionAbsorptionRaymarcher` + and NeuralVolumes' Accumulative ray marcher. It additionally returns + the rendering weights that can be used in the NVS pipeline to carry out + the importance ray-sampling in the refining pass. + Different from `EmissionAbsorptionRaymarcher`, it takes raw + (non-exponentiated) densities. + + Args: + bg_color: background_color. Must be of shape (1,) or (feature_dim,) + """ + + def __init__( + self, + surface_thickness: int = 1, + bg_color: Union[Tuple[float, ...], _TTensor] = (0.0,), + capping_function: str = "exponential", # exponential | cap1 + weight_function: str = "product", # product | minimum + background_opacity: float = 0.0, + density_relu: bool = True, + blend_output: bool = True, + ): + """ + Args: + surface_thickness: Denotes the overlap between the absorption + function and the density function. + """ + super().__init__() + self.surface_thickness = surface_thickness + self.density_relu = density_relu + self.background_opacity = background_opacity + self.blend_output = blend_output + if not isinstance(bg_color, torch.Tensor): + bg_color = torch.tensor(bg_color) + + if bg_color.ndim != 1: + raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor") + + self.register_buffer("_bg_color", bg_color, persistent=False) + + self._capping_function: Callable[[_TTensor], _TTensor] = { + "exponential": lambda x: 1.0 - torch.exp(-x), + "cap1": lambda x: x.clamp(max=1.0), + }[capping_function] + + self._weight_function: Callable[[_TTensor, _TTensor], _TTensor] = { + "product": lambda curr, acc: curr * acc, + "minimum": lambda curr, acc: torch.minimum(curr, acc), + }[weight_function] + + def forward( + self, + rays_densities: torch.Tensor, + rays_features: torch.Tensor, + aux: Dict[str, Any], + ray_lengths: torch.Tensor, + density_noise_std: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, Any]]: + """ + Args: + rays_densities: Per-ray density values represented with a tensor + of shape `(..., n_points_per_ray, 1)`. + rays_features: Per-ray feature values represented with a tensor + of shape `(..., n_points_per_ray, feature_dim)`. + aux: a dictionary with extra information. + ray_lengths: Per-ray depth values represented with a tensor + of shape `(..., n_points_per_ray, feature_dim)`. + density_noise_std: the magnitude of the noise added to densities. + + Returns: + features: A tensor of shape `(..., feature_dim)` containing + the rendered features for each ray. + depth: A tensor of shape `(..., 1)` containing estimated depth. + opacities: A tensor of shape `(..., 1)` containing rendered opacsities. + weights: A tensor of shape `(..., n_points_per_ray)` containing + the ray-specific non-negative opacity weights. In general, they + don't sum to 1 but do not overcome it, i.e. + `(weights.sum(dim=-1) <= 1.0).all()` holds. + """ + _check_raymarcher_inputs( + rays_densities, + rays_features, + ray_lengths, + z_can_be_none=True, + features_can_be_none=False, + density_1d=True, + ) + + deltas = torch.cat( + ( + ray_lengths[..., 1:] - ray_lengths[..., :-1], + self.background_opacity * torch.ones_like(ray_lengths[..., :1]), + ), + dim=-1, + ) + + rays_densities = rays_densities[..., 0] + + if density_noise_std > 0.0: + rays_densities = ( + rays_densities + torch.randn_like(rays_densities) * density_noise_std + ) + if self.density_relu: + rays_densities = torch.relu(rays_densities) + + weighted_densities = deltas * rays_densities + capped_densities = self._capping_function(weighted_densities) + + rays_opacities = self._capping_function( + torch.cumsum(weighted_densities, dim=-1) + ) + opacities = rays_opacities[..., -1:] + absorption_shifted = (-rays_opacities + 1.0).roll( + self.surface_thickness, dims=-1 + ) + absorption_shifted[..., : self.surface_thickness] = 1.0 + + weights = self._weight_function(capped_densities, absorption_shifted) + features = (weights[..., None] * rays_features).sum(dim=-2) + depth = (weights * ray_lengths)[..., None].sum(dim=-2) + + alpha = opacities if self.blend_output else 1 + if self._bg_color.shape[-1] not in [1, features.shape[-1]]: + raise ValueError("Wrong number of background color channels.") + features = alpha * features + (1 - opacities) * self._bg_color + + return features, depth, opacities, weights, aux diff --git a/pytorch3d/implicitron/models/renderer/rgb_net.py b/pytorch3d/implicitron/models/renderer/rgb_net.py new file mode 100644 index 00000000..0e444a43 --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/rgb_net.py @@ -0,0 +1,101 @@ +# @lint-ignore-every LICENSELINT +# Adapted from RenderingNetwork from IDR +# https://github.com/lioryariv/idr/ +# Copyright (c) 2020 Lior Yariv +import torch +from pytorch3d.renderer.implicit import HarmonicEmbedding, RayBundle +from torch import nn + + +class RayNormalColoringNetwork(torch.nn.Module): + def __init__( + self, + feature_vector_size=3, + mode="idr", + d_in=9, + d_out=3, + dims=(512, 512, 512, 512), + weight_norm=True, + n_harmonic_functions_dir=0, + pooled_feature_dim=0, + ): + super().__init__() + + self.mode = mode + self.output_dimensions = d_out + dims = [d_in + feature_vector_size] + list(dims) + [d_out] + + self.embedview_fn = None + if n_harmonic_functions_dir > 0: + self.embedview_fn = HarmonicEmbedding( + n_harmonic_functions_dir, append_input=True + ) + dims[0] += self.embedview_fn.get_output_dim() - 3 + + if pooled_feature_dim > 0: + print("Pooled features in rendering network.") + dims[0] += pooled_feature_dim + + self.num_layers = len(dims) + + layers = [] + for layer_idx in range(self.num_layers - 1): + out_dim = dims[layer_idx + 1] + lin = nn.Linear(dims[layer_idx], out_dim) + + if weight_norm: + lin = nn.utils.weight_norm(lin) + + layers.append(lin) + self.linear_layers = torch.nn.ModuleList(layers) + + self.relu = nn.ReLU() + self.tanh = nn.Tanh() + + def forward( + self, + feature_vectors: torch.Tensor, + points, + normals, + ray_bundle: RayBundle, + masks=None, + pooling_fn=None, + ): + if masks is not None and not masks.any(): + return torch.zeros_like(normals) + + view_dirs = ray_bundle.directions + if masks is not None: + # in case of IDR, other outputs are passed here after applying the mask + view_dirs = view_dirs.reshape(view_dirs.shape[0], -1, 3)[ + :, masks.reshape(-1) + ] + + if self.embedview_fn is not None: + view_dirs = self.embedview_fn(view_dirs) + + if self.mode == "idr": + rendering_input = torch.cat( + [points, view_dirs, normals, feature_vectors], dim=-1 + ) + elif self.mode == "no_view_dir": + rendering_input = torch.cat([points, normals, feature_vectors], dim=-1) + elif self.mode == "no_normal": + rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) + else: + raise ValueError(f"Unsupported rendering mode: {self.mode}") + + if pooling_fn is not None: + featspool = pooling_fn(points[None])[0] + rendering_input = torch.cat((rendering_input, featspool), dim=-1) + + x = rendering_input + + for layer_idx in range(self.num_layers - 1): + x = self.linear_layers[layer_idx](x) + + if layer_idx < self.num_layers - 2: + x = self.relu(x) + + x = self.tanh(x) + return x diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py new file mode 100644 index 00000000..735a8a8c --- /dev/null +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -0,0 +1,253 @@ +# @lint-ignore-every LICENSELINT +# Adapted from https://github.com/lioryariv/idr/blob/main/code/model/ +# implicit_differentiable_renderer.py +# Copyright (c) 2020 Lior Yariv +import functools +import math +from typing import List, Optional, Tuple + +import torch +from omegaconf import DictConfig +from pytorch3d.implicitron.tools.config import get_default_args_field, registry +from pytorch3d.implicitron.tools.utils import evaluating +from pytorch3d.renderer import RayBundle + +from .base import BaseRenderer, EvaluationMode, ImplicitFunctionWrapper, RendererOutput +from .ray_tracing import RayTracing +from .rgb_net import RayNormalColoringNetwork + + +@registry.register +class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): + render_features_dimensions: int = 3 + ray_tracer_args: DictConfig = get_default_args_field(RayTracing) + ray_normal_coloring_network_args: DictConfig = get_default_args_field( + RayNormalColoringNetwork + ) + bg_color: Tuple[float, ...] = (0.0,) + soft_mask_alpha: float = 50.0 + + def __post_init__( + self, + ): + super().__init__() + render_features_dimensions = self.render_features_dimensions + if len(self.bg_color) not in [1, render_features_dimensions]: + raise ValueError( + f"Background color should have {render_features_dimensions} entries." + ) + + self.ray_tracer = RayTracing(**self.ray_tracer_args) + self.object_bounding_sphere = self.ray_tracer_args.get("object_bounding_sphere") + + self.ray_normal_coloring_network_args[ + "feature_vector_size" + ] = render_features_dimensions + self._rgb_network = RayNormalColoringNetwork( + **self.ray_normal_coloring_network_args + ) + + self.register_buffer("_bg_color", torch.tensor(self.bg_color), persistent=False) + + def forward( + self, + ray_bundle: RayBundle, + implicit_functions: List[ImplicitFunctionWrapper], + evaluation_mode: EvaluationMode = EvaluationMode.EVALUATION, + object_mask: Optional[torch.Tensor] = None, + **kwargs, + ) -> RendererOutput: + """ + Args: + ray_bundle: A `RayBundle` object containing the parametrizations of the + sampled rendering rays. + implicit_functions: single element list of ImplicitFunctionWrappers which + defines the implicit function to be used. + evaluation_mode: one of EvaluationMode.TRAINING or + EvaluationMode.EVALUATION which determines the settings used for + rendering. + kwargs: + object_mask: BoolTensor, denoting the silhouette of the object. + This is a required keyword argument for SignedDistanceFunctionRenderer + + Returns: + instance of RendererOutput + """ + if len(implicit_functions) != 1: + raise ValueError( + "SignedDistanceFunctionRenderer supports only single pass." + ) + + if object_mask is None: + raise ValueError("Expected object_mask to be provided in the kwargs") + object_mask = object_mask.bool() + + implicit_function = implicit_functions[0] + implicit_function_gradient = functools.partial(gradient, implicit_function) + + # object_mask: silhouette of the object + batch_size, *spatial_size, _ = ray_bundle.lengths.shape + num_pixels = math.prod(spatial_size) + + cam_loc = ray_bundle.origins.reshape(batch_size, -1, 3) + ray_dirs = ray_bundle.directions.reshape(batch_size, -1, 3) + object_mask = object_mask.reshape(batch_size, -1) + + with torch.no_grad(), evaluating(implicit_function): + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function. + points, network_object_mask, dists = self.ray_tracer( + sdf=lambda x: implicit_function(x)[ + :, 0 + ], # TODO: get rid of this wrapper + cam_loc=cam_loc, + object_mask=object_mask, + ray_directions=ray_dirs, + ) + + # TODO: below, cam_loc might as well be different + depth = dists.reshape(batch_size, num_pixels, 1) + points = (cam_loc + depth * ray_dirs).reshape(-1, 3) + + sdf_output = implicit_function(points)[:, 0:1] + # NOTE most of the intermediate variables are flattened for + # no apparent reason (here and in the ray tracer) + ray_dirs = ray_dirs.reshape(-1, 3) + object_mask = object_mask.reshape(-1) + + # TODO: move it to loss computation + if evaluation_mode == EvaluationMode.TRAINING: + surface_mask = network_object_mask & object_mask + surface_points = points[surface_mask] + surface_dists = dists[surface_mask].unsqueeze(-1) + surface_ray_dirs = ray_dirs[surface_mask] + surface_cam_loc = cam_loc.reshape(-1, 3)[surface_mask] + surface_output = sdf_output[surface_mask] + N = surface_points.shape[0] + + # Sample points for the eikonal loss + # pyre-fixme[9] + eik_bounding_box: float = self.object_bounding_sphere + n_eik_points = batch_size * num_pixels // 2 + eikonal_points = torch.empty( + n_eik_points, 3, device=self._bg_color.device + ).uniform_(-eik_bounding_box, eik_bounding_box) + eikonal_pixel_points = points.clone() + eikonal_pixel_points = eikonal_pixel_points.detach() + eikonal_points = torch.cat([eikonal_points, eikonal_pixel_points], 0) + + points_all = torch.cat([surface_points, eikonal_points], dim=0) + + output = implicit_function(surface_points) + surface_sdf_values = output[ + :N, 0:1 + ].detach() # how is it different from sdf_output? + + g = implicit_function_gradient(points_all) + surface_points_grad = g[:N, 0, :].clone().detach() + grad_theta = g[N:, 0, :] + + differentiable_surface_points = _sample_network( + surface_output, + surface_sdf_values, + surface_points_grad, + surface_dists, + surface_cam_loc, + surface_ray_dirs, + ) + + else: + surface_mask = network_object_mask + differentiable_surface_points = points[surface_mask] + grad_theta = None + + empty_render = differentiable_surface_points.shape[0] == 0 + features = implicit_function(differentiable_surface_points)[None, :, 1:] + normals_full = features.new_zeros( + batch_size, *spatial_size, 3, requires_grad=empty_render + ) + render_full = ( + features.new_ones( + batch_size, + *spatial_size, + self.render_features_dimensions, + requires_grad=empty_render, + ) + * self._bg_color + ) + mask_full = features.new_ones( + batch_size, *spatial_size, 1, requires_grad=empty_render + ) + if not empty_render: + normals = implicit_function_gradient(differentiable_surface_points)[ + None, :, 0, : + ] + normals_full.view(-1, 3)[surface_mask] = normals + render_full.view(-1, self.render_features_dimensions)[ + surface_mask + ] = self._rgb_network( # pyre-fixme[29]: + features, + differentiable_surface_points[None], + normals, + ray_bundle, + surface_mask[None, :, None], + pooling_fn=None, # TODO + ) + mask_full.view(-1, 1)[~surface_mask] = torch.sigmoid( + -self.soft_mask_alpha * sdf_output[~surface_mask] + ) + + # scatter points with surface_mask + points_full = ray_bundle.origins.detach().clone() + points_full.view(-1, 3)[surface_mask] = differentiable_surface_points + + # TODO: it is sparse here but otherwise dense + return RendererOutput( + features=render_full, + normals=normals_full, + depths=depth.reshape(batch_size, *spatial_size, 1), + masks=mask_full, # this is a differentiable approximation, see (7) in the paper + points=points_full, + aux={"grad_theta": grad_theta}, # TODO: will be moved to eikonal loss + # TODO: do we need sdf_output, grad_theta? Only for loss probably + ) + + +def _sample_network( + surface_output, + surface_sdf_values, + surface_points_grad, + surface_dists, + surface_cam_loc, + surface_ray_dirs, + eps=1e-4, +): + # t -> t(theta) + surface_ray_dirs_0 = surface_ray_dirs.detach() + surface_points_dot = torch.bmm( + surface_points_grad.view(-1, 1, 3), surface_ray_dirs_0.view(-1, 3, 1) + ).squeeze(-1) + dot_sign = (surface_points_dot >= 0).to(surface_points_dot) * 2 - 1 + surface_dists_theta = surface_dists - (surface_output - surface_sdf_values) / ( + surface_points_dot.abs().clip(eps) * dot_sign + ) + + # t(theta) -> x(theta,c,v) + surface_points_theta_c_v = surface_cam_loc + surface_dists_theta * surface_ray_dirs + + return surface_points_theta_c_v + + +@torch.enable_grad() +def gradient(module, x): + x.requires_grad_(True) + y = module.forward(x)[:, :1] + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + return gradients.unsqueeze(1) diff --git a/pytorch3d/implicitron/models/resnet_feature_extractor.py b/pytorch3d/implicitron/models/resnet_feature_extractor.py new file mode 100644 index 00000000..31b8664a --- /dev/null +++ b/pytorch3d/implicitron/models/resnet_feature_extractor.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import logging +import math +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as Fu +import torchvision +from pytorch3d.implicitron.tools.config import Configurable + + +logger = logging.getLogger(__name__) + +MASK_FEATURE_NAME = "mask" +IMAGE_FEATURE_NAME = "image" + +_FEAT_DIMS = { + "resnet18": (64, 128, 256, 512), + "resnet34": (64, 128, 256, 512), + "resnet50": (256, 512, 1024, 2048), + "resnet101": (256, 512, 1024, 2048), + "resnet152": (256, 512, 1024, 2048), +} + +_RESNET_MEAN = [0.485, 0.456, 0.406] +_RESNET_STD = [0.229, 0.224, 0.225] + + +class ResNetFeatureExtractor(Configurable, torch.nn.Module): + """ + Implements an image feature extractor. Depending on the settings allows + to extract: + - deep features: A CNN ResNet backbone from torchvision (with/without + pretrained weights) which extracts deep features. + - masks: Segmentation masks. + - images: Raw input RGB images. + + Settings: + name: name of the resnet backbone (from torchvision) + pretrained: If true, will load the pretrained weights + stages: List of stages from which to extract features. + Features from each stage are returned as key value + pairs in the forward function + normalize_image: If set will normalize the RGB values of + the image based on the Resnet mean/std + image_rescale: If not 1.0, this rescale factor will be + used to resize the image + first_max_pool: If set, a max pool layer is added after the first + convolutional layer + proj_dim: The number of output channels for the convolutional layers + l2_norm: If set, l2 normalization is applied to the extracted features + add_masks: If set, the masks will be saved in the output dictionary + add_images: If set, the images will be saved in the output dictionary + global_average_pool: If set, global average pooling step is performed + feature_rescale: If not 1.0, this rescale factor will be used to + rescale the output features + """ + + name: str = "resnet34" + pretrained: bool = True + stages: Tuple[int, ...] = (1, 2, 3, 4) + normalize_image: bool = True + image_rescale: float = 128 / 800.0 + first_max_pool: bool = True + proj_dim: int = 32 + l2_norm: bool = True + add_masks: bool = True + add_images: bool = True + global_average_pool: bool = False # this can simulate global/non-spacial features + feature_rescale: float = 1.0 + + def __post_init__(self): + super().__init__() + if self.normalize_image: + # register buffers needed to normalize the image + for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): + self.register_buffer( + k, + torch.FloatTensor(v).view(1, 3, 1, 1), + persistent=False, + ) + + self._feat_dim = {} + + if len(self.stages) == 0: + # do not extract any resnet features + pass + else: + net = getattr(torchvision.models, self.name)(pretrained=self.pretrained) + if self.first_max_pool: + self.stem = torch.nn.Sequential( + net.conv1, net.bn1, net.relu, net.maxpool + ) + else: + self.stem = torch.nn.Sequential(net.conv1, net.bn1, net.relu) + self.max_stage = max(self.stages) + self.layers = torch.nn.ModuleList() + self.proj_layers = torch.nn.ModuleList() + for stage in range(self.max_stage): + stage_name = f"layer{stage+1}" + feature_name = self._get_resnet_stage_feature_name(stage) + if (stage + 1) in self.stages: + if ( + self.proj_dim > 0 + and _FEAT_DIMS[self.name][stage] > self.proj_dim + ): + proj = torch.nn.Conv2d( + _FEAT_DIMS[self.name][stage], + self.proj_dim, + 1, + 1, + bias=True, + ) + self._feat_dim[feature_name] = self.proj_dim + else: + proj = torch.nn.Identity() + self._feat_dim[feature_name] = _FEAT_DIMS[self.name][stage] + else: + proj = torch.nn.Identity() + self.proj_layers.append(proj) + self.layers.append(getattr(net, stage_name)) + + if self.add_masks: + self._feat_dim[MASK_FEATURE_NAME] = 1 + + if self.add_images: + self._feat_dim[IMAGE_FEATURE_NAME] = 3 + + logger.info(f"Feat extractor total dim = {self.get_feat_dims()}") + self.stages = set(self.stages) # convert to set for faster "in" + + def _get_resnet_stage_feature_name(self, stage) -> str: + return f"res_layer_{stage+1}" + + def _resnet_normalize_image(self, img: torch.Tensor) -> torch.Tensor: + return (img - self._resnet_mean) / self._resnet_std + + def get_feat_dims(self, size_dict: bool = False): + if size_dict: + return copy.deepcopy(self._feat_dim) + # pyre-fixme[16]: Item `Tensor` of `Union[Tensor, Module]` has no attribute + # `values`. + return sum(self._feat_dim.values()) + + def forward( + self, imgs: torch.Tensor, masks: Optional[torch.Tensor] = None + ) -> Dict[Any, torch.Tensor]: + """ + Args: + imgs: A batch of input images of shape `(B, 3, H, W)`. + masks: A batch of input masks of shape `(B, 3, H, W)`. + + Returns: + out_feats: A dict `{f_i: t_i}` keyed by predicted feature names `f_i` + and their corresponding tensors `t_i` of shape `(B, dim_i, H_i, W_i)`. + """ + + out_feats = {} + + imgs_input = imgs + if self.image_rescale != 1.0: + imgs_resized = Fu.interpolate( + imgs_input, + # pyre-fixme[6]: For 2nd param expected `Optional[List[float]]` but + # got `float`. + scale_factor=self.image_rescale, + mode="bilinear", + ) + else: + imgs_resized = imgs_input + + if self.normalize_image: + imgs_normed = self._resnet_normalize_image(imgs_resized) + else: + imgs_normed = imgs_resized + + if len(self.stages) > 0: + # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]` + # is not a function. + feats = self.stem(imgs_normed) + # pyre-fixme[6]: For 1st param expected `Iterable[Variable[_T1]]` but + # got `Union[Tensor, Module]`. + # pyre-fixme[6]: For 2nd param expected `Iterable[Variable[_T2]]` but + # got `Union[Tensor, Module]`. + for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)): + feats = layer(feats) + # just a sanity check below + assert feats.shape[1] == _FEAT_DIMS[self.name][stage] + if (stage + 1) in self.stages: + f = proj(feats) + if self.global_average_pool: + f = f.mean(dims=(2, 3)) + if self.l2_norm: + normfac = 1.0 / math.sqrt(len(self.stages)) + f = Fu.normalize(f, dim=1) * normfac + feature_name = self._get_resnet_stage_feature_name(stage) + out_feats[feature_name] = f + + if self.add_masks: + assert masks is not None + out_feats[MASK_FEATURE_NAME] = masks + + if self.add_images: + assert imgs_input is not None + out_feats[IMAGE_FEATURE_NAME] = imgs_resized + + if self.feature_rescale != 1.0: + out_feats = {k: self.feature_rescale * f for k, f in out_feats.items()} + + # pyre-fixme[7]: Incompatible return type, expected `Dict[typing.Any, Tensor]` + # but got `Dict[typing.Any, float]` + return out_feats diff --git a/pytorch3d/implicitron/models/view_pooling/feature_aggregation.py b/pytorch3d/implicitron/models/view_pooling/feature_aggregation.py new file mode 100644 index 00000000..01022607 --- /dev/null +++ b/pytorch3d/implicitron/models/view_pooling/feature_aggregation.py @@ -0,0 +1,666 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from enum import Enum +from typing import Dict, Optional, Sequence, Union + +import torch +import torch.nn.functional as F +from pytorch3d.implicitron.models.view_pooling.view_sampling import ( + cameras_points_cartesian_product, +) +from pytorch3d.implicitron.tools.config import ReplaceableBase, registry +from pytorch3d.ops import wmean +from pytorch3d.renderer.cameras import CamerasBase + + +class ReductionFunction(Enum): + AVG = "avg" # simple average + MAX = "max" # maximum + STD = "std" # standard deviation + STD_AVG = "std_avg" # average of per-dimension standard deviations + + +class FeatureAggregatorBase(ABC, ReplaceableBase): + """ + Base class for aggregating features. + + Typically, the aggregated features and their masks are output by `ViewSampler` + which samples feature tensors extracted from a set of source images. + + Settings: + exclude_target_view: If `True`/`False`, enables/disables pooling + from target view to itself. + exclude_target_view_mask_features: If `True`, + mask the features from the target view before aggregation + concatenate_output: If `True`, + concatenate the aggregated features into a single tensor, + otherwise return a dictionary mapping feature names to tensors. + """ + + exclude_target_view: bool = True + exclude_target_view_mask_features: bool = True + concatenate_output: bool = True + + @abstractmethod + def forward( + self, + feats_sampled: Dict[str, torch.Tensor], + masks_sampled: torch.Tensor, + camera: Optional[CamerasBase] = None, + pts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Args: + feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, + where each `t_i` is a tensor of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + masks_sampled: A binary mask represented as a tensor of shape + `(minibatch, n_source_views, n_samples, 1)` denoting valid + sampled features. + camera: A batch of `n_source_views` `CamerasBase` objects corresponding + to the source view cameras. + pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the + 3D points whose 2D projections to source views were sampled in + order to generate `feats_sampled` and `masks_sampled`. + + Returns: + feats_aggregated: If `concatenate_output==True`, a tensor + of shape `(minibatch, reduce_dim, n_samples, sum(dim_1, ... dim_N))` + containing the concatenation of the aggregated features `feats_sampled`. + `reduce_dim` depends on the specific feature aggregator + implementation and typically equals 1 or `n_source_views`. + If `concatenate_output==False`, the aggregator does not concatenate + the aggregated features and returns a dictionary of per-feature + aggregations `{f_i: t_i_aggregated}` instead. Each `t_i_aggregated` + is of shape `(minibatch, reduce_dim, n_samples, aggr_dim_i)`. + """ + raise NotImplementedError() + + +@registry.register +class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): + """ + This aggregator does not perform any feature aggregation. Depending on the + settings the aggregator allows to mask target view features and concatenate + the outputs. + """ + + def __post_init__(self): + super().__init__() + + def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): + return _get_reduction_aggregator_feature_dim(feats, []) + + def forward( + self, + feats_sampled: Dict[str, torch.Tensor], + masks_sampled: torch.Tensor, + camera: Optional[CamerasBase] = None, + pts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Args: + feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, + where each `t_i` is a tensor of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + masks_sampled: A binary mask represented as a tensor of shape + `(minibatch, n_source_views, n_samples, 1)` denoting valid + sampled features. + camera: A batch of `n_source_views` `CamerasBase` objects + corresponding to the source view cameras. + pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the + 3D points whose 2D projections to source views were sampled in + order to generate `feats_sampled` and `masks_sampled`. + + Returns: + feats_aggregated: If `concatenate_output==True`, a tensor + of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`. + If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` + with each `t_i_aggregated` of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + """ + if self.exclude_target_view_mask_features: + feats_sampled = _mask_target_view_features(feats_sampled) + feats_aggregated = feats_sampled + if self.concatenate_output: + feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) + return feats_aggregated + + +@registry.register +class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): + """ + Aggregates using a set of predefined `reduction_functions` and concatenates + the results of each aggregation function along the + channel dimension. The reduction functions singularize the second dimension + of the sampled features which stacks the source views. + + Settings: + reduction_functions: A list of `ReductionFunction`s` that reduce the + the stack of source-view-specific features to a single feature. + """ + + reduction_functions: Sequence[ReductionFunction] = ( + ReductionFunction.AVG, + ReductionFunction.STD, + ) + + def __post_init__(self): + super().__init__() + + def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): + return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions) + + def forward( + self, + feats_sampled: Dict[str, torch.Tensor], + masks_sampled: torch.Tensor, + camera: Optional[CamerasBase] = None, + pts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Args: + feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, + where each `t_i` is a tensor of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + masks_sampled: A binary mask represented as a tensor of shape + `(minibatch, n_source_views, n_samples, 1)` denoting valid + sampled features. + camera: A batch of `n_source_views` `CamerasBase` objects corresponding + to the source view cameras. + pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the + 3D points whose 2D projections to source views were sampled in + order to generate `feats_sampled` and `masks_sampled`. + + Returns: + feats_aggregated: If `concatenate_output==True`, a tensor + of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`. + If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` + with each `t_i_aggregated` of shape `(minibatch, 1, n_samples, aggr_dim_i)`. + """ + + pts_batch, n_cameras = masks_sampled.shape[:2] + if self.exclude_target_view_mask_features: + feats_sampled = _mask_target_view_features(feats_sampled) + sampling_mask = _get_view_sampling_mask( + n_cameras, + pts_batch, + masks_sampled.device, + self.exclude_target_view, + ) + aggr_weigths = masks_sampled * sampling_mask + feats_aggregated = { + k: _avgmaxstd_reduction_function( + f, + aggr_weigths, + dim=1, + reduction_functions=self.reduction_functions, + ) + for k, f in feats_sampled.items() + } + if self.concatenate_output: + feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) + return feats_aggregated + + +@registry.register +class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): + """ + Performs a weighted aggregation using a set of predefined `reduction_functions` + and concatenates the results of each aggregation function along the + channel dimension. The weights are proportional to the cosine of the + angle between the target ray and the source ray: + ``` + weight = ( + dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight + )**self.weight_by_ray_angle_gamma + ``` + + The reduction functions singularize the second dimension + of the sampled features which stacks the source views. + + Settings: + reduction_functions: A list of `ReductionFunction`s that reduce the + the stack of source-view-specific features to a single feature. + min_ray_angle_weight: The minimum possible aggregation weight + before rasising to the power of `self.weight_by_ray_angle_gamma`. + weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles + used when calculating the angle-based aggregation weights. + """ + + reduction_functions: Sequence[ReductionFunction] = ( + ReductionFunction.AVG, + ReductionFunction.STD, + ) + weight_by_ray_angle_gamma: float = 1.0 + min_ray_angle_weight: float = 0.1 + + def __post_init__(self): + super().__init__() + + def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): + return _get_reduction_aggregator_feature_dim(feats, self.reduction_functions) + + def forward( + self, + feats_sampled: Dict[str, torch.Tensor], + masks_sampled: torch.Tensor, + camera: Optional[CamerasBase] = None, + pts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Args: + feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, + where each `t_i` is a tensor of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + masks_sampled: A binary mask represented as a tensor of shape + `(minibatch, n_source_views, n_samples, 1)` denoting valid + sampled features. + camera: A batch of `n_source_views` `CamerasBase` objects + corresponding to the source view cameras. + pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the + 3D points whose 2D projections to source views were sampled in + order to generate `feats_sampled` and `masks_sampled`. + + Returns: + feats_aggregated: If `concatenate_output==True`, a tensor + of shape `(minibatch, 1, n_samples, sum(dim_1, ... dim_N))`. + If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` + with each `t_i_aggregated` of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + """ + + if camera is None: + raise ValueError("camera cannot be None for angle weighted aggregation") + + if pts is None: + raise ValueError("Points cannot be None for angle weighted aggregation") + + pts_batch, n_cameras = masks_sampled.shape[:2] + if self.exclude_target_view_mask_features: + feats_sampled = _mask_target_view_features(feats_sampled) + view_sampling_mask = _get_view_sampling_mask( + n_cameras, + pts_batch, + masks_sampled.device, + self.exclude_target_view, + ) + aggr_weights = _get_angular_reduction_weights( + view_sampling_mask, + masks_sampled, + camera, + pts, + self.min_ray_angle_weight, + self.weight_by_ray_angle_gamma, + ) + assert torch.isfinite(aggr_weights).all() + feats_aggregated = { + k: _avgmaxstd_reduction_function( + f, + aggr_weights, + dim=1, + reduction_functions=self.reduction_functions, + ) + for k, f in feats_sampled.items() + } + if self.concatenate_output: + feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) + return feats_aggregated + + +@registry.register +class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): + """ + This aggregator does not perform any feature aggregation. It only weights + the features by the weights proportional to the cosine of the + angle between the target ray and the source ray: + ``` + weight = ( + dot(target_ray, source_ray) * 0.5 + 0.5 + self.min_ray_angle_weight + )**self.weight_by_ray_angle_gamma + ``` + + Settings: + min_ray_angle_weight: The minimum possible aggregation weight + before rasising to the power of `self.weight_by_ray_angle_gamma`. + weight_by_ray_angle_gamma: The exponent of the cosine of the ray angles + used when calculating the angle-based aggregation weights. + + Additionally the aggregator allows to mask target view features and to concatenate + the outputs. + """ + + weight_by_ray_angle_gamma: float = 1.0 + min_ray_angle_weight: float = 0.1 + + def __post_init__(self): + super().__init__() + + def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): + return _get_reduction_aggregator_feature_dim(feats, []) + + def forward( + self, + feats_sampled: Dict[str, torch.Tensor], + masks_sampled: torch.Tensor, + camera: Optional[CamerasBase] = None, + pts: Optional[torch.Tensor] = None, + **kwargs, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Args: + feats_sampled: A `dict` of sampled feature tensors `{f_i: t_i}`, + where each `t_i` is a tensor of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + masks_sampled: A binary mask represented as a tensor of shape + `(minibatch, n_source_views, n_samples, 1)` denoting valid + sampled features. + camera: A batch of `n_source_views` `CamerasBase` objects corresponding + to the source view cameras. + pts: A tensor of shape `(minibatch, n_samples, 3)` denoting the + 3D points whose 2D projections to source views were sampled in + order to generate `feats_sampled` and `masks_sampled`. + + Returns: + feats_aggregated: If `concatenate_output==True`, a tensor + of shape `(minibatch, n_source_views, n_samples, sum(dim_1, ... dim_N))`. + If `concatenate_output==False`, a dictionary `{f_i: t_i_aggregated}` + with each `t_i_aggregated` of shape + `(minibatch, n_source_views, n_samples, dim_i)`. + """ + + if camera is None: + raise ValueError("camera cannot be None for angle weighted aggregation") + + if pts is None: + raise ValueError("Points cannot be None for angle weighted aggregation") + + pts_batch, n_cameras = masks_sampled.shape[:2] + if self.exclude_target_view_mask_features: + feats_sampled = _mask_target_view_features(feats_sampled) + view_sampling_mask = _get_view_sampling_mask( + n_cameras, + pts_batch, + masks_sampled.device, + self.exclude_target_view, + ) + aggr_weights = _get_angular_reduction_weights( + view_sampling_mask, + masks_sampled, + camera, + pts, + self.min_ray_angle_weight, + self.weight_by_ray_angle_gamma, + ) + feats_aggregated = { + k: f * aggr_weights[..., None] for k, f in feats_sampled.items() + } + if self.concatenate_output: + feats_aggregated = torch.cat(tuple(feats_aggregated.values()), dim=-1) + return feats_aggregated + + +def _get_reduction_aggregator_feature_dim( + feats_or_feats_dim: Union[Dict[str, torch.Tensor], int], + reduction_functions: Sequence[ReductionFunction], +): + if isinstance(feats_or_feats_dim, int): + feat_dim = feats_or_feats_dim + else: + feat_dim = int(sum(f.shape[1] for f in feats_or_feats_dim.values())) + if len(reduction_functions) == 0: + return feat_dim + return sum( + _get_reduction_function_output_dim( + reduction_function, + feat_dim, + ) + for reduction_function in reduction_functions + ) + + +def _get_reduction_function_output_dim( + reduction_function: ReductionFunction, + feat_dim: int, +) -> int: + if reduction_function == ReductionFunction.STD_AVG: + return 1 + else: + return feat_dim + + +def _get_view_sampling_mask( + n_cameras: int, + pts_batch: int, + device: Union[str, torch.device], + exclude_target_view: bool, +): + return ( + -torch.eye(n_cameras, device=device, dtype=torch.float32) + * float(exclude_target_view) + + 1.0 + )[:pts_batch] + + +def _mask_target_view_features( + feats_sampled: Dict[str, torch.Tensor], +): + # mask out the sampled features to be sure we dont use them + # anywhere later + one_feature_sampled = next(iter(feats_sampled.values())) + pts_batch, n_cameras = one_feature_sampled.shape[:2] + view_sampling_mask = _get_view_sampling_mask( + n_cameras, + pts_batch, + one_feature_sampled.device, + True, + ) + view_sampling_mask = view_sampling_mask.view( + pts_batch, n_cameras, *([1] * (one_feature_sampled.ndim - 2)) + ) + return {k: f * view_sampling_mask for k, f in feats_sampled.items()} + + +def _get_angular_reduction_weights( + view_sampling_mask: torch.Tensor, + masks_sampled: torch.Tensor, + camera: CamerasBase, + pts: torch.Tensor, + min_ray_angle_weight: float, + weight_by_ray_angle_gamma: float, +): + aggr_weights = masks_sampled.clone()[..., 0] + assert not any(v is None for v in [camera, pts]) + angle_weight = _get_ray_angle_weights( + camera, + pts, + min_ray_angle_weight, + weight_by_ray_angle_gamma, + ) + assert torch.isfinite(angle_weight).all() + # multiply the final aggr weights with ray angles + view_sampling_mask = view_sampling_mask.view( + *view_sampling_mask.shape[:2], *([1] * (aggr_weights.ndim - 2)) + ) + aggr_weights = ( + aggr_weights * angle_weight.reshape_as(aggr_weights) * view_sampling_mask + ) + return aggr_weights + + +def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor): + n_cameras = camera.R.shape[0] + pts_batch = pts.shape[0] + + camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts) + + # does not produce nans randomly unlike get_camera_center() below + cam_centers_rep = -torch.bmm( + # pyre-fixme[29]: + # `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self, + # torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor], + # torch.Tensor, torch.nn.modules.module.Module]` is not a function. + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N... + camera_rep.T[:, None], + # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N... + camera_rep.R.permute(0, 2, 1), + ).reshape(-1, *([1] * (pts.ndim - 2)), 3) + # cam_centers_rep = camera_rep.get_camera_center().reshape( + # -1, *([1]*(pts.ndim - 2)), 3 + # ) + + ray_dirs = F.normalize(pts_rep - cam_centers_rep, dim=-1) + # camera_rep = [ pts_rep = [ + # camera[0] pts[0], + # camera[0] pts[1], + # camera[0] ..., + # ... pts[batch_pts-1], + # camera[1] pts[0], + # camera[1] pts[1], + # camera[1] ..., + # ... pts[batch_pts-1], + # ... ..., + # camera[n_cameras-1] pts[0], + # camera[n_cameras-1] pts[1], + # camera[n_cameras-1] ..., + # ... pts[batch_pts-1], + # ] ] + + ray_dirs_reshape = ray_dirs.view(n_cameras, pts_batch, -1, 3) + # [ + # [pts_0 in cam_0, pts_1 in cam_0, ..., pts_m in cam_0], + # [pts_0 in cam_1, pts_1 in cam_1, ..., pts_m in cam_1], + # ... + # [pts_0 in cam_n, pts_1 in cam_n, ..., pts_m in cam_n], + # ] + + ray_dirs_pts = torch.stack([ray_dirs_reshape[i, i] for i in range(pts_batch)]) + ray_dir_dot_prods = (ray_dirs_pts[None] * ray_dirs_reshape).sum( + dim=-1 + ) # pts_batch x n_cameras x n_pts + + return ray_dir_dot_prods.transpose(0, 1) + + +def _get_ray_angle_weights( + camera: CamerasBase, + pts: torch.Tensor, + min_ray_angle_weight: float, + weight_by_ray_angle_gamma: float, +): + ray_dir_dot_prods = _get_ray_dir_dot_prods( + camera, pts + ) # pts_batch x n_cameras x ... x 3 + angle_weight_01 = ray_dir_dot_prods * 0.5 + 0.5 # [-1, 1] to [0, 1] + angle_weight = (angle_weight_01 + min_ray_angle_weight) ** weight_by_ray_angle_gamma + return angle_weight + + +def _avgmaxstd_reduction_function( + x: torch.Tensor, + w: torch.Tensor, + reduction_functions: Sequence[ReductionFunction], + dim: int = 1, +): + """ + Args: + x: Features to aggreagate. Tensor of shape `(batch, n_views, ..., dim)`. + w: Aggregation weights. Tensor of shape `(batch, n_views, ...,)`. + dim: the dimension along which to aggregate. + reduction_functions: The set of reduction functions. + + Returns: + x_aggr: Aggregation of `x` to a tensor of shape `(batch, 1, ..., dim_aggregate)`. + """ + + pooled_features = [] + + mu = None + std = None + + if ReductionFunction.AVG in reduction_functions: + # average pool + mu = _avg_reduction_function(x, w, dim=dim) + pooled_features.append(mu) + + if ReductionFunction.STD in reduction_functions: + # standard-dev pool + std = _std_reduction_function(x, w, dim=dim, mu=mu) + pooled_features.append(std) + + if ReductionFunction.STD_AVG in reduction_functions: + # average-of-standard-dev pool + stdavg = _std_avg_reduction_function(x, w, dim=dim, mu=mu, std=std) + pooled_features.append(stdavg) + + if ReductionFunction.MAX in reduction_functions: + max_ = _max_reduction_function(x, w, dim=dim) + pooled_features.append(max_) + + # cat all results along the feature dimension (the last dim) + x_aggr = torch.cat(pooled_features, dim=-1) + + # zero out features that were all masked out + any_active = (w.max(dim=dim, keepdim=True).values > 1e-4).type_as(x_aggr) + x_aggr = x_aggr * any_active[..., None] + + # some asserts to check that everything was done right + assert torch.isfinite(x_aggr).all() + assert x_aggr.shape[1] == 1 + + return x_aggr + + +def _avg_reduction_function( + x: torch.Tensor, + w: torch.Tensor, + dim: int = 1, +): + mu = wmean(x, w, dim=dim, eps=1e-2) + return mu + + +def _std_reduction_function( + x: torch.Tensor, + w: torch.Tensor, + dim: int = 1, + mu: Optional[torch.Tensor] = None, # pre-computed mean +): + if mu is None: + mu = _avg_reduction_function(x, w, dim=dim) + std = wmean((x - mu) ** 2, w, dim=dim, eps=1e-2).clamp(1e-4).sqrt() + # FIXME: somehow this is extremely heavy in mem? + return std + + +def _std_avg_reduction_function( + x: torch.Tensor, + w: torch.Tensor, + dim: int = 1, + mu: Optional[torch.Tensor] = None, # pre-computed mean + std: Optional[torch.Tensor] = None, # pre-computed std +): + if std is None: + std = _std_reduction_function(x, w, dim=dim, mu=mu) + stdmean = std.mean(dim=-1, keepdim=True) + return stdmean + + +def _max_reduction_function( + x: torch.Tensor, + w: torch.Tensor, + dim: int = 1, + big_M_factor: float = 10.0, +): + big_M = x.max(dim=dim, keepdim=True).values.abs() * big_M_factor + max_ = (x * w - ((1 - w) * big_M)).max(dim=dim, keepdim=True).values + return max_ diff --git a/pytorch3d/implicitron/models/view_pooling/view_sampling.py b/pytorch3d/implicitron/models/view_pooling/view_sampling.py new file mode 100644 index 00000000..eb8413b5 --- /dev/null +++ b/pytorch3d/implicitron/models/view_pooling/view_sampling.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from pytorch3d.implicitron.tools.config import Configurable +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.renderer.utils import ndc_grid_sample + + +class ViewSampler(Configurable, torch.nn.Module): + """ + Implements sampling of image-based features at the 2d projections of a set + of 3D points. + + Args: + masked_sampling: If `True`, the `sampled_masks` output of `self.forward` + contains the input `masks` sampled at the 2d projections. Otherwise, + all entries of `sampled_masks` are set to 1. + sampling_mode: Controls the mode of the `torch.nn.functional.grid_sample` + function used to interpolate the sampled feature tensors at the + locations of the 2d projections. + """ + + masked_sampling: bool = False + sampling_mode: str = "bilinear" + + def __post_init__(self): + super().__init__() + + def forward( + self, + *, # force kw args + pts: torch.Tensor, + seq_id_pts: Union[List[int], List[str], torch.LongTensor], + camera: CamerasBase, + seq_id_camera: Union[List[int], List[str], torch.LongTensor], + feats: Dict[str, torch.Tensor], + masks: Optional[torch.Tensor], + **kwargs, + ) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + """ + Project each point cloud from a batch of point clouds to corresponding + input cameras and sample features at the 2D projection locations. + + Args: + pts: A tensor of shape `[pts_batch x n_pts x 3]` in world coords. + seq_id_pts: LongTensor of shape `[pts_batch]` denoting the ids of the scenes + from which `pts` were extracted, or a list of string names. + camera: 'n_cameras' cameras, each coresponding to a batch element of `feats`. + seq_id_camera: LongTensor of shape `[n_cameras]` denoting the ids of the scenes + corresponding to cameras in `camera`, or a list of string names. + feats: a dict of tensors of per-image features `{feat_i: T_i}`. + Each tensor `T_i` is of shape `[n_cameras x dim_i x H_i x W_i]`. + masks: `[n_cameras x 1 x H x W]`, define valid image regions + for sampling `feats`. + Returns: + sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`. + Each `sampled_T_i` of shape `[pts_batch, n_cameras, n_pts, dim_i]`. + sampled_masks: A tensor with mask of the sampled features + of shape `(pts_batch, n_cameras, n_pts, 1)`. + """ + + # convert sequence ids to long tensors + seq_id_pts, seq_id_camera = [ + handle_seq_id(seq_id, pts.device) for seq_id in [seq_id_pts, seq_id_camera] + ] + + if self.masked_sampling and masks is None: + raise ValueError( + "Masks have to be provided for `self.masked_sampling==True`" + ) + + # project pts to all cameras and sample feats from the locations of + # the 2D projections + sampled_feats_all_cams, sampled_masks_all_cams = project_points_and_sample( + pts, + feats, + camera, + masks if self.masked_sampling else None, + sampling_mode=self.sampling_mode, + ) + + # generate the mask that invalidates features sampled from + # non-corresponding cameras + camera_pts_mask = (seq_id_camera[None] == seq_id_pts[:, None])[ + ..., None, None + ].to(pts) + + # mask the sampled features and masks + sampled_feats = { + k: f * camera_pts_mask for k, f in sampled_feats_all_cams.items() + } + sampled_masks = sampled_masks_all_cams * camera_pts_mask + + return sampled_feats, sampled_masks + + +def project_points_and_sample( + pts: torch.Tensor, + feats: Dict[str, torch.Tensor], + camera: CamerasBase, + masks: Optional[torch.Tensor], + eps: float = 1e-2, + sampling_mode: str = "bilinear", +) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + """ + Project each point cloud from a batch of point clouds to all input cameras + and sample features at the 2D projection locations. + + Args: + pts: `(pts_batch, n_pts, 3)` tensor containing a batch of 3D point clouds. + feats: A dict `{feat_i: feat_T_i}` of features to sample, + where each `feat_T_i` is a tensor of shape + `(n_cameras, feat_i_dim, feat_i_H, feat_i_W)` + of `feat_i_dim`-dimensional features extracted from `n_cameras` + source views. + camera: A batch of `n_cameras` cameras corresponding to their feature + tensors `feat_T_i` from `feats`. + masks: A tensor of shape `(n_cameras, 1, mask_H, mask_W)` denoting + valid locations for sampling. + eps: A small constant controlling the minimum depth of projections + of `pts` to avoid divisons by zero in the projection operation. + sampling_mode: Sampling mode of the grid sampler. + + Returns: + sampled_feats: Dict of sampled features `{feat_i: sampled_T_i}`. + Each `sampled_T_i` is of shape + `(pts_batch, n_cameras, n_pts, feat_i_dim)`. + sampled_masks: A tensor with the mask of the sampled features + of shape `(pts_batch, n_cameras, n_pts, 1)`. + If `masks` is `None`, the returned `sampled_masks` will be + filled with 1s. + """ + + n_cameras = camera.R.shape[0] + pts_batch = pts.shape[0] + n_pts = pts.shape[1:-1] + + camera_rep, pts_rep = cameras_points_cartesian_product(camera, pts) + + # The eps here is super-important to avoid NaNs in backprop! + proj_rep = camera_rep.transform_points( + pts_rep.reshape(n_cameras * pts_batch, -1, 3), eps=eps + )[..., :2] + # [ pts1 in cam1, pts2 in cam1, pts3 in cam1, + # pts1 in cam2, pts2 in cam2, pts3 in cam2, + # pts1 in cam3, pts2 in cam3, pts3 in cam3 ] + + # reshape for the grid sampler + sampling_grid_ndc = proj_rep.view(n_cameras, pts_batch, -1, 2) + # [ [pts1 in cam1, pts2 in cam1, pts3 in cam1], + # [pts1 in cam2, pts2 in cam2, pts3 in cam2], + # [pts1 in cam3, pts2 in cam3, pts3 in cam3] ] + # n_cameras x pts_batch x n_pts x 2 + + # sample both feats + feats_sampled = { + k: ndc_grid_sample( + f, + sampling_grid_ndc, + mode=sampling_mode, + align_corners=False, + ) + .permute(2, 0, 3, 1) + .reshape(pts_batch, n_cameras, *n_pts, -1) + for k, f in feats.items() + } # {k: pts_batch x n_cameras x *n_pts x dim} for each feat type "k" + + if masks is not None: + # sample masks + masks_sampled = ( + ndc_grid_sample( + masks, + sampling_grid_ndc, + mode=sampling_mode, + align_corners=False, + ) + .permute(2, 0, 3, 1) + .reshape(pts_batch, n_cameras, *n_pts, 1) + ) + else: + masks_sampled = sampling_grid_ndc.new_ones(pts_batch, n_cameras, *n_pts, 1) + + return feats_sampled, masks_sampled + + +def handle_seq_id( + seq_id: Union[torch.LongTensor, List[str], List[int]], + device, +) -> torch.LongTensor: + """ + Converts the input sequence id to a LongTensor. + + Args: + seq_id: A sequence of sequence ids. + device: The target device of the output. + Returns + long_seq_id: `seq_id` converted to a `LongTensor` and moved to `device`. + """ + if not torch.is_tensor(seq_id): + if isinstance(seq_id[0], str): + seq_id = [hash(s) for s in seq_id] + seq_id = torch.tensor(seq_id, dtype=torch.long, device=device) + return seq_id.to(device) + + +def cameras_points_cartesian_product( + camera: CamerasBase, pts: torch.Tensor +) -> Tuple[CamerasBase, torch.Tensor]: + """ + Generates all pairs of pairs of elements from 'camera' and 'pts' and returns + `camera_rep` and `pts_rep` such that: + ``` + camera_rep = [ pts_rep = [ + camera[0] pts[0], + camera[0] pts[1], + camera[0] ..., + ... pts[batch_pts-1], + camera[1] pts[0], + camera[1] pts[1], + camera[1] ..., + ... pts[batch_pts-1], + ... ..., + camera[n_cameras-1] pts[0], + camera[n_cameras-1] pts[1], + camera[n_cameras-1] ..., + ... pts[batch_pts-1], + ] ] + ``` + + Args: + camera: A batch of `n_cameras` cameras. + pts: A batch of `batch_pts` points of shape `(batch_pts, ..., dim)` + + Returns: + camera_rep: A batch of batch_pts*n_cameras cameras such that: + ``` + camera_rep = [ + camera[0] + camera[0] + camera[0] + ... + camera[1] + camera[1] + camera[1] + ... + ... + camera[n_cameras-1] + camera[n_cameras-1] + camera[n_cameras-1] + ] + ``` + + pts_rep: Repeated `pts` of shape `(batch_pts*n_cameras, ..., dim)`, + such that: + ``` + pts_rep = [ + pts[0], + pts[1], + ..., + pts[batch_pts-1], + pts[0], + pts[1], + ..., + pts[batch_pts-1], + ..., + pts[0], + pts[1], + ..., + pts[batch_pts-1], + ] + ``` + """ + n_cameras = camera.R.shape[0] + batch_pts = pts.shape[0] + pts_rep = pts.repeat(n_cameras, *[1 for _ in pts.shape[1:]]) + idx_cams = ( + torch.arange(n_cameras)[:, None] + .expand( + n_cameras, + batch_pts, + ) + .reshape(batch_pts * n_cameras) + ) + camera_rep = camera[idx_cams] + return camera_rep, pts_rep diff --git a/pytorch3d/implicitron/third_party/hyperlayers.py b/pytorch3d/implicitron/third_party/hyperlayers.py new file mode 100644 index 00000000..640a14ed --- /dev/null +++ b/pytorch3d/implicitron/third_party/hyperlayers.py @@ -0,0 +1,254 @@ +# a copy-paste from https://github.com/vsitzmann/scene-representation-networks/blob/master/hyperlayers.py +# fmt: off +# flake8: noqa +'''Pytorch implementations of hyper-network modules. +isort:skip_file +''' +import functools +import torch +import torch.nn as nn + +from . import pytorch_prototyping + + +def partialclass(cls, *args, **kwds): + class NewCls(cls): + __init__ = functools.partialmethod(cls.__init__, *args, **kwds) + + return NewCls + + +class LookupLayer(nn.Module): + def __init__(self, in_ch, out_ch, num_objects): + super().__init__() + + self.out_ch = out_ch + self.lookup_lin = LookupLinear(in_ch, out_ch, num_objects=num_objects) + self.norm_nl = nn.Sequential( + nn.LayerNorm([self.out_ch], elementwise_affine=False), nn.ReLU(inplace=True) + ) + + def forward(self, obj_idx): + net = nn.Sequential(self.lookup_lin(obj_idx), self.norm_nl) + return net + + +class LookupFC(nn.Module): + def __init__( + self, + hidden_ch, + num_hidden_layers, + num_objects, + in_ch, + out_ch, + outermost_linear=False, + ): + super().__init__() + self.layers = nn.ModuleList() + self.layers.append( + LookupLayer(in_ch=in_ch, out_ch=hidden_ch, num_objects=num_objects) + ) + + for i in range(num_hidden_layers): + self.layers.append( + LookupLayer(in_ch=hidden_ch, out_ch=hidden_ch, num_objects=num_objects) + ) + + if outermost_linear: + self.layers.append( + LookupLinear(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects) + ) + else: + self.layers.append( + LookupLayer(in_ch=hidden_ch, out_ch=out_ch, num_objects=num_objects) + ) + + def forward(self, obj_idx): + net = [] + for i in range(len(self.layers)): + net.append(self.layers[i](obj_idx)) + + return nn.Sequential(*net) + + +class LookupLinear(nn.Module): + def __init__(self, in_ch, out_ch, num_objects): + super().__init__() + self.in_ch = in_ch + self.out_ch = out_ch + + self.hypo_params = nn.Embedding(num_objects, in_ch * out_ch + out_ch) + + for i in range(num_objects): + nn.init.kaiming_normal_( + self.hypo_params.weight.data[i, : self.in_ch * self.out_ch].view( + self.out_ch, self.in_ch + ), + a=0.0, + nonlinearity="relu", + mode="fan_in", + ) + self.hypo_params.weight.data[i, self.in_ch * self.out_ch :].fill_(0.0) + + def forward(self, obj_idx): + hypo_params = self.hypo_params(obj_idx) + + # Indices explicit to catch erros in shape of output layer + weights = hypo_params[..., : self.in_ch * self.out_ch] + biases = hypo_params[ + ..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch + ] + + biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) + weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) + + return BatchLinear(weights=weights, biases=biases) + + +class HyperLayer(nn.Module): + """A hypernetwork that predicts a single Dense Layer, including LayerNorm and a ReLU.""" + + def __init__( + self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch + ): + super().__init__() + + self.hyper_linear = HyperLinear( + in_ch=in_ch, + out_ch=out_ch, + hyper_in_ch=hyper_in_ch, + hyper_num_hidden_layers=hyper_num_hidden_layers, + hyper_hidden_ch=hyper_hidden_ch, + ) + self.norm_nl = nn.Sequential( + nn.LayerNorm([out_ch], elementwise_affine=False), nn.ReLU(inplace=True) + ) + + def forward(self, hyper_input): + """ + :param hyper_input: input to hypernetwork. + :return: nn.Module; predicted fully connected network. + """ + return nn.Sequential(self.hyper_linear(hyper_input), self.norm_nl) + + +class HyperFC(nn.Module): + """Builds a hypernetwork that predicts a fully connected neural network.""" + + def __init__( + self, + hyper_in_ch, + hyper_num_hidden_layers, + hyper_hidden_ch, + hidden_ch, + num_hidden_layers, + in_ch, + out_ch, + outermost_linear=False, + ): + super().__init__() + + PreconfHyperLinear = partialclass( + HyperLinear, + hyper_in_ch=hyper_in_ch, + hyper_num_hidden_layers=hyper_num_hidden_layers, + hyper_hidden_ch=hyper_hidden_ch, + ) + PreconfHyperLayer = partialclass( + HyperLayer, + hyper_in_ch=hyper_in_ch, + hyper_num_hidden_layers=hyper_num_hidden_layers, + hyper_hidden_ch=hyper_hidden_ch, + ) + + self.layers = nn.ModuleList() + self.layers.append(PreconfHyperLayer(in_ch=in_ch, out_ch=hidden_ch)) + + for i in range(num_hidden_layers): + self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=hidden_ch)) + + if outermost_linear: + self.layers.append(PreconfHyperLinear(in_ch=hidden_ch, out_ch=out_ch)) + else: + self.layers.append(PreconfHyperLayer(in_ch=hidden_ch, out_ch=out_ch)) + + def forward(self, hyper_input): + """ + :param hyper_input: Input to hypernetwork. + :return: nn.Module; Predicted fully connected neural network. + """ + net = [] + for i in range(len(self.layers)): + net.append(self.layers[i](hyper_input)) + + return nn.Sequential(*net) + + +class BatchLinear(nn.Module): + def __init__(self, weights, biases): + """Implements a batch linear layer. + + :param weights: Shape: (batch, out_ch, in_ch) + :param biases: Shape: (batch, 1, out_ch) + """ + super().__init__() + + self.weights = weights + self.biases = biases + + def __repr__(self): + return "BatchLinear(in_ch=%d, out_ch=%d)" % ( + self.weights.shape[-1], + self.weights.shape[-2], + ) + + def forward(self, input): + output = input.matmul( + self.weights.permute( + *[i for i in range(len(self.weights.shape) - 2)], -1, -2 + ) + ) + output += self.biases + return output + + +def last_hyper_layer_init(m) -> None: + if type(m) == nn.Linear: + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in") + # pyre-fixme[41]: `data` cannot be reassigned. It is a read-only property. + m.weight.data *= 1e-1 + + +class HyperLinear(nn.Module): + """A hypernetwork that predicts a single linear layer (weights & biases).""" + + def __init__( + self, in_ch, out_ch, hyper_in_ch, hyper_num_hidden_layers, hyper_hidden_ch + ): + + super().__init__() + self.in_ch = in_ch + self.out_ch = out_ch + + self.hypo_params = pytorch_prototyping.FCBlock( + in_features=hyper_in_ch, + hidden_ch=hyper_hidden_ch, + num_hidden_layers=hyper_num_hidden_layers, + out_features=(in_ch * out_ch) + out_ch, + outermost_linear=True, + ) + self.hypo_params[-1].apply(last_hyper_layer_init) + + def forward(self, hyper_input): + hypo_params = self.hypo_params(hyper_input) + + # Indices explicit to catch erros in shape of output layer + weights = hypo_params[..., : self.in_ch * self.out_ch] + biases = hypo_params[ + ..., self.in_ch * self.out_ch : (self.in_ch * self.out_ch) + self.out_ch + ] + + biases = biases.view(*(biases.size()[:-1]), 1, self.out_ch) + weights = weights.view(*(weights.size()[:-1]), self.out_ch, self.in_ch) + + return BatchLinear(weights=weights, biases=biases) diff --git a/pytorch3d/implicitron/third_party/pytorch_prototyping.py b/pytorch3d/implicitron/third_party/pytorch_prototyping.py new file mode 100644 index 00000000..0823b953 --- /dev/null +++ b/pytorch3d/implicitron/third_party/pytorch_prototyping.py @@ -0,0 +1,772 @@ +# a copy-paste from https://raw.githubusercontent.com/vsitzmann/pytorch_prototyping/10f49b1e7df38a58fd78451eac91d7ac1a21df64/pytorch_prototyping.py +# fmt: off +# flake8: noqa +'''A number of custom pytorch modules with sane defaults that I find useful for model prototyping. +isort:skip_file +''' +import torch +import torch.nn as nn +import torchvision.utils +from torch.nn import functional as F + + +class FCLayer(nn.Module): + def __init__(self, in_features, out_features): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_features, out_features), + nn.LayerNorm([out_features]), + nn.ReLU(inplace=True), + ) + + def forward(self, input): + return self.net(input) + + +# From https://gist.github.com/wassname/ecd2dac6fc8f9918149853d17e3abf02 +class LayerNormConv2d(nn.Module): + def __init__(self, num_features, eps=1e-5, affine=True): + super().__init__() + self.num_features = num_features + self.affine = affine + self.eps = eps + + if self.affine: + self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) + self.beta = nn.Parameter(torch.zeros(num_features)) + + def forward(self, x): + shape = [-1] + [1] * (x.dim() - 1) + mean = x.view(x.size(0), -1).mean(1).view(*shape) + std = x.view(x.size(0), -1).std(1).view(*shape) + + y = (x - mean) / (std + self.eps) + if self.affine: + shape = [1, -1] + [1] * (x.dim() - 2) + y = self.gamma.view(*shape) * y + self.beta.view(*shape) + return y + + +class FCBlock(nn.Module): + def __init__( + self, + hidden_ch, + num_hidden_layers, + in_features, + out_features, + outermost_linear=False, + ): + super().__init__() + + self.net = [] + self.net.append(FCLayer(in_features=in_features, out_features=hidden_ch)) + + for i in range(num_hidden_layers): + self.net.append(FCLayer(in_features=hidden_ch, out_features=hidden_ch)) + + if outermost_linear: + self.net.append(nn.Linear(in_features=hidden_ch, out_features=out_features)) + else: + self.net.append(FCLayer(in_features=hidden_ch, out_features=out_features)) + + self.net = nn.Sequential(*self.net) + self.net.apply(self.init_weights) + + def __getitem__(self, item): + return self.net[item] + + def init_weights(self, m): + if type(m) == nn.Linear: + nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity="relu", mode="fan_in") + + def forward(self, input): + return self.net(input) + + +class DownBlock3D(nn.Module): + """A 3D convolutional downsampling block.""" + + def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d): + super().__init__() + + self.net = [ + nn.ReplicationPad3d(1), + nn.Conv3d( + in_channels, + out_channels, + kernel_size=4, + padding=0, + stride=2, + bias=False if norm is not None else True, + ), + ] + + if norm is not None: + self.net += [norm(out_channels, affine=True)] + + self.net += [nn.LeakyReLU(0.2, True)] + self.net = nn.Sequential(*self.net) + + def forward(self, x): + return self.net(x) + + +class UpBlock3D(nn.Module): + """A 3D convolutional upsampling block.""" + + def __init__(self, in_channels, out_channels, norm=nn.BatchNorm3d): + super().__init__() + + self.net = [ + nn.ConvTranspose3d( + in_channels, + out_channels, + kernel_size=4, + stride=2, + padding=1, + bias=False if norm is not None else True, + ), + ] + + if norm is not None: + self.net += [norm(out_channels, affine=True)] + + self.net += [nn.ReLU(True)] + self.net = nn.Sequential(*self.net) + + def forward(self, x, skipped=None): + if skipped is not None: + input = torch.cat([skipped, x], dim=1) + else: + input = x + return self.net(input) + + +class Conv3dSame(torch.nn.Module): + """3D convolution that pads to keep spatial dimensions equal. + Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + bias=True, + padding_layer=nn.ReplicationPad3d, + ): + """ + :param in_channels: Number of input channels + :param out_channels: Number of output channels + :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). + :param bias: Whether or not to use bias. + :param padding_layer: Which padding to use. Default is reflection padding. + """ + super().__init__() + ka = kernel_size // 2 + kb = ka - 1 if kernel_size % 2 == 0 else ka + self.net = nn.Sequential( + padding_layer((ka, kb, ka, kb, ka, kb)), + nn.Conv3d(in_channels, out_channels, kernel_size, bias=bias, stride=1), + ) + + def forward(self, x): + return self.net(x) + + +class Conv2dSame(torch.nn.Module): + """2D convolution that pads to keep spatial dimensions equal. + Cannot deal with stride. Only quadratic kernels (=scalar kernel_size). + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + bias=True, + padding_layer=nn.ReflectionPad2d, + ): + """ + :param in_channels: Number of input channels + :param out_channels: Number of output channels + :param kernel_size: Scalar. Spatial dimensions of kernel (only quadratic kernels supported). + :param bias: Whether or not to use bias. + :param padding_layer: Which padding to use. Default is reflection padding. + """ + super().__init__() + ka = kernel_size // 2 + kb = ka - 1 if kernel_size % 2 == 0 else ka + self.net = nn.Sequential( + padding_layer((ka, kb, ka, kb)), + nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias, stride=1), + ) + + self.weight = self.net[1].weight + self.bias = self.net[1].bias + + def forward(self, x): + return self.net(x) + + +class UpBlock(nn.Module): + """A 2d-conv upsampling block with a variety of options for upsampling, and following best practices / with + reasonable defaults. (LeakyReLU, kernel size multiple of stride) + """ + + def __init__( + self, + in_channels, + out_channels, + post_conv=True, + use_dropout=False, + dropout_prob=0.1, + norm=nn.BatchNorm2d, + upsampling_mode="transpose", + ): + """ + :param in_channels: Number of input channels + :param out_channels: Number of output channels + :param post_conv: Whether to have another convolutional layer after the upsampling layer. + :param use_dropout: bool. Whether to use dropout or not. + :param dropout_prob: Float. The dropout probability (if use_dropout is True) + :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. + :param upsampling_mode: Which upsampling mode: + transpose: Upsampling with stride-2, kernel size 4 transpose convolutions. + bilinear: Feature map is upsampled with bilinear upsampling, then a conv layer. + nearest: Feature map is upsampled with nearest neighbor upsampling, then a conv layer. + shuffle: Feature map is upsampled with pixel shuffling, then a conv layer. + """ + super().__init__() + + net = list() + + if upsampling_mode == "transpose": + net += [ + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=4, + stride=2, + padding=1, + bias=True if norm is None else False, + ) + ] + elif upsampling_mode == "bilinear": + net += [nn.UpsamplingBilinear2d(scale_factor=2)] + net += [ + Conv2dSame( + in_channels, + out_channels, + kernel_size=3, + bias=True if norm is None else False, + ) + ] + elif upsampling_mode == "nearest": + net += [nn.UpsamplingNearest2d(scale_factor=2)] + net += [ + Conv2dSame( + in_channels, + out_channels, + kernel_size=3, + bias=True if norm is None else False, + ) + ] + elif upsampling_mode == "shuffle": + net += [nn.PixelShuffle(upscale_factor=2)] + net += [ + Conv2dSame( + in_channels // 4, + out_channels, + kernel_size=3, + bias=True if norm is None else False, + ) + ] + else: + raise ValueError("Unknown upsampling mode!") + + if norm is not None: + net += [norm(out_channels, affine=True)] + + net += [nn.ReLU(True)] + + if use_dropout: + net += [nn.Dropout2d(dropout_prob, False)] + + if post_conv: + net += [ + Conv2dSame( + out_channels, + out_channels, + kernel_size=3, + bias=True if norm is None else False, + ) + ] + + if norm is not None: + net += [norm(out_channels, affine=True)] + + net += [nn.ReLU(True)] + + if use_dropout: + net += [nn.Dropout2d(0.1, False)] + + self.net = nn.Sequential(*net) + + def forward(self, x, skipped=None): + if skipped is not None: + input = torch.cat([skipped, x], dim=1) + else: + input = x + return self.net(input) + + +class DownBlock(nn.Module): + """A 2D-conv downsampling block following best practices / with reasonable defaults + (LeakyReLU, kernel size multiple of stride) + """ + + def __init__( + self, + in_channels, + out_channels, + prep_conv=True, + middle_channels=None, + use_dropout=False, + dropout_prob=0.1, + norm=nn.BatchNorm2d, + ): + """ + :param in_channels: Number of input channels + :param out_channels: Number of output channels + :param prep_conv: Whether to have another convolutional layer before the downsampling layer. + :param middle_channels: If prep_conv is true, this sets the number of channels between the prep and downsampling + convs. + :param use_dropout: bool. Whether to use dropout or not. + :param dropout_prob: Float. The dropout probability (if use_dropout is True) + :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. + """ + super().__init__() + + if middle_channels is None: + middle_channels = in_channels + + net = list() + + if prep_conv: + net += [ + nn.ReflectionPad2d(1), + nn.Conv2d( + in_channels, + middle_channels, + kernel_size=3, + padding=0, + stride=1, + bias=True if norm is None else False, + ), + ] + + if norm is not None: + net += [norm(middle_channels, affine=True)] + + net += [nn.LeakyReLU(0.2, True)] + + if use_dropout: + net += [nn.Dropout2d(dropout_prob, False)] + + net += [ + nn.ReflectionPad2d(1), + nn.Conv2d( + middle_channels, + out_channels, + kernel_size=4, + padding=0, + stride=2, + bias=True if norm is None else False, + ), + ] + + if norm is not None: + net += [norm(out_channels, affine=True)] + + net += [nn.LeakyReLU(0.2, True)] + + if use_dropout: + net += [nn.Dropout2d(dropout_prob, False)] + + self.net = nn.Sequential(*net) + + def forward(self, x): + return self.net(x) + + +class Unet3d(nn.Module): + """A 3d-Unet implementation with sane defaults.""" + + def __init__( + self, + in_channels, + out_channels, + nf0, + num_down, + max_channels, + norm=nn.BatchNorm3d, + outermost_linear=False, + ): + """ + :param in_channels: Number of input channels + :param out_channels: Number of output channels + :param nf0: Number of features at highest level of U-Net + :param num_down: Number of downsampling stages. + :param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage) + :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. + :param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one. + """ + super().__init__() + + assert num_down > 0, "Need at least one downsampling layer in UNet3d." + + # Define the in block + self.in_layer = [Conv3dSame(in_channels, nf0, kernel_size=3, bias=False)] + + if norm is not None: + self.in_layer += [norm(nf0, affine=True)] + + self.in_layer += [nn.LeakyReLU(0.2, True)] + self.in_layer = nn.Sequential(*self.in_layer) + + # Define the center UNet block. The feature map has height and width 1 --> no batchnorm. + self.unet_block = UnetSkipConnectionBlock3d( + int(min(2 ** (num_down - 1) * nf0, max_channels)), + int(min(2 ** (num_down - 1) * nf0, max_channels)), + norm=None, + ) + for i in list(range(0, num_down - 1))[::-1]: + self.unet_block = UnetSkipConnectionBlock3d( + int(min(2 ** i * nf0, max_channels)), + int(min(2 ** (i + 1) * nf0, max_channels)), + submodule=self.unet_block, + norm=norm, + ) + + # Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer + # automatically receives the output of the in_layer and the output of the last unet layer. + self.out_layer = [ + Conv3dSame(2 * nf0, out_channels, kernel_size=3, bias=outermost_linear) + ] + + if not outermost_linear: + if norm is not None: + self.out_layer += [norm(out_channels, affine=True)] + self.out_layer += [nn.ReLU(True)] + self.out_layer = nn.Sequential(*self.out_layer) + + def forward(self, x): + in_layer = self.in_layer(x) + unet = self.unet_block(in_layer) + out_layer = self.out_layer(unet) + return out_layer + + +class UnetSkipConnectionBlock3d(nn.Module): + """Helper class for building a 3D unet.""" + + def __init__(self, outer_nc, inner_nc, norm=nn.BatchNorm3d, submodule=None): + super().__init__() + + if submodule is None: + model = [ + DownBlock3D(outer_nc, inner_nc, norm=norm), + UpBlock3D(inner_nc, outer_nc, norm=norm), + ] + else: + model = [ + DownBlock3D(outer_nc, inner_nc, norm=norm), + submodule, + UpBlock3D(2 * inner_nc, outer_nc, norm=norm), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + forward_passed = self.model(x) + return torch.cat([x, forward_passed], 1) + + +class UnetSkipConnectionBlock(nn.Module): + """Helper class for building a 2D unet.""" + + def __init__( + self, + outer_nc, + inner_nc, + upsampling_mode, + norm=nn.BatchNorm2d, + submodule=None, + use_dropout=False, + dropout_prob=0.1, + ): + super().__init__() + + if submodule is None: + model = [ + DownBlock( + outer_nc, + inner_nc, + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=norm, + ), + UpBlock( + inner_nc, + outer_nc, + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=norm, + upsampling_mode=upsampling_mode, + ), + ] + else: + model = [ + DownBlock( + outer_nc, + inner_nc, + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=norm, + ), + submodule, + UpBlock( + 2 * inner_nc, + outer_nc, + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=norm, + upsampling_mode=upsampling_mode, + ), + ] + + self.model = nn.Sequential(*model) + + def forward(self, x): + forward_passed = self.model(x) + return torch.cat([x, forward_passed], 1) + + +class Unet(nn.Module): + """A 2d-Unet implementation with sane defaults.""" + + def __init__( + self, + in_channels, + out_channels, + nf0, + num_down, + max_channels, + use_dropout, + upsampling_mode="transpose", + dropout_prob=0.1, + norm=nn.BatchNorm2d, + outermost_linear=False, + ): + """ + :param in_channels: Number of input channels + :param out_channels: Number of output channels + :param nf0: Number of features at highest level of U-Net + :param num_down: Number of downsampling stages. + :param max_channels: Maximum number of channels (channels multiply by 2 with every downsampling stage) + :param use_dropout: Whether to use dropout or no. + :param dropout_prob: Dropout probability if use_dropout=True. + :param upsampling_mode: Which type of upsampling should be used. See "UpBlock" for documentation. + :param norm: Which norm to use. If None, no norm is used. Default is Batchnorm with affinity. + :param outermost_linear: Whether the output layer should be a linear layer or a nonlinear one. + """ + super().__init__() + + assert num_down > 0, "Need at least one downsampling layer in UNet." + + # Define the in block + self.in_layer = [ + Conv2dSame( + in_channels, nf0, kernel_size=3, bias=True if norm is None else False + ) + ] + if norm is not None: + self.in_layer += [norm(nf0, affine=True)] + self.in_layer += [nn.LeakyReLU(0.2, True)] + + if use_dropout: + self.in_layer += [nn.Dropout2d(dropout_prob)] + self.in_layer = nn.Sequential(*self.in_layer) + + # Define the center UNet block + self.unet_block = UnetSkipConnectionBlock( + min(2 ** (num_down - 1) * nf0, max_channels), + min(2 ** (num_down - 1) * nf0, max_channels), + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=None, # Innermost has no norm (spatial dimension 1) + upsampling_mode=upsampling_mode, + ) + + for i in list(range(0, num_down - 1))[::-1]: + self.unet_block = UnetSkipConnectionBlock( + min(2 ** i * nf0, max_channels), + min(2 ** (i + 1) * nf0, max_channels), + use_dropout=use_dropout, + dropout_prob=dropout_prob, + submodule=self.unet_block, + norm=norm, + upsampling_mode=upsampling_mode, + ) + + # Define the out layer. Each unet block concatenates its inputs with its outputs - so the output layer + # automatically receives the output of the in_layer and the output of the last unet layer. + self.out_layer = [ + Conv2dSame( + 2 * nf0, + out_channels, + kernel_size=3, + bias=outermost_linear or (norm is None), + ) + ] + + if not outermost_linear: + if norm is not None: + self.out_layer += [norm(out_channels, affine=True)] + self.out_layer += [nn.ReLU(True)] + + if use_dropout: + self.out_layer += [nn.Dropout2d(dropout_prob)] + self.out_layer = nn.Sequential(*self.out_layer) + + self.out_layer_weight = self.out_layer[0].weight + + def forward(self, x): + in_layer = self.in_layer(x) + unet = self.unet_block(in_layer) + out_layer = self.out_layer(unet) + return out_layer + + +class Identity(nn.Module): + """Helper module to allow Downsampling and Upsampling nets to default to identity if they receive an empty list.""" + + def __init__(self): + super().__init__() + + def forward(self, input): + return input + + +class DownsamplingNet(nn.Module): + """A subnetwork that downsamples a 2D feature map with strided convolutions.""" + + def __init__( + self, + per_layer_out_ch, + in_channels, + use_dropout, + dropout_prob=0.1, + last_layer_one=False, + norm=nn.BatchNorm2d, + ): + """ + :param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of + list defines number of downsampling steps (each step dowsamples by factor of 2.) + :param in_channels: Number of input channels. + :param use_dropout: Whether or not to use dropout. + :param dropout_prob: Dropout probability. + :param last_layer_one: Whether the output of the last layer will have a spatial size of 1. In that case, + the last layer will not have batchnorm, else, it will. + :param norm: Which norm to use. Defaults to BatchNorm. + """ + super().__init__() + + if not len(per_layer_out_ch): + self.downs = Identity() + else: + self.downs = list() + self.downs.append( + DownBlock( + in_channels, + per_layer_out_ch[0], + use_dropout=use_dropout, + dropout_prob=dropout_prob, + middle_channels=per_layer_out_ch[0], + norm=norm, + ) + ) + for i in range(0, len(per_layer_out_ch) - 1): + if last_layer_one and (i == len(per_layer_out_ch) - 2): + norm = None + self.downs.append( + DownBlock( + per_layer_out_ch[i], + per_layer_out_ch[i + 1], + dropout_prob=dropout_prob, + use_dropout=use_dropout, + norm=norm, + ) + ) + self.downs = nn.Sequential(*self.downs) + + def forward(self, input): + return self.downs(input) + + +class UpsamplingNet(nn.Module): + """A subnetwork that upsamples a 2D feature map with a variety of upsampling options.""" + + def __init__( + self, + per_layer_out_ch, + in_channels, + upsampling_mode, + use_dropout, + dropout_prob=0.1, + first_layer_one=False, + norm=nn.BatchNorm2d, + ): + """ + :param per_layer_out_ch: python list of integers. Defines the number of output channels per layer. Length of + list defines number of upsampling steps (each step upsamples by factor of 2.) + :param in_channels: Number of input channels. + :param upsampling_mode: Mode of upsampling. For documentation, see class "UpBlock" + :param use_dropout: Whether or not to use dropout. + :param dropout_prob: Dropout probability. + :param first_layer_one: Whether the input to the last layer will have a spatial size of 1. In that case, + the first layer will not have a norm, else, it will. + :param norm: Which norm to use. Defaults to BatchNorm. + """ + super().__init__() + + if not len(per_layer_out_ch): + self.ups = Identity() + else: + self.ups = list() + self.ups.append( + UpBlock( + in_channels, + per_layer_out_ch[0], + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=None if first_layer_one else norm, + upsampling_mode=upsampling_mode, + ) + ) + for i in range(0, len(per_layer_out_ch) - 1): + self.ups.append( + UpBlock( + per_layer_out_ch[i], + per_layer_out_ch[i + 1], + use_dropout=use_dropout, + dropout_prob=dropout_prob, + norm=norm, + upsampling_mode=upsampling_mode, + ) + ) + self.ups = nn.Sequential(*self.ups) + + def forward(self, input): + return self.ups(input) diff --git a/pytorch3d/implicitron/tools/__init__.py b/pytorch3d/implicitron/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pytorch3d/implicitron/tools/camera_utils.py b/pytorch3d/implicitron/tools/camera_utils.py new file mode 100644 index 00000000..3148adf9 --- /dev/null +++ b/pytorch3d/implicitron/tools/camera_utils.py @@ -0,0 +1,142 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# TODO: all this potentially goes to PyTorch3D + +import math +from typing import Tuple + +import pytorch3d as pt3d +import torch +from pytorch3d.renderer.cameras import CamerasBase + + +def jitter_extrinsics( + R: torch.Tensor, + T: torch.Tensor, + max_angle: float = (math.pi * 2.0), + translation_std: float = 1.0, + scale_std: float = 0.3, +): + """ + Jitter the extrinsic camera parameters `R` and `T` with a random similarity + transformation. The transformation rotates by a random angle between [0, max_angle]; + scales by a random factor exp(N(0, scale_std)), where N(0, scale_std) is + a random sample from a normal distrubtion with zero mean and variance scale_std; + and translates by a 3D offset sampled from N(0, translation_std). + """ + assert all(x >= 0.0 for x in (max_angle, translation_std, scale_std)) + N = R.shape[0] + R_jit = pt3d.transforms.random_rotations(1, device=R.device) + R_jit = pt3d.transforms.so3_exponential_map( + pt3d.transforms.so3_log_map(R_jit) * max_angle + ) + T_jit = torch.randn_like(R_jit[:1, :, 0]) * translation_std + rigid_transform = pt3d.ops.eyes(dim=4, N=N, device=R.device) + rigid_transform[:, :3, :3] = R_jit.expand(N, 3, 3) + rigid_transform[:, 3, :3] = T_jit.expand(N, 3) + scale_jit = torch.exp(torch.randn_like(T_jit[:, 0]) * scale_std).expand(N) + return apply_camera_alignment(R, T, rigid_transform, scale_jit) + + +def apply_camera_alignment( + R: torch.Tensor, + T: torch.Tensor, + rigid_transform: torch.Tensor, + scale: torch.Tensor, +): + """ + Args: + R: Camera rotation matrix of shape (N, 3, 3). + T: Camera translation of shape (N, 3). + rigid_transform: A tensor of shape (N, 4, 4) representing a batch of + N 4x4 tensors that map the scene pointcloud from misaligned coords + to the aligned space. + scale: A list of N scaling factors. A tensor of shape (N,) + + Returns: + R_aligned: The aligned rotations R. + T_aligned: The aligned translations T. + """ + R_rigid = rigid_transform[:, :3, :3] + T_rigid = rigid_transform[:, 3:, :3] + R_aligned = R_rigid.permute(0, 2, 1).bmm(R) + T_aligned = scale[:, None] * (T - (T_rigid @ R_aligned)[:, 0]) + return R_aligned, T_aligned + + +def get_min_max_depth_bounds(cameras, scene_center, scene_extent): + """ + Estimate near/far depth plane as: + near = dist(cam_center, self.scene_center) - self.scene_extent + far = dist(cam_center, self.scene_center) + self.scene_extent + """ + cam_center = cameras.get_camera_center() + center_dist = ( + ((cam_center - scene_center.to(cameras.R)[None]) ** 2) + .sum(dim=-1) + .clamp(0.001) + .sqrt() + ) + center_dist = center_dist.clamp(scene_extent + 1e-3) + min_depth = center_dist - scene_extent + max_depth = center_dist + scene_extent + return min_depth, max_depth + + +def volumetric_camera_overlaps( + cameras: CamerasBase, + scene_extent: float = 8.0, + scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0), + resol: int = 16, + weigh_by_ray_angle: bool = True, +): + """ + Compute the overlaps between viewing frustrums of all pairs of cameras + in `cameras`. + """ + device = cameras.device + ba = cameras.R.shape[0] + n_vox = int(resol ** 3) + grid = pt3d.structures.Volumes( + densities=torch.zeros([1, 1, resol, resol, resol], device=device), + volume_translation=-torch.FloatTensor(scene_center)[None].to(device), + voxel_size=2.0 * scene_extent / resol, + ).get_coord_grid(world_coordinates=True) + + grid = grid.view(1, n_vox, 3).expand(ba, n_vox, 3) + gridp = cameras.transform_points(grid, eps=1e-2) + proj_in_camera = ( + torch.prod((gridp[..., :2].abs() <= 1.0), dim=-1) + * (gridp[..., 2] > 0.0).float() + ) # ba x n_vox + + if weigh_by_ray_angle: + rays = torch.nn.functional.normalize( + grid - cameras.get_camera_center()[:, None], dim=-1 + ) + rays_masked = rays * proj_in_camera[..., None] + + # - slow and readable: + # inter = torch.zeros(ba, ba) + # for i1 in range(ba): + # for i2 in range(ba): + # inter[i1, i2] = ( + # 1 + (rays_masked[i1] * rays_masked[i2] + # ).sum(dim=-1)).sum() + + # - fast: + rays_masked = rays_masked.view(ba, n_vox * 3) + inter = n_vox + (rays_masked @ rays_masked.t()) + + else: + inter = proj_in_camera @ proj_in_camera.t() + + mass = torch.diag(inter) + iou = inter / (mass[:, None] + mass[None, :] - inter).clamp(0.1) + + return iou diff --git a/pytorch3d/implicitron/tools/circle_fitting.py b/pytorch3d/implicitron/tools/circle_fitting.py new file mode 100644 index 00000000..092a8ee2 --- /dev/null +++ b/pytorch3d/implicitron/tools/circle_fitting.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from dataclasses import dataclass +from math import pi +from typing import Optional + +import torch +from pytorch3d.common.compat import eigh, lstsq + + +def _get_rotation_to_best_fit_xy( + points: torch.Tensor, centroid: torch.Tensor +) -> torch.Tensor: + """ + Returns a rotation r such that points @ r has a best fit plane + parallel to the xy plane + + Args: + points: (N, 3) tensor of points in 3D + centroid: (3,) their centroid + + Returns: + (3,3) tensor rotation matrix + """ + points_centered = points - centroid[None] + return eigh(points_centered.t() @ points_centered)[1][:, [1, 2, 0]] + + +def _signed_area(path: torch.Tensor) -> torch.Tensor: + """ + Calculates the signed area / Lévy area of a 2D path. If the path is closed, + i.e. ends where it starts, this is the integral of the winding number over + the whole plane. If not, consider a closed path made by adding a straight + line from the end to the start; the signed area is the integral of the + winding number (also over the plane) with respect to that closed path. + + If this number is positive, it indicates in some sense that the path + turns anticlockwise more than clockwise, and vice versa. + + Args: + path: N x 2 tensor of points. + + Returns: + signed area, shape () + """ + # This calculation is a sum of areas of triangles of the form + # (path[0], path[i], path[i+1]), where each triangle is half a + # parallelogram. + x, y = (path[1:] - path[:1]).unbind(1) + return (y[1:] * x[:-1] - x[1:] * y[:-1]).sum() * 0.5 + + +@dataclass(frozen=True) +class Circle2D: + """ + Contains details of a circle in a plane. + Members + center: tensor shape (2,) + radius: tensor shape () + generated_points: points around the circle, shape (n_points, 2) + """ + + center: torch.Tensor + radius: torch.Tensor + generated_points: torch.Tensor + + +def fit_circle_in_2d( + points2d, *, n_points: int = 0, angles: Optional[torch.Tensor] = None +) -> Circle2D: + """ + Simple best fitting of a circle to 2D points. In particular, the circle which + minimizes the sum of the squares of the squared-distances to the circle. + + Finds (a,b) and r to minimize the sum of squares (over the x,y pairs) of + r**2 - [(x-a)**2+(y-b)**2] + i.e. + (2*a)*x + (2*b)*y + (r**2 - a**2 - b**2)*1 - (x**2 + y**2) + + In addition, generates points along the circle. If angles is None (default) + then n_points around the circle equally spaced are given. These begin at the + point closest to the first input point. They continue in the direction which + seems to match the movement of points in points2d, as judged by its + signed area. If `angles` are provided, then n_points is ignored, and points + along the circle at the given angles are returned, with the starting point + and direction as before. + + (Note that `generated_points` is affected by the order of the points in + points2d, but the other outputs are not.) + + Args: + points2d: N x 2 tensor of 2D points + n_points: number of points to generate on the circle, if angles not given + angles: optional angles in radians of points to generate. + + Returns: + Circle2D object + """ + design = torch.cat([points2d, torch.ones_like(points2d[:, :1])], dim=1) + rhs = (points2d ** 2).sum(1) + n_provided = points2d.shape[0] + if n_provided < 3: + raise ValueError(f"{n_provided} points are not enough to determine a circle") + solution = lstsq(design, rhs) + center = solution[:2] / 2 + radius = torch.sqrt(solution[2] + (center ** 2).sum()) + if n_points > 0: + if angles is not None: + warnings.warn("n_points ignored because angles provided") + else: + angles = torch.linspace(0, 2 * pi, n_points, device=points2d.device) + + if angles is not None: + initial_direction_xy = (points2d[0] - center).unbind() + initial_angle = torch.atan2(initial_direction_xy[1], initial_direction_xy[0]) + with torch.no_grad(): + anticlockwise = _signed_area(points2d) > 0 + if anticlockwise: + use_angles = initial_angle + angles + else: + use_angles = initial_angle - angles + generated_points = center[None] + radius * torch.stack( + [torch.cos(use_angles), torch.sin(use_angles)], dim=-1 + ) + else: + generated_points = points2d.new_zeros(0, 2) + return Circle2D(center=center, radius=radius, generated_points=generated_points) + + +@dataclass(frozen=True) +class Circle3D: + """ + Contains details of a circle in 3D. + Members + center: tensor shape (3,) + radius: tensor shape () + normal: tensor shape (3,) + generated_points: points around the circle, shape (n_points, 3) + """ + + center: torch.Tensor + radius: torch.Tensor + normal: torch.Tensor + generated_points: torch.Tensor + + +def fit_circle_in_3d( + points, + *, + n_points: int = 0, + angles: Optional[torch.Tensor] = None, + offset: Optional[torch.Tensor] = None, + up: Optional[torch.Tensor] = None, +) -> Circle3D: + """ + Simple best fit circle to 3D points. Uses circle_2d in the + least-squares best fit plane. + + In addition, generates points along the circle. If angles is None (default) + then n_points around the circle equally spaced are given. These begin at the + point closest to the first input point. They continue in the direction which + seems to be match the movement of points. If angles is provided, then n_points + is ignored, and points along the circle at the given angles are returned, + with the starting point and direction as before. + + Further, an offset can be given to add to the generated points; this is + interpreted in a rotated coordinate system where (0, 0, 1) is normal to the + circle, specifically the normal which is approximately in the direction of a + given `up` vector. The remaining rotation is disambiguated in an unspecified + but deterministic way. + + (Note that `generated_points` is affected by the order of the points in + points, but the other outputs are not.) + + Args: + points2d: N x 3 tensor of 3D points + n_points: number of points to generate on the circle + angles: optional angles in radians of points to generate. + offset: optional tensor (3,), a displacement expressed in a "canonical" + coordinate system to add to the generated points. + up: optional tensor (3,), a vector which helps define the + "canonical" coordinate system for interpretting `offset`. + Required if offset is used. + + + Returns: + Circle3D object + """ + centroid = points.mean(0) + r = _get_rotation_to_best_fit_xy(points, centroid) + normal = r[:, 2] + rotated_points = (points - centroid) @ r + result_2d = fit_circle_in_2d( + rotated_points[:, :2], n_points=n_points, angles=angles + ) + center_3d = result_2d.center @ r[:, :2].t() + centroid + n_generated_points = result_2d.generated_points.shape[0] + if n_generated_points > 0: + generated_points_in_plane = torch.cat( + [ + result_2d.generated_points, + torch.zeros_like(result_2d.generated_points[:, :1]), + ], + dim=1, + ) + if offset is not None: + if up is None: + raise ValueError("Missing `up` input for interpreting offset") + with torch.no_grad(): + swap = torch.dot(up, normal) < 0 + if swap: + # We need some rotation which takes +z to -z. Here's one. + generated_points_in_plane += offset * offset.new_tensor([1, -1, -1]) + else: + generated_points_in_plane += offset + + generated_points = generated_points_in_plane @ r.t() + centroid + else: + generated_points = points.new_zeros(0, 3) + + return Circle3D( + radius=result_2d.radius, + center=center_3d, + normal=normal, + generated_points=generated_points, + ) diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py new file mode 100644 index 00000000..edbec7d3 --- /dev/null +++ b/pytorch3d/implicitron/tools/config.py @@ -0,0 +1,714 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import copy +import dataclasses +import inspect +import warnings +from collections import Counter, defaultdict +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, cast + +from omegaconf import DictConfig, OmegaConf, open_dict + + +""" +This functionality allows a configurable system to be determined in a dataclass-type +way. It is a generalization of omegaconf's "structured", in the dataclass case. +Core functionality: + +- Configurable -- A base class used to label a class as being one which uses this + system. Uses class members and __post_init__ like a dataclass. + +- expand_args_fields -- Expands a class like `dataclasses.dataclass`. Runs automatically. + +- get_default_args -- gets an omegaconf.DictConfig for initializing + a given class or calling a given function. + +- run_auto_creation -- Initialises nested members. To be called in __post_init__. + + +In addition, a Configurable may contain members whose type is decided at runtime. + +- ReplaceableBase -- As a base instead of Configurable, labels a class to say that + any child class can be used instead. + +- registry -- A global store of named child classes of ReplaceableBase classes. + Used as `@registry.register` decorator on class definition. + + +Additional utility functions: + +- remove_unused_components -- used for simplifying a DictConfig instance. +- get_default_args_field -- default for DictConfig member of another configurable. + + +1. The simplest usage of this functionality is as follows. First a schema is defined +in dataclass style. + + class A(Configurable): + n: int = 9 + + class B(Configurable): + a: A + + def __post_init__(self): + run_auto_creation(self) + +It can be used like + + b_args = get_default_args(B) + b = B(**b_args) + +In this case, get_default_args(B) returns an omegaconf.DictConfig with the right +members {"a_args": {"n": 9}}. It also modifies the definitions of the classes to +something like the following. (The modification itself is done by the function +`expand_args_fields`, which is called inside `get_default_args`.) + + @dataclasses.dataclass + class A: + n: int = 9 + + @dataclasses.dataclass + class B: + a_args: DictConfig = dataclasses.field(default_factory=lambda: DictConfig({"n": 9})) + + def __post_init__(self): + self.a = A(**self.a_args) + +2. Pluggability. Instead of a dataclass-style member being given a concrete class, +you can give a base class and the implementation is looked up by name in the global +`registry` in this module. E.g. + + class A(ReplaceableBase): + k: int = 1 + + @registry.register + class A1(A): + m: int = 3 + + @registry.register + class A2(A): + n: str = "2" + + class B(Configurable): + a: A + a_class_type: str = "A2" + + def __post_init__(self): + run_auto_creation(self) + +will expand to + + @dataclasses.dataclass + class A: + k: int = 1 + + @dataclasses.dataclass + class A1(A): + m: int = 3 + + @dataclasses.dataclass + class A2(A): + n: str = "2" + + @dataclasses.dataclass + class B: + a_class_type: str = "A2" + a_A1_args: DictConfig = dataclasses.field( + default_factory=lambda: DictConfig({"k": 1, "m": 3} + ) + a_A2_args: DictConfig = dataclasses.field( + default_factory=lambda: DictConfig({"k": 1, "m": 3} + ) + + def __post_init__(self): + if self.a_class_type == "A1": + self.a = A1(**self.a_A1_args) + elif self.a_class_type == "A2": + self.a = A2(**self.a_A2_args) + else: + raise ValueError(...) + +3. Aside from these classes, the members of these classes should be things +which DictConfig is happy with: e.g. (bool, int, str, None, float) and what +can be built from them with DictConfigs and lists of them. + +In addition, you can call get_default_args on a function or class to get +the DictConfig of its defaulted arguments, assuming those are all things +which DictConfig is happy with. If you want to use such a thing as a member +of another configured class, `get_default_args_field` is a helper. +""" + + +_unprocessed_warning: str = ( + " must be processed before it can be used." + + " This is done by calling expand_args_fields " + + "or get_default_args on it." +) + +TYPE_SUFFIX: str = "_class_type" +ARGS_SUFFIX: str = "_args" + + +class ReplaceableBase: + """ + Base class for dataclass-style classes which + can be stored in the registry. + """ + + def __new__(cls, *args, **kwargs): + """ + This function only exists to raise a + warning if class construction is attempted + without processing. + """ + obj = super().__new__(cls) + if cls is not ReplaceableBase and not _is_actually_dataclass(cls): + warnings.warn(cls.__name__ + _unprocessed_warning) + return obj + + +class Configurable: + """ + This indicates a class which is not ReplaceableBase + but still needs to be + expanded into a dataclass with expand_args_fields. + This expansion is delayed. + """ + + def __new__(cls, *args, **kwargs): + """ + This function only exists to raise a + warning if class construction is attempted + without processing. + """ + obj = super().__new__(cls) + if cls is not Configurable and not _is_actually_dataclass(cls): + warnings.warn(cls.__name__ + _unprocessed_warning) + return obj + + +_X = TypeVar("X", bound=ReplaceableBase) + + +class _Registry: + """ + Register from names to classes. In particular, we say that direct subclasses of + ReplaceableBase are "base classes" and we register subclasses of each base class + in a separate namespace. + """ + + def __init__(self) -> None: + self._mapping: Dict[ + Type[ReplaceableBase], Dict[str, Type[ReplaceableBase]] + ] = defaultdict(dict) + + def register(self, some_class: Type[_X]) -> Type[_X]: + """ + A class decorator, to register a class in self. + """ + name = some_class.__name__ + self._register(some_class, name=name) + return some_class + + def _register( + self, + some_class: Type[ReplaceableBase], + *, + base_class: Optional[Type[ReplaceableBase]] = None, + name: str, + ) -> None: + """ + Register a new member. + + Args: + cls: the new member + base_class: (optional) what the new member is a type for + name: name for the new member + """ + if base_class is None: + base_class = self._base_class_from_class(some_class) + if base_class is None: + raise ValueError( + f"Cannot register {some_class}. Cannot tell what it is." + ) + if some_class is base_class: + raise ValueError(f"Attempted to register the base class {some_class}") + self._mapping[base_class][name] = some_class + + def get( + self, base_class_wanted: Type[ReplaceableBase], name: str + ) -> Type[ReplaceableBase]: + """ + Retrieve a class from the registry by name + + Args: + base_class_wanted: parent type of type we are looking for. + It determines the namespace. + This will typically be a direct subclass of ReplaceableBase. + name: what to look for + + Returns: + class type + """ + if self._is_base_class(base_class_wanted): + base_class = base_class_wanted + else: + base_class = self._base_class_from_class(base_class_wanted) + if base_class is None: + raise ValueError( + f"Cannot look up {base_class_wanted}. Cannot tell what it is." + ) + result = self._mapping[base_class].get(name) + if result is None: + raise ValueError(f"{name} has not been registered.") + if not issubclass(result, base_class_wanted): + raise ValueError( + f"{name} resolves to {result} which does not subclass {base_class_wanted}" + ) + return result + + def get_all( + self, base_class_wanted: Type[ReplaceableBase] + ) -> List[Type[ReplaceableBase]]: + """ + Retrieve all registered implementations from the registry + + Args: + base_class_wanted: parent type of type we are looking for. + It determines the namespace. + This will typically be a direct subclass of ReplaceableBase. + Returns: + list of class types + """ + if self._is_base_class(base_class_wanted): + return list(self._mapping[base_class_wanted].values()) + + base_class = self._base_class_from_class(base_class_wanted) + if base_class is None: + raise ValueError( + f"Cannot look up {base_class_wanted}. Cannot tell what it is." + ) + return [ + class_ + for class_ in self._mapping[base_class].values() + if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted + ] + + @staticmethod + def _is_base_class(some_class: Type[ReplaceableBase]) -> bool: + """ + Return whether the given type is a direct subclass of ReplaceableBase + and so gets used as a namespace. + """ + return ReplaceableBase in some_class.__bases__ + + @staticmethod + def _base_class_from_class( + some_class: Type[ReplaceableBase], + ) -> Optional[Type[ReplaceableBase]]: + """ + Find the parent class of some_class which inherits ReplaceableBase, or None + """ + for base in some_class.mro()[-3::-1]: + if base is not ReplaceableBase and issubclass(base, ReplaceableBase): + return base + return None + + +# Global instance of the registry +registry = _Registry() + + +def _default_create(name: str, type_: Type, pluggable: bool) -> Callable[[Any], None]: + """ + Return the default creation function for a member. This is a function which + could be called in __post_init__ to initialise the member, and will be called + from run_auto_creation. + + Args: + name: name of the member + type_: declared type of the member + pluggable: True if the member's declared type inherits ReplaceableBase, + in which case the actual type to be created is decided at + runtime. + + Returns: + Function taking one argument, the object whose member should be + initialized. + """ + + def inner(self): + expand_args_fields(type_) + args = getattr(self, name + ARGS_SUFFIX) + setattr(self, name, type_(**args)) + + def inner_pluggable(self): + type_name = getattr(self, name + TYPE_SUFFIX) + chosen_class = registry.get(type_, type_name) + if self._known_implementations.get(type_name, chosen_class) is not chosen_class: + # If this warning is raised, it means that a new definition of + # the chosen class has been registered since our class was processed + # (i.e. expanded). A DictConfig which comes from our get_default_args + # (which might have triggered the processing) will contain the old default + # values for the members of the chosen class. Changes to those defaults which + # were made in the redefinition will not be reflected here. + warnings.warn(f"New implementation of {type_name} is being chosen.") + expand_args_fields(chosen_class) + args = getattr(self, f"{name}_{type_name}{ARGS_SUFFIX}") + setattr(self, name, chosen_class(**args)) + + return inner_pluggable if pluggable else inner + + +def run_auto_creation(self: Any) -> None: + """ + Run all the functions named in self._creation_functions. + """ + for create_function in self._creation_functions: + getattr(self, create_function)() + + +def _is_configurable_class(C) -> bool: + return isinstance(C, type) and issubclass(C, (Configurable, ReplaceableBase)) + + +def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig: + """ + Get the DictConfig of args to call C - which might be a type or a function. + + If C is a subclass of Configurable or ReplaceableBase, we make sure + it has been processed with expand_args_fields. If C is a dataclass, + including a subclass of Configurable or ReplaceableBase, the output + will be a typed DictConfig. + + Args: + C: the class or function to be processed + _do_not_process: (internal use) When this function is called from + expand_args_fields, we specify any class currently being + processed, to make sure we don't try to process a class + while it is already being processed. + + Returns: + new DictConfig object + """ + if C is None: + return DictConfig({}) + + if _is_configurable_class(C): + if C in _do_not_process: + raise ValueError( + f"Internal recursion error. Need processed {C}," + f" but cannot get it. _do_not_process={_do_not_process}" + ) + # This is safe to run multiple times. It will return + # straight away if C has already been processed. + expand_args_fields(C, _do_not_process=_do_not_process) + + kwargs = {} + if dataclasses.is_dataclass(C): + # Note that if get_default_args_field is used somewhere in C, + # this call is recursive. No special care is needed, + # because in practice get_default_args_field is used for + # separate types than the outer type. + out = OmegaConf.structured(C) + exclude = getattr(C, "_processed_members", ()) + with open_dict(out): + for field in exclude: + out.pop(field, None) + return out + + if _is_configurable_class(C): + raise ValueError(f"Failed to process {C}") + + # returns dict of keyword args of a callable C + sig = inspect.signature(C) + for pname, defval in dict(sig.parameters).items(): + if defval.default == inspect.Parameter.empty: + # print('skipping %s' % pname) + continue + else: + kwargs[pname] = copy.deepcopy(defval.default) + + return DictConfig(kwargs) + + +def _is_actually_dataclass(some_class) -> bool: + # Return whether the class some_class has been processed with + # the dataclass annotation. This is more specific than + # dataclasses.is_dataclass which returns True on anything + # deriving from a dataclass. + + # Checking for __init__ would also work for our purpose. + return "__dataclass_fields__" in some_class.__dict__ + + +def expand_args_fields( + some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = () +) -> Type[_X]: + """ + This expands a class which inherits Configurable or ReplaceableBase classes, + including dataclass processing. some_class is modified in place by this function. + For classes of type ReplaceableBase, you can add some_class to the registry before + or after calling this function. But potential inner classes need to be registered + before this function is run on the outer class. + + The transformations this function makes, before the concluding + dataclasses.dataclass, are as follows. if X is a base class with registered + subclasses Y and Z, replace + + x: X + + and optionally + + x_class_type: str = "Y" + def create_x(self):... + + with + + x_Y_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + x_Z_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + def create_x(self): + self.x = registry.get(X, self.x_class_type)( + **self.getattr(f"x_{self.x_class_type}_args) + ) + x_class_type: str = "UNDEFAULTED" + + without adding the optional things if they are already there. + + Similarly, if X is a subclass of Configurable, + + x: X + + and optionally + + def create_x(self):... + + will be replaced with + + x_args : DictConfig = dataclasses.field(default_factory=lambda: DictConfig()) + def create_x(self): + self.x = X(self.x_args) + + Also adds the following class members, unannotated so that dataclass + ignores them. + - _creation_functions: Tuple[str] of all the create_ functions, + including those from base classes. + - _known_implementations: Dict[str, Type] containing the classes which + have been found from the registry. + (used only to raise a warning if it one has been overwritten) + - _processed_members: a Set[str] of all the members which have been transformed. + + Args: + some_class: the class to be processed + _do_not_process: Internal use for get_default_args: Because get_default_args calls + and is called by this function, we let it specify any class currently + being processed, to make sure we don't try to process a class while + it is already being processed. + + + Returns: + some_class itself, which has been modified in place. This + allows this function to be used as a class decorator. + """ + if _is_actually_dataclass(some_class): + return some_class + + # The functions this class's run_auto_creation will run. + creation_functions: List[str] = [] + # The classes which this type knows about from the registry + # We could use a weakref.WeakValueDictionary here which would mean + # that we don't warn if the class we should have expected is elsewhere + # unused. + known_implementations: Dict[str, Type] = {} + # Names of members which have been processed. + processed_members: Set[str] = set() + + # For all bases except ReplaceableBase and Configurable and object, + # we need to process them before our own processing. This is + # because dataclasses expect to inherit dataclasses and not unprocessed + # dataclasses. + for base in some_class.mro()[-3:0:-1]: + if base is ReplaceableBase: + continue + if base is Configurable: + continue + if not issubclass(base, (Configurable, ReplaceableBase)): + continue + expand_args_fields(base, _do_not_process=_do_not_process) + if "_creation_functions" in base.__dict__: + creation_functions.extend(base._creation_functions) + if "_known_implementations" in base.__dict__: + known_implementations.update(base._known_implementations) + if "_processed_members" in base.__dict__: + processed_members.update(base._processed_members) + + to_process: List[Tuple[str, Type, bool]] = [] + if "__annotations__" in some_class.__dict__: + for name, type_ in some_class.__annotations__.items(): + if not isinstance(type_, type): + # type_ could be something like typing.Tuple + continue + if ( + issubclass(type_, ReplaceableBase) + and ReplaceableBase in type_.__bases__ + ): + to_process.append((name, type_, True)) + elif issubclass(type_, Configurable): + to_process.append((name, type_, False)) + + for name, type_, pluggable in to_process: + _process_member( + name=name, + type_=type_, + pluggable=pluggable, + some_class=cast(type, some_class), + creation_functions=creation_functions, + _do_not_process=_do_not_process, + known_implementations=known_implementations, + ) + processed_members.add(name) + + for key, count in Counter(creation_functions).items(): + if count > 1: + warnings.warn(f"Clash with {key} in a base class.") + some_class._creation_functions = tuple(creation_functions) + some_class._processed_members = processed_members + some_class._known_implementations = known_implementations + + dataclasses.dataclass(eq=False)(some_class) + return some_class + + +def get_default_args_field(C, *, _do_not_process: Tuple[type, ...] = ()): + """ + Get a dataclass field which defaults to get_default_args(...) + + Args: + As for get_default_args. + + Returns: + function to return new DictConfig object + """ + + def create(): + return get_default_args(C, _do_not_process=_do_not_process) + + return dataclasses.field(default_factory=create) + + +def _process_member( + *, + name: str, + type_: Type, + pluggable: bool, + some_class: Type, + creation_functions: List[str], + _do_not_process: Tuple[type, ...], + known_implementations: Dict[str, Type], +) -> None: + """ + Make the modification (of expand_args_fields) to some_class for a single member. + + Args: + name: member name + type_: member declared type + plugglable: whether member has dynamic type + some_class: (MODIFIED IN PLACE) the class being processed + creation_functions: (MODIFIED IN PLACE) the names of the create functions + _do_not_process: as for expand_args_fields. + known_implementations: (MODIFIED IN PLACE) known types from the registry + """ + # Because we are adding defaultable members, make + # sure they go at the end of __annotations__ in case + # there are non-defaulted standard class members. + del some_class.__annotations__[name] + + if pluggable: + type_name = name + TYPE_SUFFIX + if type_name not in some_class.__annotations__: + some_class.__annotations__[type_name] = str + setattr(some_class, type_name, "UNDEFAULTED") + + for derived_type in registry.get_all(type_): + if derived_type in _do_not_process: + continue + if issubclass(derived_type, some_class): + # When derived_type is some_class we have a simple + # recursion to avoid. When it's a strict subclass the + # situation is even worse. + continue + known_implementations[derived_type.__name__] = derived_type + args_name = f"{name}_{derived_type.__name__}{ARGS_SUFFIX}" + if args_name in some_class.__annotations__: + raise ValueError( + f"Cannot generate {args_name} because it is already present." + ) + some_class.__annotations__[args_name] = DictConfig + setattr( + some_class, + args_name, + get_default_args_field( + derived_type, _do_not_process=_do_not_process + (some_class,) + ), + ) + else: + args_name = name + ARGS_SUFFIX + if args_name in some_class.__annotations__: + raise ValueError( + f"Cannot generate {args_name} because it is already present." + ) + if issubclass(type_, some_class) or type_ in _do_not_process: + raise ValueError(f"Cannot process {type_} inside {some_class}") + + some_class.__annotations__[args_name] = DictConfig + setattr( + some_class, + args_name, + get_default_args_field( + type_, + _do_not_process=_do_not_process + (some_class,), + ), + ) + + creation_function_name = f"create_{name}" + if not hasattr(some_class, creation_function_name): + setattr( + some_class, + creation_function_name, + _default_create(name, type_, pluggable), + ) + creation_functions.append(creation_function_name) + + +def remove_unused_components(dict_: DictConfig) -> None: + """ + Assuming dict_ represents the state of a configurable, + modify it to remove all the portions corresponding to + pluggable parts which are not in use. + For example, if renderer_class_type is SignedDistanceFunctionRenderer, + the renderer_MultiPassEmissionAbsorptionRenderer_args will be + removed. + + Args: + dict_: (MODIFIED IN PLACE) a DictConfig instance + """ + keys = [key for key in dict_ if isinstance(key, str)] + suffix_length = len(TYPE_SUFFIX) + replaceables = [key[:-suffix_length] for key in keys if key.endswith(TYPE_SUFFIX)] + args_keys = [key for key in keys if key.endswith(ARGS_SUFFIX)] + for replaceable in replaceables: + selected_type = dict_[replaceable + TYPE_SUFFIX] + expect = replaceable + "_" + selected_type + ARGS_SUFFIX + with open_dict(dict_): + for key in args_keys: + if key.startswith(replaceable + "_") and key != expect: + del dict_[key] + + for key in dict_: + if isinstance(dict_.get(key), DictConfig): + remove_unused_components(dict_[key]) diff --git a/pytorch3d/implicitron/tools/depth_cleanup.py b/pytorch3d/implicitron/tools/depth_cleanup.py new file mode 100644 index 00000000..a4a8bab4 --- /dev/null +++ b/pytorch3d/implicitron/tools/depth_cleanup.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as Fu +from pytorch3d.ops import wmean +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds + + +def cleanup_eval_depth( + point_cloud: Pointclouds, + camera: CamerasBase, + depth: torch.Tensor, + mask: torch.Tensor, + sigma: float = 0.01, + image=None, +): + + ba, _, H, W = depth.shape + + pcl = point_cloud.points_padded() + n_pts = point_cloud.num_points_per_cloud() + pcl_mask = ( + torch.arange(pcl.shape[1], dtype=torch.int64, device=pcl.device)[None] + < n_pts[:, None] + ).type_as(pcl) + + pcl_proj = camera.transform_points(pcl, eps=1e-2)[..., :-1] + pcl_depth = camera.get_world_to_view_transform().transform_points(pcl)[..., -1] + + depth_and_idx = torch.cat( + ( + depth, + torch.arange(H * W).view(1, 1, H, W).expand(ba, 1, H, W).type_as(depth), + ), + dim=1, + ) + + depth_and_idx_sampled = Fu.grid_sample( + depth_and_idx, -pcl_proj[:, None], mode="nearest" + )[:, :, 0].view(ba, 2, -1) + + depth_sampled, idx_sampled = depth_and_idx_sampled.split([1, 1], dim=1) + df = (depth_sampled[:, 0] - pcl_depth).abs() + + # the threshold is a sigma-multiple of the standard deviation of the depth + mu = wmean(depth.view(ba, -1, 1), mask.view(ba, -1)).view(ba, 1) + std = ( + wmean((depth.view(ba, -1) - mu).view(ba, -1, 1) ** 2, mask.view(ba, -1)) + .clamp(1e-4) + .sqrt() + .view(ba, -1) + ) + good_df_thr = std * sigma + good_depth = (df <= good_df_thr).float() * pcl_mask + + perc_kept = good_depth.sum(dim=1) / pcl_mask.sum(dim=1).clamp(1) + # print(f'Kept {100.0 * perc_kept.mean():1.3f} % points') + + good_depth_raster = torch.zeros_like(depth).view(ba, -1) + # pyre-ignore[16]: scatter_add_ + good_depth_raster.scatter_add_(1, torch.round(idx_sampled[:, 0]).long(), good_depth) + + good_depth_mask = (good_depth_raster.view(ba, 1, H, W) > 0).float() + + # if float(torch.rand(1)) > 0.95: + # depth_ok = depth * good_depth_mask + + # # visualize + # visdom_env = 'depth_cleanup_dbg' + # from visdom import Visdom + # # from tools.vis_utils import make_depth_image + # from pytorch3d.vis.plotly_vis import plot_scene + # viz = Visdom() + + # show_pcls = { + # 'pointclouds': point_cloud, + # } + # for d, nm in zip( + # (depth, depth_ok), + # ('pointclouds_unproj', 'pointclouds_unproj_ok'), + # ): + # pointclouds_unproj = get_rgbd_point_cloud( + # camera, image, d, + # ) + # if int(pointclouds_unproj.num_points_per_cloud()) > 0: + # show_pcls[nm] = pointclouds_unproj + + # scene_dict = {'1': { + # **show_pcls, + # 'cameras': camera, + # }} + # scene = plot_scene( + # scene_dict, + # pointcloud_max_points=5000, + # pointcloud_marker_size=1.5, + # camera_scale=1.0, + # ) + # viz.plotlyplot(scene, env=visdom_env, win='scene') + + # # depth_image_ok = make_depth_image(depths_ok, masks) + # # viz.images(depth_image_ok, env=visdom_env, win='depth_ok') + # # depth_image = make_depth_image(depths, masks) + # # viz.images(depth_image, env=visdom_env, win='depth') + # # # viz.images(rgb_rendered, env=visdom_env, win='images_render') + # # viz.images(images, env=visdom_env, win='images') + # import pdb; pdb.set_trace() + + return good_depth_mask diff --git a/pytorch3d/implicitron/tools/eval_video_trajectory.py b/pytorch3d/implicitron/tools/eval_video_trajectory.py new file mode 100644 index 00000000..d9afb873 --- /dev/null +++ b/pytorch3d/implicitron/tools/eval_video_trajectory.py @@ -0,0 +1,226 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional, Tuple + +import torch +from pytorch3d.common.compat import eigh +from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d +from pytorch3d.renderer import PerspectiveCameras, look_at_view_transform +from pytorch3d.transforms import Scale + + +def generate_eval_video_cameras( + train_cameras, + n_eval_cams: int = 100, + trajectory_type: str = "figure_eight", + trajectory_scale: float = 0.2, + scene_center: Tuple[float, float, float] = (0.0, 0.0, 0.0), + up: Tuple[float, float, float] = (0.0, 0.0, 1.0), + focal_length: Optional[torch.FloatTensor] = None, + principal_point: Optional[torch.FloatTensor] = None, + time: Optional[torch.FloatTensor] = None, + infer_up_as_plane_normal: bool = True, + traj_offset: Optional[Tuple[float, float, float]] = None, + traj_offset_canonical: Optional[Tuple[float, float, float]] = None, +) -> PerspectiveCameras: + """ + Generate a camera trajectory rendering a scene from multiple viewpoints. + + Args: + train_dataset: The training dataset object. + n_eval_cams: Number of cameras in the trajectory. + trajectory_type: The type of the camera trajectory. Can be one of: + circular_lsq_fit: Camera centers follow a trajectory obtained + by fitting a 3D circle to train_cameras centers. + All cameras are looking towards scene_center. + figure_eight: Figure-of-8 trajectory around the center of the + central camera of the training dataset. + trefoil_knot: Same as 'figure_eight', but the trajectory has a shape + of a trefoil knot (https://en.wikipedia.org/wiki/Trefoil_knot). + figure_eight_knot: Same as 'figure_eight', but the trajectory has a shape + of a figure-eight knot + (https://en.wikipedia.org/wiki/Figure-eight_knot_(mathematics)). + trajectory_scale: The extent of the trajectory. + up: The "up" vector of the scene (=the normal of the scene floor). + Active for the `trajectory_type="circular"`. + scene_center: The center of the scene in world coordinates which all + the cameras from the generated trajectory look at. + Returns: + Dictionary of camera instances which can be used as the test dataset + """ + if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"): + cam_centers = train_cameras.get_camera_center() + # get the nearest camera center to the mean of centers + mean_camera_idx = ( + ((cam_centers - cam_centers.mean(dim=0)[None]) ** 2) + .sum(dim=1) + .min(dim=0) + .indices + ) + # generate the knot trajectory in canonical coords + if time is None: + time = torch.linspace(0, 2 * math.pi, n_eval_cams + 1)[:n_eval_cams] + else: + assert time.numel() == n_eval_cams + if trajectory_type == "trefoil_knot": + traj = _trefoil_knot(time) + elif trajectory_type == "figure_eight_knot": + traj = _figure_eight_knot(time) + elif trajectory_type == "figure_eight": + traj = _figure_eight(time) + else: + raise ValueError(f"bad trajectory type: {trajectory_type}") + traj[:, 2] -= traj[:, 2].max() + + # transform the canonical knot to the coord frame of the mean camera + mean_camera = PerspectiveCameras( + **{ + k: getattr(train_cameras, k)[[int(mean_camera_idx)]] + for k in ("focal_length", "principal_point", "R", "T") + } + ) + traj_trans = Scale(cam_centers.std(dim=0).mean() * trajectory_scale).compose( + mean_camera.get_world_to_view_transform().inverse() + ) + + if traj_offset_canonical is not None: + traj_trans = traj_trans.translate( + torch.FloatTensor(traj_offset_canonical)[None].to(traj) + ) + + traj = traj_trans.transform_points(traj) + + plane_normal = _fit_plane(cam_centers)[:, 0] + if infer_up_as_plane_normal: + up = _disambiguate_normal(plane_normal, up) + + elif trajectory_type == "circular_lsq_fit": + ### fit plane to the camera centers + + # get the center of the plane as the median of the camera centers + cam_centers = train_cameras.get_camera_center() + + if time is not None: + angle = time + else: + angle = torch.linspace(0, 2.0 * math.pi, n_eval_cams).to(cam_centers) + + fit = fit_circle_in_3d( + cam_centers, + angles=angle, + offset=angle.new_tensor(traj_offset_canonical) + if traj_offset_canonical is not None + else None, + up=angle.new_tensor(up), + ) + traj = fit.generated_points + + # scalethe trajectory + _t_mu = traj.mean(dim=0, keepdim=True) + traj = (traj - _t_mu) * trajectory_scale + _t_mu + + plane_normal = fit.normal + + if infer_up_as_plane_normal: + up = _disambiguate_normal(plane_normal, up) + + else: + raise ValueError(f"Uknown trajectory_type {trajectory_type}.") + + if traj_offset is not None: + traj = traj + torch.FloatTensor(traj_offset)[None].to(traj) + + # point all cameras towards the center of the scene + R, T = look_at_view_transform( + eye=traj, + at=(scene_center,), # (1, 3) + up=(up,), # (1, 3) + device=traj.device, + ) + + # get the average focal length and principal point + if focal_length is None: + focal_length = train_cameras.focal_length.mean(dim=0).repeat(n_eval_cams, 1) + if principal_point is None: + principal_point = train_cameras.principal_point.mean(dim=0).repeat( + n_eval_cams, 1 + ) + + test_cameras = PerspectiveCameras( + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + device=focal_length.device, + ) + + # _visdom_plot_scene( + # train_cameras, + # test_cameras, + # ) + + return test_cameras + + +def _disambiguate_normal(normal, up): + up_t = torch.tensor(up).to(normal) + flip = (up_t * normal).sum().sign() + up = normal * flip + up = up.tolist() + return up + + +def _fit_plane(x): + x = x - x.mean(dim=0)[None] + cov = (x.t() @ x) / x.shape[0] + _, e_vec = eigh(cov) + return e_vec + + +def _visdom_plot_scene( + train_cameras, + test_cameras, +) -> None: + from pytorch3d.vis.plotly_vis import plot_scene + + p = plot_scene( + { + "scene": { + "train_cams": train_cameras, + "test_cams": test_cameras, + } + } + ) + from visdom import Visdom + + viz = Visdom() + viz.plotlyplot(p, env="cam_traj_dbg", win="cam_trajs") + import pdb + + pdb.set_trace() + + +def _figure_eight_knot(t: torch.Tensor, z_scale: float = 0.5): + x = (2 + (2 * t).cos()) * (3 * t).cos() + y = (2 + (2 * t).cos()) * (3 * t).sin() + z = (4 * t).sin() * z_scale + return torch.stack((x, y, z), dim=-1) + + +def _trefoil_knot(t: torch.Tensor, z_scale: float = 0.5): + x = t.sin() + 2 * (2 * t).sin() + y = t.cos() - 2 * (2 * t).cos() + z = -(3 * t).sin() * z_scale + return torch.stack((x, y, z), dim=-1) + + +def _figure_eight(t: torch.Tensor, z_scale: float = 0.5): + x = t.cos() + y = (2 * t).sin() / 2 + z = t.sin() * z_scale + return torch.stack((x, y, z), dim=-1) diff --git a/pytorch3d/implicitron/tools/image_utils.py b/pytorch3d/implicitron/tools/image_utils.py new file mode 100644 index 00000000..33926f31 --- /dev/null +++ b/pytorch3d/implicitron/tools/image_utils.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Union + +import torch + + +def mask_background( + image_rgb: torch.Tensor, + mask_fg: torch.Tensor, + dim_color: int = 1, + bg_color: Union[torch.Tensor, str, float] = 0.0, +) -> torch.Tensor: + """ + Mask the background input image tensor `image_rgb` with `bg_color`. + The background regions are obtained from the binary foreground segmentation + mask `mask_fg`. + """ + tgt_view = [1, 1, 1, 1] + tgt_view[dim_color] = 3 + # obtain the background color tensor + if isinstance(bg_color, torch.Tensor): + bg_color_t = bg_color.view(1, 3, 1, 1).clone().to(image_rgb) + elif isinstance(bg_color, float): + bg_color_t = torch.tensor( + [bg_color] * 3, device=image_rgb.device, dtype=image_rgb.dtype + ).view(*tgt_view) + elif isinstance(bg_color, str): + if bg_color == "white": + bg_color_t = image_rgb.new_ones(tgt_view) + elif bg_color == "black": + bg_color_t = image_rgb.new_zeros(tgt_view) + else: + raise ValueError(_invalid_color_error_msg(bg_color)) + else: + raise ValueError(_invalid_color_error_msg(bg_color)) + # cast to the image_rgb's type + mask_fg = mask_fg.type_as(image_rgb) + # mask the bg + image_masked = mask_fg * image_rgb + (1 - mask_fg) * bg_color_t + return image_masked + + +def _invalid_color_error_msg(bg_color) -> str: + return ( + f"Invalid bg_color={bg_color}. Plese set bg_color to a 3-element" + + " tensor. or a string (white | black), or a float." + ) diff --git a/pytorch3d/implicitron/tools/metric_utils.py b/pytorch3d/implicitron/tools/metric_utils.py new file mode 100644 index 00000000..05433027 --- /dev/null +++ b/pytorch3d/implicitron/tools/metric_utils.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional, Tuple + +import torch +from torch.nn import functional as F + + +def eval_depth( + pred: torch.Tensor, + gt: torch.Tensor, + crop: int = 1, + mask: Optional[torch.Tensor] = None, + get_best_scale: bool = True, + mask_thr: float = 0.5, + best_scale_clamp_thr: float = 1e-4, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Evaluate the depth error between the prediction `pred` and the ground + truth `gt`. + + Args: + pred: A tensor of shape (N, 1, H, W) denoting the predicted depth maps. + gt: A tensor of shape (N, 1, H, W) denoting the ground truth depth maps. + crop: The number of pixels to crop from the border. + mask: A mask denoting the valid regions of the gt depth. + get_best_scale: If `True`, estimates a scaling factor of the predicted depth + that yields the best mean squared error between `pred` and `gt`. + This is typically enabled for cases where predicted reconstructions + are inherently defined up to an arbitrary scaling factor. + mask_thr: A constant used to threshold the `mask` to specify the valid + regions. + best_scale_clamp_thr: The threshold for clamping the divisor in best + scale estimation. + + Returns: + mse_depth: Mean squared error between `pred` and `gt`. + abs_depth: Mean absolute difference between `pred` and `gt`. + """ + + # chuck out the border + if crop > 0: + gt = gt[:, :, crop:-crop, crop:-crop] + pred = pred[:, :, crop:-crop, crop:-crop] + + if mask is not None: + # mult gt by mask + if crop > 0: + mask = mask[:, :, crop:-crop, crop:-crop] + gt = gt * (mask > mask_thr).float() + + dmask = (gt > 0.0).float() + dmask_mass = torch.clamp(dmask.sum((1, 2, 3)), 1e-4) + + if get_best_scale: + # mult preds by a scalar "scale_best" + # s.t. we get best possible mse error + scale_best = estimate_depth_scale_factor(pred, gt, dmask, best_scale_clamp_thr) + pred = pred * scale_best[:, None, None, None] + + df = gt - pred + + mse_depth = (dmask * (df ** 2)).sum((1, 2, 3)) / dmask_mass + abs_depth = (dmask * df.abs()).sum((1, 2, 3)) / dmask_mass + + return mse_depth, abs_depth + + +def estimate_depth_scale_factor(pred, gt, mask, clamp_thr): + xy = pred * gt * mask + xx = pred * pred * mask + scale_best = xy.mean((1, 2, 3)) / torch.clamp(xx.mean((1, 2, 3)), clamp_thr) + return scale_best + + +def calc_psnr( + x: torch.Tensor, + y: torch.Tensor, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Calculates the Peak-signal-to-noise ratio between tensors `x` and `y`. + """ + mse = calc_mse(x, y, mask=mask) + psnr = torch.log10(mse.clamp(1e-10)) * (-10.0) + return psnr + + +def calc_mse( + x: torch.Tensor, + y: torch.Tensor, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Calculates the mean square error between tensors `x` and `y`. + """ + if mask is None: + return torch.mean((x - y) ** 2) + else: + return (((x - y) ** 2) * mask).sum() / mask.expand_as(x).sum().clamp(1e-5) + + +def calc_bce( + pred: torch.Tensor, + gt: torch.Tensor, + equal_w: bool = True, + pred_eps: float = 0.01, + mask: Optional[torch.Tensor] = None, + lerp_bound: Optional[float] = None, +) -> torch.Tensor: + """ + Calculates the binary cross entropy. + """ + if pred_eps > 0.0: + # up/low bound the predictions + pred = torch.clamp(pred, pred_eps, 1.0 - pred_eps) + + if mask is None: + mask = torch.ones_like(gt) + + if equal_w: + mask_fg = (gt > 0.5).float() * mask + mask_bg = (1 - mask_fg) * mask + weight = mask_fg / mask_fg.sum().clamp(1.0) + mask_bg / mask_bg.sum().clamp(1.0) + # weight sum should be at this point ~2 + weight = weight * (weight.numel() / weight.sum().clamp(1.0)) + else: + weight = torch.ones_like(gt) * mask + + if lerp_bound is not None: + return binary_cross_entropy_lerp(pred, gt, weight, lerp_bound) + else: + return F.binary_cross_entropy(pred, gt, reduction="mean", weight=weight) + + +def binary_cross_entropy_lerp( + pred: torch.Tensor, + gt: torch.Tensor, + weight: torch.Tensor, + lerp_bound: float, +): + """ + Binary cross entropy which avoids exploding gradients by linearly + extrapolating the log function for log(1-pred) mad log(pred) whenever + pred or 1-pred is smaller than lerp_bound. + """ + loss = log_lerp(1 - pred, lerp_bound) * (1 - gt) + log_lerp(pred, lerp_bound) * gt + loss_reduced = -(loss * weight).sum() / weight.sum().clamp(1e-4) + return loss_reduced + + +def log_lerp(x: torch.Tensor, b: float): + """ + Linearly extrapolated log for x < b. + """ + assert b > 0 + return torch.where(x >= b, x.log(), math.log(b) + (x - b) / b) + + +def rgb_l1( + pred: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Calculates the mean absolute error between the predicted colors `pred` + and ground truth colors `target`. + """ + if mask is None: + mask = torch.ones_like(pred[:, :1]) + return ((pred - target).abs() * mask).sum(dim=(1, 2, 3)) / mask.sum( + dim=(1, 2, 3) + ).clamp(1) + + +def huber(dfsq: torch.Tensor, scaling: float = 0.03) -> torch.Tensor: + """ + Calculates the huber function of the input squared error `dfsq`. + The function smoothly transitions from a region with unit gradient + to a hyperbolic function at `dfsq=scaling`. + """ + loss = (safe_sqrt(1 + dfsq / (scaling * scaling), eps=1e-4) - 1) * scaling + return loss + + +def neg_iou_loss( + predict: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This is a great loss because it emphasizes on the active + regions of the predict and targets + """ + return 1.0 - iou(predict, target, mask=mask) + + +def safe_sqrt(A: torch.Tensor, eps: float = float(1e-4)) -> torch.Tensor: + """ + performs safe differentiable sqrt + """ + return (torch.clamp(A, float(0)) + eps).sqrt() + + +def iou( + predict: torch.Tensor, + target: torch.Tensor, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This is a great loss because it emphasizes on the active + regions of the predict and targets + """ + dims = tuple(range(predict.dim())[1:]) + if mask is not None: + predict = predict * mask + target = target * mask + intersect = (predict * target).sum(dims) + union = (predict + target - predict * target).sum(dims) + 1e-4 + return (intersect / union).sum() / intersect.numel() + + +def beta_prior(pred: torch.Tensor, cap: float = 0.1) -> torch.Tensor: + if cap <= 0.0: + raise ValueError("capping should be positive to avoid unbound loss") + + min_value = math.log(cap) + math.log(cap + 1.0) + return (torch.log(pred + cap) + torch.log(1.0 - pred + cap)).mean() - min_value diff --git a/pytorch3d/implicitron/tools/model_io.py b/pytorch3d/implicitron/tools/model_io.py new file mode 100644 index 00000000..d38272d1 --- /dev/null +++ b/pytorch3d/implicitron/tools/model_io.py @@ -0,0 +1,163 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import glob +import os +import shutil +import tempfile + +import torch + + +def load_stats(flstats): + from pytorch3d.implicitron.tools.stats import Stats + + try: + stats = Stats.load(flstats) + except: + print("Cant load stats! %s" % flstats) + stats = None + return stats + + +def get_model_path(fl) -> str: + fl = os.path.splitext(fl)[0] + flmodel = "%s.pth" % fl + return flmodel + + +def get_optimizer_path(fl) -> str: + fl = os.path.splitext(fl)[0] + flopt = "%s_opt.pth" % fl + return flopt + + +def get_stats_path(fl, eval_results: bool = False): + fl = os.path.splitext(fl)[0] + if eval_results: + for postfix in ("_2", ""): + flstats = os.path.join(os.path.dirname(fl), f"stats_test{postfix}.jgz") + if os.path.isfile(flstats): + break + else: + flstats = "%s_stats.jgz" % fl + # pyre-fixme[61]: `flstats` is undefined, or not always defined. + return flstats + + +def safe_save_model(model, stats, fl, optimizer=None, cfg=None) -> None: + """ + This functions stores model files safely so that no model files exist on the + file system in case the saving procedure gets interrupted. + + This is done first by saving the model files to a temporary directory followed + by (atomic) moves to the target location. Note, that this can still result + in a corrupt set of model files in case interruption happens while performing + the moves. It is however quite improbable that a crash would occur right at + this time. + """ + print(f"saving model files safely to {fl}") + # first store everything to a tmpdir + with tempfile.TemporaryDirectory() as tmpdir: + tmpfl = os.path.join(tmpdir, os.path.split(fl)[-1]) + stored_tmp_fls = save_model(model, stats, tmpfl, optimizer=optimizer, cfg=cfg) + tgt_fls = [ + ( + os.path.join(os.path.split(fl)[0], os.path.split(tmpfl)[-1]) + if (tmpfl is not None) + else None + ) + for tmpfl in stored_tmp_fls + ] + # then move from the tmpdir to the right location + for tmpfl, tgt_fl in zip(stored_tmp_fls, tgt_fls): + if tgt_fl is None: + continue + # print(f'Moving {tmpfl} --> {tgt_fl}\n') + shutil.move(tmpfl, tgt_fl) + + +def save_model(model, stats, fl, optimizer=None, cfg=None): + flstats = get_stats_path(fl) + flmodel = get_model_path(fl) + print("saving model to %s" % flmodel) + torch.save(model.state_dict(), flmodel) + flopt = None + if optimizer is not None: + flopt = get_optimizer_path(fl) + print("saving optimizer to %s" % flopt) + torch.save(optimizer.state_dict(), flopt) + print("saving model stats to %s" % flstats) + stats.save(flstats) + + return flstats, flmodel, flopt + + +def load_model(fl): + flstats = get_stats_path(fl) + flmodel = get_model_path(fl) + flopt = get_optimizer_path(fl) + model_state_dict = torch.load(flmodel) + stats = load_stats(flstats) + if os.path.isfile(flopt): + optimizer = torch.load(flopt) + else: + optimizer = None + + return model_state_dict, stats, optimizer + + +def parse_epoch_from_model_path(model_path) -> int: + return int( + os.path.split(model_path)[-1].replace(".pth", "").replace("model_epoch_", "") + ) + + +def get_checkpoint(exp_dir, epoch): + fl = os.path.join(exp_dir, "model_epoch_%08d.pth" % epoch) + return fl + + +def find_last_checkpoint( + exp_dir, any_path: bool = False, all_checkpoints: bool = False +): + if any_path: + exts = [".pth", "_stats.jgz", "_opt.pth"] + else: + exts = [".pth"] + + for ext in exts: + fls = sorted( + glob.glob( + os.path.join(glob.escape(exp_dir), "model_epoch_" + "[0-9]" * 8 + ext) + ) + ) + if len(fls) > 0: + break + # pyre-fixme[61]: `fls` is undefined, or not always defined. + if len(fls) == 0: + fl = None + else: + if all_checkpoints: + # pyre-fixme[61]: `fls` is undefined, or not always defined. + fl = [f[0 : -len(ext)] + ".pth" for f in fls] + else: + fl = fls[-1][0 : -len(ext)] + ".pth" + + return fl + + +def purge_epoch(exp_dir, epoch) -> None: + model_path = get_checkpoint(exp_dir, epoch) + + for file_path in [ + model_path, + get_optimizer_path(model_path), + get_stats_path(model_path), + ]: + if os.path.isfile(file_path): + print("deleting %s" % file_path) + os.remove(file_path) diff --git a/pytorch3d/implicitron/tools/point_cloud_utils.py b/pytorch3d/implicitron/tools/point_cloud_utils.py new file mode 100644 index 00000000..0a051a13 --- /dev/null +++ b/pytorch3d/implicitron/tools/point_cloud_utils.py @@ -0,0 +1,168 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Optional, Tuple, cast + +import torch +import torch.nn.functional as Fu +from pytorch3d.renderer import ( + AlphaCompositor, + NDCMultinomialRaysampler, + PointsRasterizationSettings, + PointsRasterizer, + ray_bundle_to_ray_points, +) +from pytorch3d.renderer.cameras import CamerasBase +from pytorch3d.structures import Pointclouds + + +def get_rgbd_point_cloud( + camera: CamerasBase, + image_rgb: torch.Tensor, + depth_map: torch.Tensor, + mask: Optional[torch.Tensor] = None, + mask_thr: float = 0.5, + mask_points: bool = True, +) -> Pointclouds: + """ + Given a batch of images, depths, masks and cameras, generate a colored + point cloud by unprojecting depth maps to the and coloring with the source + pixel colors. + """ + imh, imw = image_rgb.shape[2:] + + # convert the depth maps to point clouds using the grid ray sampler + pts_3d = ray_bundle_to_ray_points( + NDCMultinomialRaysampler( + image_width=imw, + image_height=imh, + n_pts_per_ray=1, + min_depth=1.0, + max_depth=1.0, + )(camera)._replace(lengths=depth_map[:, 0, ..., None]) + ) + + pts_mask = depth_map > 0.0 + if mask is not None: + pts_mask *= mask > mask_thr + pts_mask = pts_mask.reshape(-1) + + pts_3d = pts_3d.reshape(-1, 3)[pts_mask] + + pts_colors = torch.nn.functional.interpolate( + image_rgb, + # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got + # `List[typing.Any]`. + size=[imh, imw], + mode="bilinear", + align_corners=False, + ) + pts_colors = pts_colors.permute(0, 2, 3, 1).reshape(-1, 3)[pts_mask] + + return Pointclouds(points=pts_3d[None], features=pts_colors[None]) + + +def render_point_cloud_pytorch3d( + camera, + point_cloud, + render_size: Tuple[int, int], + point_radius: float = 0.03, + topk: int = 10, + eps: float = 1e-2, + bg_color=None, + **kwargs +): + + # feature dimension + featdim = point_cloud.features_packed().shape[-1] + + # move to the camera coordinates; using identity cameras in the renderer + point_cloud = _transform_points(camera, point_cloud, eps, **kwargs) + camera_trivial = camera.clone() + camera_trivial.R[:] = torch.eye(3) + camera_trivial.T *= 0.0 + + rasterizer = PointsRasterizer( + cameras=camera_trivial, + raster_settings=PointsRasterizationSettings( + image_size=render_size, + radius=point_radius, + points_per_pixel=topk, + bin_size=64 if int(max(render_size)) > 1024 else None, + ), + ) + + fragments = rasterizer(point_cloud, **kwargs) + + # Construct weights based on the distance of a point to the true point. + # However, this could be done differently: e.g. predicted as opposed + # to a function of the weights. + r = rasterizer.raster_settings.radius + + # set up the blending weights + dists2 = fragments.dists + weights = 1 - dists2 / (r * r) + ok = cast(torch.BoolTensor, (fragments.idx >= 0)).float() + + weights = weights * ok + + fragments_prm = fragments.idx.long().permute(0, 3, 1, 2) + weights_prm = weights.permute(0, 3, 1, 2) + images = AlphaCompositor()( + fragments_prm, + weights_prm, + point_cloud.features_packed().permute(1, 0), + background_color=bg_color if bg_color is not None else [0.0] * featdim, + **kwargs, + ) + + # get the depths ... + # weighted_fs[b,c,i,j] = sum_k cum_alpha_k * features[c,pointsidx[b,k,i,j]] + # cum_alpha_k = alphas[b,k,i,j] * prod_l=0..k-1 (1 - alphas[b,l,i,j]) + cumprod = torch.cumprod(1 - weights, dim=-1) + cumprod = torch.cat((torch.ones_like(cumprod[..., :1]), cumprod[..., :-1]), dim=-1) + depths = (weights * cumprod * fragments.zbuf).sum(dim=-1) + # add the rendering mask + render_mask = -torch.prod(1.0 - weights, dim=-1) + 1.0 + + # cat depths and render mask + rendered_blob = torch.cat((images, depths[:, None], render_mask[:, None]), dim=1) + + # reshape back + rendered_blob = Fu.interpolate( + rendered_blob, + # pyre-fixme[6]: Expected `Optional[int]` for 2nd param but got `Tuple[int, + # ...]`. + size=tuple(render_size), + mode="bilinear", + ) + + data_rendered, depth_rendered, render_mask = rendered_blob.split( + [rendered_blob.shape[1] - 2, 1, 1], + dim=1, + ) + + return data_rendered, render_mask, depth_rendered + + +def _signed_clamp(x, eps): + sign = x.sign() + (x == 0.0).type_as(x) + x_clamp = sign * torch.clamp(x.abs(), eps) + return x_clamp + + +def _transform_points(cameras, point_clouds, eps, **kwargs): + pts_world = point_clouds.points_padded() + pts_view = cameras.get_world_to_view_transform(**kwargs).transform_points( + pts_world, eps=eps + ) + # it is crucial to actually clamp the points as well ... + pts_view = torch.cat( + (pts_view[..., :-1], _signed_clamp(pts_view[..., -1:], eps)), dim=-1 + ) + point_clouds = point_clouds.update_padded(pts_view) + return point_clouds diff --git a/pytorch3d/implicitron/tools/rasterize_mc.py b/pytorch3d/implicitron/tools/rasterize_mc.py new file mode 100644 index 00000000..20570a30 --- /dev/null +++ b/pytorch3d/implicitron/tools/rasterize_mc.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Optional, Tuple + +import torch +from pytorch3d.renderer import PerspectiveCameras +from pytorch3d.structures import Pointclouds + +from .point_cloud_utils import render_point_cloud_pytorch3d + + +def rasterize_mc_samples( + xys: torch.Tensor, + feats: torch.Tensor, + image_size_hw: Tuple[int, int], + radius: float = 0.03, + topk: int = 5, + masks: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Rasterizes Monte-Carlo sampled features back onto the image. + + Specifically, the code uses the PyTorch3D point rasterizer to render + a z-flat point cloud composed of the xy MC locations and their features. + + Args: + xys: B x N x 2 2D point locations in PyTorch3D NDC convention + feats: B x N x dim tensor containing per-point rendered features. + image_size_hw: Tuple[image_height, image_width] containing + the size of rasterized image. + radius: Rasterization point radius. + topk: The maximum z-buffer size for the PyTorch3D point cloud rasterizer. + masks: B x N x 1 tensor containing the alpha mask of the + rendered features. + """ + + if masks is None: + masks = torch.ones_like(xys[..., :1]) + + feats = torch.cat((feats, masks), dim=-1) + pointclouds = Pointclouds( + points=torch.cat([xys, torch.ones_like(xys[..., :1])], dim=-1), + features=feats, + ) + + data_rendered, render_mask, _ = render_point_cloud_pytorch3d( + PerspectiveCameras(device=feats.device), + pointclouds, + render_size=image_size_hw, + point_radius=radius, + topk=topk, + ) + + data_rendered, masks_pt = data_rendered.split( + [data_rendered.shape[1] - 1, 1], dim=1 + ) + render_mask = masks_pt * render_mask + + return data_rendered, render_mask diff --git a/pytorch3d/implicitron/tools/stats.py b/pytorch3d/implicitron/tools/stats.py new file mode 100644 index 00000000..16eebe9f --- /dev/null +++ b/pytorch3d/implicitron/tools/stats.py @@ -0,0 +1,491 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gzip +import json +import time +import warnings +from collections.abc import Iterable +from itertools import cycle + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import colors as mcolors +from pytorch3d.implicitron.tools.vis_utils import get_visdom_connection + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.history = [] + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1, epoch=0): + + # make sure the history is of the same len as epoch + while len(self.history) <= epoch: + self.history.append([]) + + self.history[epoch].append(val / n) + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def get_epoch_averages(self, epoch=-1): + if len(self.history) == 0: # no stats here + return None + elif epoch == -1: + return [ + (float(np.array(x).mean()) if len(x) > 0 else float("NaN")) + for x in self.history + ] + else: + return float(np.array(self.history[epoch]).mean()) + + def get_all_values(self): + all_vals = [np.array(x) for x in self.history] + all_vals = np.concatenate(all_vals) + return all_vals + + def get_epoch(self): + return len(self.history) + + @staticmethod + def from_json_str(json_str): + self = AverageMeter() + self.__dict__.update(json.loads(json_str)) + return self + + +class Stats(object): + # TODO: update this with context manager + """ + stats logging object useful for gathering statistics of training a deep net in pytorch + Example: + # init stats structure that logs statistics 'objective' and 'top1e' + stats = Stats( ('objective','top1e') ) + network = init_net() # init a pytorch module (=nueral network) + dataloader = init_dataloader() # init a dataloader + for epoch in range(10): + # start of epoch -> call new_epoch + stats.new_epoch() + + # iterate over batches + for batch in dataloader: + + output = network(batch) # run and save into a dict of output variables "output" + + # stats.update() automatically parses the 'objective' and 'top1e' from + # the "output" dict and stores this into the db + stats.update(output) + stats.print() # prints the averages over given epoch + # stores the training plots into '/tmp/epoch_stats.pdf' + # and plots into a visdom server running at localhost (if running) + stats.plot_stats(plot_file='/tmp/epoch_stats.pdf') + """ + + def __init__( + self, + log_vars, + verbose=False, + epoch=-1, + visdom_env="main", + do_plot=True, + plot_file=None, + visdom_server="http://localhost", + visdom_port=8097, + ): + + self.verbose = verbose + self.log_vars = log_vars + self.visdom_env = visdom_env + self.visdom_server = visdom_server + self.visdom_port = visdom_port + self.plot_file = plot_file + self.do_plot = do_plot + self.hard_reset(epoch=epoch) + + @staticmethod + def from_json_str(json_str): + self = Stats([]) + # load the global state + self.__dict__.update(json.loads(json_str)) + # recover the AverageMeters + for stat_set in self.stats: + self.stats[stat_set] = { + log_var: AverageMeter.from_json_str(log_vals_json_str) + for log_var, log_vals_json_str in self.stats[stat_set].items() + } + return self + + @staticmethod + def load(flpath, postfix=".jgz"): + flpath = _get_postfixed_filename(flpath, postfix) + with gzip.open(flpath, "r") as fin: + data = json.loads(fin.read().decode("utf-8")) + return Stats.from_json_str(data) + + def save(self, flpath, postfix=".jgz"): + flpath = _get_postfixed_filename(flpath, postfix) + # store into a gzipped-json + with gzip.open(flpath, "w") as fout: + fout.write(json.dumps(self, cls=StatsJSONEncoder).encode("utf-8")) + + # some sugar to be used with "with stats:" at the beginning of the epoch + def __enter__(self): + if self.do_plot and self.epoch >= 0: + self.plot_stats(self.visdom_env) + self.new_epoch() + + def __exit__(self, type, value, traceback): + iserr = type is not None and issubclass(type, Exception) + iserr = iserr or (type is KeyboardInterrupt) + if iserr: + print("error inside 'with' block") + return + if self.do_plot: + self.plot_stats(self.visdom_env) + + def reset(self): # to be called after each epoch + stat_sets = list(self.stats.keys()) + if self.verbose: + print("stats: epoch %d - reset" % self.epoch) + self.it = {k: -1 for k in stat_sets} + for stat_set in stat_sets: + for stat in self.stats[stat_set]: + self.stats[stat_set][stat].reset() + + def hard_reset(self, epoch=-1): # to be called during object __init__ + self.epoch = epoch + if self.verbose: + print("stats: epoch %d - hard reset" % self.epoch) + self.stats = {} + + # reset + self.reset() + + def new_epoch(self): + if self.verbose: + print("stats: new epoch %d" % (self.epoch + 1)) + self.epoch += 1 + self.reset() # zero the stats + increase epoch counter + + def gather_value(self, val): + if isinstance(val, (float, int)): + val = float(val) + else: + val = val.data.cpu().numpy() + val = float(val.sum()) + return val + + def add_log_vars(self, added_log_vars, verbose=True): + for add_log_var in added_log_vars: + if add_log_var not in self.stats: + if verbose: + print(f"Adding {add_log_var}") + self.log_vars.append(add_log_var) + # self.synchronize_logged_vars(self.log_vars, verbose=verbose) + + def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"): + + if self.epoch == -1: # uninitialized + print( + "warning: epoch==-1 means uninitialized stats structure -> new_epoch() called" + ) + self.new_epoch() + + if stat_set not in self.stats: + self.stats[stat_set] = {} + self.it[stat_set] = -1 + + if not freeze_iter: + self.it[stat_set] += 1 + + epoch = self.epoch + it = self.it[stat_set] + + for stat in self.log_vars: + + if stat not in self.stats[stat_set]: + self.stats[stat_set][stat] = AverageMeter() + + if stat == "sec/it": # compute speed + if time_start is None: + elapsed = 0.0 + else: + elapsed = time.time() - time_start + time_per_it = float(elapsed) / float(it + 1) + val = time_per_it + # self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1) + else: + if stat in preds: + try: + val = self.gather_value(preds[stat]) + except KeyError: + raise ValueError( + "could not extract prediction %s\ + from the prediction dictionary" + % stat + ) + else: + val = None + + if val is not None: + self.stats[stat_set][stat].update(val, epoch=epoch, n=1) + + def get_epoch_averages(self, epoch=None): + + stat_sets = list(self.stats.keys()) + + if epoch is None: + epoch = self.epoch + if epoch == -1: + epoch = list(range(self.epoch)) + + outvals = {} + for stat_set in stat_sets: + outvals[stat_set] = { + "epoch": epoch, + "it": self.it[stat_set], + "epoch_max": self.epoch, + } + + for stat in self.stats[stat_set].keys(): + if self.stats[stat_set][stat].count == 0: + continue + if isinstance(epoch, Iterable): + avgs = self.stats[stat_set][stat].get_epoch_averages() + avgs = [avgs[e] for e in epoch] + else: + avgs = self.stats[stat_set][stat].get_epoch_averages(epoch=epoch) + outvals[stat_set][stat] = avgs + + return outvals + + def print( + self, + max_it=None, + stat_set="train", + vars_print=None, + get_str=False, + skip_nan=False, + stat_format=lambda s: s.replace("loss_", "").replace("prev_stage_", "ps_"), + ): + + epoch = self.epoch + stats = self.stats + + str_out = "" + + it = self.it[stat_set] + stat_str = "" + stats_print = sorted(stats[stat_set].keys()) + for stat in stats_print: + if stats[stat_set][stat].count == 0: + continue + if skip_nan and not np.isfinite(stats[stat_set][stat].avg): + continue + stat_str += " {0:.12}: {1:1.3f} |".format( + stat_format(stat), stats[stat_set][stat].avg + ) + + head_str = "[%s] | epoch %3d | it %5d" % (stat_set, epoch, it) + if max_it: + head_str += "/ %d" % max_it + + str_out = "%s | %s" % (head_str, stat_str) + + if get_str: + return str_out + else: + print(str_out) + + def plot_stats( + self, visdom_env=None, plot_file=None, visdom_server=None, visdom_port=None + ): + + # use the cached visdom env if none supplied + if visdom_env is None: + visdom_env = self.visdom_env + if visdom_server is None: + visdom_server = self.visdom_server + if visdom_port is None: + visdom_port = self.visdom_port + if plot_file is None: + plot_file = self.plot_file + + stat_sets = list(self.stats.keys()) + + print( + "printing charts to visdom env '%s' (%s:%d)" + % (visdom_env, visdom_server, visdom_port) + ) + + novisdom = False + + viz = get_visdom_connection(server=visdom_server, port=visdom_port) + if not viz.check_connection(): + print("no visdom server! -> skipping visdom plots") + novisdom = True + + lines = [] + + # plot metrics + if not novisdom: + viz.close(env=visdom_env, win=None) + + for stat in self.log_vars: + vals = [] + stat_sets_now = [] + for stat_set in stat_sets: + val = self.stats[stat_set][stat].get_epoch_averages() + if val is None: + continue + else: + val = np.array(val).reshape(-1) + stat_sets_now.append(stat_set) + vals.append(val) + + if len(vals) == 0: + continue + + lines.append((stat_sets_now, stat, vals)) + + if not novisdom: + for tmodes, stat, vals in lines: + title = "%s" % stat + opts = {"title": title, "legend": list(tmodes)} + for i, (tmode, val) in enumerate(zip(tmodes, vals)): + update = "append" if i > 0 else None + valid = np.where(np.isfinite(val))[0] + if len(valid) == 0: + continue + x = np.arange(len(val)) + viz.line( + Y=val[valid], + X=x[valid], + env=visdom_env, + opts=opts, + win=f"stat_plot_{title}", + name=tmode, + update=update, + ) + + if plot_file: + print("exporting stats to %s" % plot_file) + ncol = 3 + nrow = int(np.ceil(float(len(lines)) / ncol)) + matplotlib.rcParams.update({"font.size": 5}) + color = cycle(plt.cm.tab10(np.linspace(0, 1, 10))) + fig = plt.figure(1) + plt.clf() + for idx, (tmodes, stat, vals) in enumerate(lines): + c = next(color) + plt.subplot(nrow, ncol, idx + 1) + plt.gca() + for vali, vals_ in enumerate(vals): + c_ = c * (1.0 - float(vali) * 0.3) + valid = np.where(np.isfinite(vals_))[0] + if len(valid) == 0: + continue + x = np.arange(len(vals_)) + plt.plot(x[valid], vals_[valid], c=c_, linewidth=1) + plt.ylabel(stat) + plt.xlabel("epoch") + plt.gca().yaxis.label.set_color(c[0:3] * 0.75) + plt.legend(tmodes) + gcolor = np.array(mcolors.to_rgba("lightgray")) + plt.grid( + b=True, which="major", color=gcolor, linestyle="-", linewidth=0.4 + ) + plt.grid( + b=True, which="minor", color=gcolor, linestyle="--", linewidth=0.2 + ) + plt.minorticks_on() + + plt.tight_layout() + plt.show() + try: + fig.savefig(plot_file) + except PermissionError: + warnings.warn("Cant dump stats due to insufficient permissions!") + + def synchronize_logged_vars(self, log_vars, default_val=float("NaN"), verbose=True): + + stat_sets = list(self.stats.keys()) + + # remove the additional log_vars + for stat_set in stat_sets: + for stat in self.stats[stat_set].keys(): + if stat not in log_vars: + print("additional stat %s:%s -> removing" % (stat_set, stat)) + + self.stats[stat_set] = { + stat: v for stat, v in self.stats[stat_set].items() if stat in log_vars + } + + self.log_vars = log_vars # !!! + + for stat_set in stat_sets: + reference_stat = list(self.stats[stat_set].keys())[0] + for stat in log_vars: + if stat not in self.stats[stat_set]: + if verbose: + print( + "missing stat %s:%s -> filling with default values (%1.2f)" + % (stat_set, stat, default_val) + ) + elif len(self.stats[stat_set][stat].history) != self.epoch + 1: + h = self.stats[stat_set][stat].history + if len(h) == 0: # just never updated stat ... skip + continue + else: + if verbose: + print( + "incomplete stat %s:%s -> reseting with default values (%1.2f)" + % (stat_set, stat, default_val) + ) + else: + continue + + self.stats[stat_set][stat] = AverageMeter() + self.stats[stat_set][stat].reset() + + lastep = self.epoch + 1 + for ep in range(lastep): + self.stats[stat_set][stat].update(default_val, n=1, epoch=ep) + epoch_self = self.stats[stat_set][reference_stat].get_epoch() + epoch_generated = self.stats[stat_set][stat].get_epoch() + assert ( + epoch_self == epoch_generated + ), "bad epoch of synchronized log_var! %d vs %d" % ( + epoch_self, + epoch_generated, + ) + + +class StatsJSONEncoder(json.JSONEncoder): + def default(self, o): + if isinstance(o, (AverageMeter, Stats)): + enc = self.encode(o.__dict__) + return enc + else: + raise TypeError( + f"Object of type {o.__class__.__name__} " f"is not JSON serializable" + ) + + +def _get_postfixed_filename(fl, postfix): + return fl if fl.endswith(postfix) else fl + postfix diff --git a/pytorch3d/implicitron/tools/utils.py b/pytorch3d/implicitron/tools/utils.py new file mode 100644 index 00000000..5e70c1c5 --- /dev/null +++ b/pytorch3d/implicitron/tools/utils.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +import dataclasses +import time +from contextlib import contextmanager +from typing import Any, Callable, Dict + +import torch + + +@contextmanager +def evaluating(net: torch.nn.Module): + """Temporarily switch to evaluation mode.""" + istrain = net.training + try: + net.eval() + yield net + finally: + if istrain: + net.train() + + +def try_to_cuda(t: Any) -> Any: + """ + Try to move the input variable `t` to a cuda device. + + Args: + t: Input. + + Returns: + t_cuda: `t` moved to a cuda device, if supported. + """ + try: + t = t.cuda() + except AttributeError: + pass + return t + + +def try_to_cpu(t: Any) -> Any: + """ + Try to move the input variable `t` to a cpu device. + + Args: + t: Input. + + Returns: + t_cpu: `t` moved to a cpu device, if supported. + """ + try: + t = t.cpu() + except AttributeError: + pass + return t + + +def dict_to_cuda(batch: Dict[Any, Any]) -> Dict[Any, Any]: + """ + Move all values in a dictionary to cuda if supported. + + Args: + batch: Input dict. + + Returns: + batch_cuda: `batch` moved to a cuda device, if supported. + """ + return {k: try_to_cuda(v) for k, v in batch.items()} + + +def dict_to_cpu(batch): + """ + Move all values in a dictionary to cpu if supported. + + Args: + batch: Input dict. + + Returns: + batch_cpu: `batch` moved to a cpu device, if supported. + """ + return {k: try_to_cpu(v) for k, v in batch.items()} + + +def dataclass_to_cuda_(obj): + """ + Move all contents of a dataclass to cuda inplace if supported. + + Args: + batch: Input dataclass. + + Returns: + batch_cuda: `batch` moved to a cuda device, if supported. + """ + for f in dataclasses.fields(obj): + setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) + return obj + + +def dataclass_to_cpu_(obj): + """ + Move all contents of a dataclass to cpu inplace if supported. + + Args: + batch: Input dataclass. + + Returns: + batch_cuda: `batch` moved to a cpu device, if supported. + """ + for f in dataclasses.fields(obj): + setattr(obj, f.name, try_to_cpu(getattr(obj, f.name))) + return obj + + +# TODO: test it +def cat_dataclass(batch, tensor_collator: Callable): + """ + Concatenate all fields of a list of dataclasses `batch` to a single + dataclass object using `tensor_collator`. + + Args: + batch: Input list of dataclasses. + + Returns: + concatenated_batch: All elements of `batch` concatenated to a single + dataclass object. + tensor_collator: The function used to concatenate tensor fields. + """ + + elem = batch[0] + collated = {} + + for f in dataclasses.fields(elem): + elem_f = getattr(elem, f.name) + if elem_f is None: + collated[f.name] = None + elif torch.is_tensor(elem_f): + collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch]) + elif dataclasses.is_dataclass(elem_f): + collated[f.name] = cat_dataclass( + [getattr(e, f.name) for e in batch], tensor_collator + ) + elif isinstance(elem_f, collections.abc.Mapping): + collated[f.name] = { + k: tensor_collator([getattr(e, f.name)[k] for e in batch]) + if elem_f[k] is not None + else None + for k in elem_f + } + else: + raise ValueError("Unsupported field type for concatenation") + + return type(elem)(**collated) + + +class Timer: + """ + A simple class for timing execution. + + Example: + ``` + with Timer(): + print("This print statement is timed.") + ``` + """ + + def __init__(self, name="timer", quiet=False): + self.name = name + self.quiet = quiet + + def __enter__(self): + self.start = time.time() + return self + + def __exit__(self, *args): + self.end = time.time() + self.interval = self.end - self.start + if not self.quiet: + print("%20s: %1.6f sec" % (self.name, self.interval)) diff --git a/pytorch3d/implicitron/tools/video_writer.py b/pytorch3d/implicitron/tools/video_writer.py new file mode 100644 index 00000000..364dabee --- /dev/null +++ b/pytorch3d/implicitron/tools/video_writer.py @@ -0,0 +1,149 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import shutil +import tempfile +import warnings +from typing import Optional, Tuple, Union + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +from PIL import Image + + +matplotlib.use("Agg") + + +class VideoWriter: + """ + A class for exporting videos. + """ + + def __init__( + self, + cache_dir: Optional[str] = None, + ffmpeg_bin: str = "ffmpeg", + out_path: str = "/tmp/video.mp4", + fps: int = 20, + output_format: str = "visdom", + rmdir_allowed: bool = False, + **kwargs, + ): + """ + Args: + cache_dir: A directory for storing the video frames. If `None`, + a temporary directory will be used. + ffmpeg_bin: The path to an `ffmpeg` executable. + out_path: The path to the output video. + fps: The speed of the generated video in frames-per-second. + output_format: Format of the output video. Currently only `"visdom"` + is supported. + rmdir_allowed: If `True` delete and create `cache_dir` in case + it is not empty. + """ + self.rmdir_allowed = rmdir_allowed + self.output_format = output_format + self.fps = fps + self.out_path = out_path + self.cache_dir = cache_dir + self.ffmpeg_bin = ffmpeg_bin + self.frames = [] + self.regexp = "frame_%08d.png" + self.frame_num = 0 + + if self.cache_dir is not None: + self.tmp_dir = None + if os.path.isdir(self.cache_dir): + if rmdir_allowed: + shutil.rmtree(self.cache_dir) + else: + warnings.warn( + f"Warning: cache directory not empty ({self.cache_dir})." + ) + os.makedirs(self.cache_dir, exist_ok=True) + else: + self.tmp_dir = tempfile.TemporaryDirectory() + self.cache_dir = self.tmp_dir.name + + def write_frame( + self, + frame: Union[matplotlib.figure.Figure, np.ndarray, Image.Image, str], + resize: Optional[Union[float, Tuple[int, int]]] = None, + ): + """ + Write a frame to the video. + + Args: + frame: An object containing the frame image. + resize: Either a floating defining the image rescaling factor + or a 2-tuple defining the size of the output image. + """ + + outfile = os.path.join(self.cache_dir, self.regexp % self.frame_num) + + if isinstance(frame, matplotlib.figure.Figure): + plt.savefig(outfile) + im = Image.open(outfile) + elif isinstance(frame, np.ndarray): + if frame.dtype in (np.float64, np.float32, float): + frame = (np.transpose(frame, (1, 2, 0)) * 255.0).astype(np.uint8) + im = Image.fromarray(frame) + elif isinstance(frame, Image.Image): + im = frame + elif isinstance(frame, str): + im = Image.open(frame).convert("RGB") + else: + raise ValueError("Cant convert type %s" % str(type(frame))) + + if im is not None: + if resize is not None: + if isinstance(resize, float): + resize = [int(resize * s) for s in im.size] + else: + resize = im.size + # make sure size is divisible by 2 + resize = tuple([resize[i] + resize[i] % 2 for i in (0, 1)]) + im = im.resize(resize, Image.ANTIALIAS) + im.save(outfile) + + self.frames.append(outfile) + self.frame_num += 1 + + def get_video(self, quiet: bool = True): + """ + Generate the video from the written frames. + + Args: + quiet: If `True`, suppresses logging messages. + + Returns: + video_path: The path to the generated video. + """ + + regexp = os.path.join(self.cache_dir, self.regexp) + + if self.output_format == "visdom": # works for ppt too + ffmcmd_ = ( + "%s -r %d -i %s -vcodec h264 -f mp4 \ + -y -crf 18 -b 2000k -pix_fmt yuv420p '%s'" + % (self.ffmpeg_bin, self.fps, regexp, self.out_path) + ) + else: + raise ValueError("no such output type %s" % str(self.output_format)) + + if quiet: + ffmcmd_ += " > /dev/null 2>&1" + else: + print(ffmcmd_) + os.system(ffmcmd_) + + return self.out_path + + def __del__(self): + if self.tmp_dir is not None: + self.tmp_dir.cleanup() diff --git a/pytorch3d/implicitron/tools/vis_utils.py b/pytorch3d/implicitron/tools/vis_utils.py new file mode 100644 index 00000000..28990327 --- /dev/null +++ b/pytorch3d/implicitron/tools/vis_utils.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Dict, List + +import torch +from visdom import Visdom + + +def get_visdom_env(cfg): + """ + Parse out visdom environment name from the input config. + + Args: + cfg: The global config file. + + Returns: + visdom_env: The name of the visdom environment. + """ + if len(cfg.visdom_env) == 0: + visdom_env = cfg.exp_dir.split("/")[-1] + else: + visdom_env = cfg.visdom_env + return visdom_env + + +# TODO: a proper singleton +_viz_singleton = None + + +def get_visdom_connection( + server: str = "http://localhost", + port: int = 8097, +) -> Visdom: + """ + Obtain a connection to a visdom server. + + Args: + server: Server address. + port: Server port. + + Returns: + connection: The connection object. + """ + global _viz_singleton + if _viz_singleton is None: + _viz_singleton = Visdom(server=server, port=port) + return _viz_singleton + + +def visualize_basics( + viz: Visdom, + preds: Dict[str, Any], + visdom_env_imgs: str, + title: str = "", + visualize_preds_keys: List[str] = [ + "image_rgb", + "images_render", + "fg_probability", + "masks_render", + "depths_render", + "depth_map", + ], + store_history: bool = False, +) -> None: + """ + Visualize basic outputs of a `GenericModel` to visdom. + + Args: + viz: The visdom object. + preds: A dictionary containing `GenericModel` outputs. + visdom_env_imgs: Target visdom environment name. + title: The title of produced visdom window. + visualize_preds_keys: The list of keys of `preds` for visualization. + store_history: Store the history buffer in visdom windows. + """ + imout = {} + for k in visualize_preds_keys: + if k not in preds or preds[k] is None: + print(f"cant show {k}") + continue + v = preds[k].cpu().detach().clone() + if k.startswith("depth"): + # divide by 95th percentile + normfac = ( + v.view(v.shape[0], -1) + .topk(k=int(0.05 * (v.numel() // v.shape[0])), dim=-1) + .values[:, -1] + ) + v = v / normfac[:, None, None, None].clamp(1e-4) + if v.shape[1] == 1: + v = v.repeat(1, 3, 1, 1) + v = torch.nn.functional.interpolate( + v, + # pyre-fixme[6]: Expected `Optional[typing.List[float]]` for 2nd param + # but got `float`. + scale_factor=( + 600.0 + if ( + "_eval" in visdom_env_imgs + and k in ("images_render", "depths_render") + ) + else 200.0 + ) + / v.shape[2], + mode="bilinear", + ) + imout[k] = v + + # TODO: handle errors on the outside + try: + imout = {"all": torch.cat(list(imout.values()), dim=2)} + except: + print("cant cat!") + + for k, v in imout.items(): + viz.images( + v.clamp(0.0, 1.0), + win=k, + env=visdom_env_imgs, + opts={"title": title + "_" + k, "store_history": store_history}, + ) + + +def make_depth_image( + depths: torch.Tensor, + masks: torch.Tensor, + max_quantile: float = 0.98, + min_quantile: float = 0.02, + min_out_depth: float = 0.1, + max_out_depth: float = 0.9, +) -> torch.Tensor: + """ + Convert a batch of depth maps to a grayscale image. + + Args: + depths: A tensor of shape `(B, 1, H, W)` containing a batch of depth maps. + masks: A tensor of shape `(B, 1, H, W)` containing a batch of foreground masks. + max_quantile: The quantile of the input depth values which will + be mapped to `max_out_depth`. + min_quantile: The quantile of the input depth values which will + be mapped to `min_out_depth`. + min_out_depth: The minimal value in each depth map will be assigned this color. + max_out_depth: The maximal value in each depth map will be assigned this color. + + Returns: + depth_image: A tensor of shape `(B, 1, H, W)` a batch of grayscale + depth images. + """ + normfacs = [] + for d, m in zip(depths, masks): + ok = (d.view(-1) > 1e-6) * (m.view(-1) > 0.5) + if ok.sum() <= 1: + print("empty depth!") + normfacs.append(torch.zeros(2).type_as(depths)) + continue + dok = d.view(-1)[ok].view(-1) + _maxk = max(int(round((1 - max_quantile) * (dok.numel()))), 1) + _mink = max(int(round(min_quantile * (dok.numel()))), 1) + normfac_max = dok.topk(k=_maxk, dim=-1).values[-1] + normfac_min = dok.topk(k=_mink, dim=-1, largest=False).values[-1] + normfacs.append(torch.stack([normfac_min, normfac_max])) + normfacs = torch.stack(normfacs) + _min, _max = (normfacs[:, 0].view(-1, 1, 1, 1), normfacs[:, 1].view(-1, 1, 1, 1)) + depths = (depths - _min) / (_max - _min).clamp(1e-4) + depths = ( + (depths * (max_out_depth - min_out_depth) + min_out_depth) * masks.float() + ).clamp(0.0, 1.0) + return depths diff --git a/tests/implicitron/__init__.py b/tests/implicitron/__init__.py new file mode 100644 index 00000000..2e41cd71 --- /dev/null +++ b/tests/implicitron/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/implicitron/common_resources.py b/tests/implicitron/common_resources.py new file mode 100644 index 00000000..2c2620c1 --- /dev/null +++ b/tests/implicitron/common_resources.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +import tempfile +import unittest +from pathlib import Path +from typing import Generator, Tuple +from zipfile import ZipFile + +from iopath.common.file_io import PathManager + + +@contextlib.contextmanager +def get_skateboard_data( + avoid_manifold: bool = False, silence_logs: bool = False +) -> Generator[Tuple[str, PathManager], None, None]: + """ + Context manager for accessing Co3D dataset by tests, at least for + the first 5 skateboards. Internally, we want this to exercise the + normal way to access the data directly manifold, but on an RE + worker this is impossible so we use a workaround. + + Args: + avoid_manifold: Use the method used by RE workers even locally. + silence_logs: Whether to reduce log output from iopath library. + + Yields: + dataset_root: (str) path to dataset root. + path_manager: path_manager to access it with. + """ + path_manager = PathManager() + if silence_logs: + logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL) + logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL) + + if not os.environ.get("FB_TEST", False): + if os.getenv("FAIR_ENV_CLUSTER", "") == "": + raise unittest.SkipTest("Unknown environment. Data not available.") + yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", path_manager + + elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False): + from libfb.py.parutil import get_file_path + + par_path = "skateboard_first_5" + source = get_file_path(par_path) + assert Path(source).is_file() + with tempfile.TemporaryDirectory() as dest: + with ZipFile(source) as f: + f.extractall(dest) + yield os.path.join(dest, "extracted"), path_manager + else: + from iopath.fb.manifold import ManifoldPathHandler + + path_manager.register_handler(ManifoldPathHandler()) + + yield "manifold://co3d/tree/extracted", path_manager + + +def provide_lpips_vgg(): + """ + Ensure the weights files are available for lpips.LPIPS(net="vgg") + to be called. Specifically, torchvision's vgg16 + """ + # In OSS, torchvision looks for vgg16 weights in + # https://download.pytorch.org/models/vgg16-397923af.pth + # Inside fbcode, this is replaced by asking iopath for + # manifold://torchvision/tree/models/vgg16-397923af.pth + # (the code for this replacement is in + # fbcode/pytorch/vision/fb/_internally_replaced_utils.py ) + # + # iopath does this by looking for the file at the cache location + # and if it is not there getting it from manifold. + # (the code for this is in + # fbcode/fair_infra/data/iopath/iopath/fb/manifold.py ) + # + # On the remote execution worker, manifold is inaccessible. + # We solve this by making the cached file available before iopath + # looks. + # + # By default the cache location is + # ~/.torch/iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth + # But we can't write to the home directory on the RE worker. + # We define FVCORE_CACHE to change the cache location to + # iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth + # (Without it, manifold caches in unstable temporary locations on RE.) + # + # The file we want has been copied from + # tree/models/vgg16-397923af.pth in the torchvision bucket + # to + # tree/testing/vgg16-397923af.pth in the co3d bucket + # and the TARGETS file copies it somewhere in the PAR which we + # recover with get_file_path. + # (It can't copy straight to a nested location, see + # https://fb.workplace.com/groups/askbuck/posts/2644615728920359/) + # Here we symlink it to the new cache location. + if os.environ.get("INSIDE_RE_WORKER") is not None: + from libfb.py.parutil import get_file_path + + os.environ["FVCORE_CACHE"] = "iopath_cache" + + par_path = "vgg_weights_for_lpips" + source = Path(get_file_path(par_path)) + assert source.is_file() + + dest = Path("iopath_cache/manifold_cache/tree/models") + if not dest.exists(): + dest.mkdir(parents=True) + (dest / "vgg16-397923af.pth").symlink_to(source) diff --git a/tests/implicitron/data/overrides.yaml b/tests/implicitron/data/overrides.yaml new file mode 100644 index 00000000..1f748d3b --- /dev/null +++ b/tests/implicitron/data/overrides.yaml @@ -0,0 +1,122 @@ +mask_images: true +mask_depths: true +render_image_width: 400 +render_image_height: 400 +mask_threshold: 0.5 +output_rasterized_mc: false +bg_color: +- 0.0 +- 0.0 +- 0.0 +view_pool: false +num_passes: 1 +chunk_size_grid: 4096 +render_features_dimensions: 3 +tqdm_trigger_threshold: 16 +n_train_target_views: 1 +sampling_mode_training: mask_sample +sampling_mode_evaluation: full_grid +renderer_class_type: LSTMRenderer +feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator +implicit_function_class_type: IdrFeatureField +loss_weights: + loss_rgb_mse: 1.0 + loss_prev_stage_rgb_mse: 1.0 + loss_mask_bce: 0.0 + loss_prev_stage_mask_bce: 0.0 +log_vars: +- loss_rgb_psnr_fg +- loss_rgb_psnr +- loss_rgb_mse +- loss_rgb_huber +- loss_depth_abs +- loss_depth_abs_fg +- loss_mask_neg_iou +- loss_mask_bce +- loss_mask_beta_prior +- loss_eikonal +- loss_density_tv +- loss_depth_neg_penalty +- loss_autodecoder_norm +- loss_prev_stage_rgb_mse +- loss_prev_stage_rgb_psnr_fg +- loss_prev_stage_rgb_psnr +- loss_prev_stage_mask_bce +- objective +- epoch +- sec/it +sequence_autodecoder_args: + encoding_dim: 0 + n_instances: 0 + init_scale: 1.0 + ignore_input: false +raysampler_args: + image_width: 400 + image_height: 400 + scene_center: + - 0.0 + - 0.0 + - 0.0 + scene_extent: 0.0 + sampling_mode_training: mask_sample + sampling_mode_evaluation: full_grid + n_pts_per_ray_training: 64 + n_pts_per_ray_evaluation: 64 + n_rays_per_image_sampled_from_mask: 1024 + min_depth: 0.1 + max_depth: 8.0 + stratified_point_sampling_training: true + stratified_point_sampling_evaluation: false +renderer_LSTMRenderer_args: + num_raymarch_steps: 10 + init_depth: 17.0 + init_depth_noise_std: 0.0005 + hidden_size: 16 + n_feature_channels: 256 + verbose: false +image_feature_extractor_args: + name: resnet34 + pretrained: true + stages: + - 1 + - 2 + - 3 + - 4 + normalize_image: true + image_rescale: 0.16 + first_max_pool: true + proj_dim: 32 + l2_norm: true + add_masks: true + add_images: true + global_average_pool: false + feature_rescale: 1.0 +view_sampler_args: + masked_sampling: false + sampling_mode: bilinear +feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + weight_by_ray_angle_gamma: 1.0 + min_ray_angle_weight: 0.1 +implicit_function_IdrFeatureField_args: + feature_vector_size: 3 + d_in: 3 + d_out: 1 + dims: + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + - 512 + geometric_init: true + bias: 1.0 + skip_in: [] + weight_norm: true + n_harmonic_functions_xyz: 0 + pooled_feature_dim: 0 + encoding_dim: 0 diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py new file mode 100644 index 00000000..a2ae074a --- /dev/null +++ b/tests/implicitron/test_batch_sampler.py @@ -0,0 +1,215 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest +from collections import defaultdict +from dataclasses import dataclass + +from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler + + +@dataclass +class MockFrameAnnotation: + frame_number: int + frame_timestamp: float = 0.0 + + +class MockDataset: + def __init__(self, num_seq, max_frame_gap=1): + """ + Makes a gap of max_frame_gap frame numbers in the middle of each sequence + """ + self.seq_annots = {f"seq_{i}": None for i in range(num_seq)} + self.seq_to_idx = { + f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq) + } + + # frame numbers within sequence: [0, ..., 4, n, ..., n+4] + # where n - 4 == max_frame_gap + frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap)) + self.frame_annots = [ + {"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq + ] + + def get_frame_numbers_and_timestamps(self, idxs): + out = [] + for idx in idxs: + frame_annotation = self.frame_annots[idx]["frame_annotation"] + out.append( + (frame_annotation.frame_number, frame_annotation.frame_timestamp) + ) + return out + + +class TestSceneBatchSampler(unittest.TestCase): + def setUp(self): + self.dataset_overfit = MockDataset(1) + + def test_overfit(self): + num_batches = 3 + batch_size = 10 + sampler = SceneBatchSampler( + self.dataset_overfit, + batch_size=batch_size, + num_batches=num_batches, + images_per_seq_options=[10], # will try to sample batch_size anyway + ) + + self.assertEqual(len(sampler), num_batches) + + it = iter(sampler) + for _ in range(num_batches): + batch = next(it) + self.assertIsNotNone(batch) + self.assertEqual(len(batch), batch_size) # true for our examples + self.assertTrue(all(idx // 10 == 0 for idx in batch)) + + with self.assertRaises(StopIteration): + batch = next(it) + + def test_multiseq(self): + for ips_options in [[10], [2], [3], [2, 3, 4]]: + for sample_consecutive_frames in [True, False]: + for consecutive_frames_max_gap in [0, 1, 3]: + self._test_multiseq_flavour( + ips_options, + sample_consecutive_frames, + consecutive_frames_max_gap, + ) + + def test_multiseq_gaps(self): + num_batches = 16 + batch_size = 10 + dataset_multiseq = MockDataset(5, max_frame_gap=3) + for ips_options in [[10], [2], [3], [2, 3, 4]]: + debug_info = f" Images per sequence: {ips_options}." + + sampler = SceneBatchSampler( + dataset_multiseq, + batch_size=batch_size, + num_batches=num_batches, + images_per_seq_options=ips_options, + sample_consecutive_frames=True, + consecutive_frames_max_gap=1, + ) + + self.assertEqual(len(sampler), num_batches, msg=debug_info) + + it = iter(sampler) + for _ in range(num_batches): + batch = next(it) + self.assertIsNotNone(batch, "batch is None in" + debug_info) + if max(ips_options) > 5: + # true for our examples + self.assertEqual(len(batch), 5, msg=debug_info) + else: + # true for our examples + self.assertEqual(len(batch), batch_size, msg=debug_info) + + self._check_frames_are_consecutive( + batch, dataset_multiseq.frame_annots, debug_info + ) + + def _test_multiseq_flavour( + self, + ips_options, + sample_consecutive_frames, + consecutive_frames_max_gap, + num_batches=16, + batch_size=10, + ): + debug_info = ( + f" Images per sequence: {ips_options}, " + f"sample_consecutive_frames: {sample_consecutive_frames}, " + f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, " + ) + # in this test, either consecutive_frames_max_gap == max_frame_gap, + # or consecutive_frames_max_gap == 0, so segments consist of full sequences + frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3 + dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap) + sampler = SceneBatchSampler( + dataset_multiseq, + batch_size=batch_size, + num_batches=num_batches, + images_per_seq_options=ips_options, + sample_consecutive_frames=sample_consecutive_frames, + consecutive_frames_max_gap=consecutive_frames_max_gap, + ) + + self.assertEqual(len(sampler), num_batches, msg=debug_info) + + it = iter(sampler) + typical_counts = set() + for _ in range(num_batches): + batch = next(it) + self.assertIsNotNone(batch, "batch is None in" + debug_info) + # true for our examples + self.assertEqual(len(batch), batch_size, msg=debug_info) + # find distribution over sequences + counts = _count_by_quotient(batch, 10) + freqs = _count_by_quotient(counts.values(), 1) + self.assertLessEqual( + len(freqs), + 2, + msg="We should have maximum of 2 different " + "frequences of sequences in the batch." + debug_info, + ) + if len(freqs) == 2: + most_seq_count = max(*freqs.keys()) + last_seq = min(*freqs.keys()) + self.assertEqual( + freqs[last_seq], + 1, + msg="Only one odd sequence allowed." + debug_info, + ) + else: + self.assertEqual(len(freqs), 1) + most_seq_count = next(iter(freqs)) + + self.assertIn(most_seq_count, ips_options) + typical_counts.add(most_seq_count) + + if sample_consecutive_frames: + self._check_frames_are_consecutive( + batch, + dataset_multiseq.frame_annots, + debug_info, + max_gap=consecutive_frames_max_gap, + ) + + self.assertTrue( + all(i in typical_counts for i in ips_options), + "Some of the frequency options did not occur among " + f"the {num_batches} batches (could be just bad luck)." + debug_info, + ) + + with self.assertRaises(StopIteration): + batch = next(it) + + def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1): + # make sure that sampled frames are consecutive + for i in range(len(batch) - 1): + curr_idx, next_idx = batch[i : i + 2] + if curr_idx // 10 == next_idx // 10: # same sequence + if max_gap > 0: + curr_idx, next_idx = [ + annots[idx]["frame_annotation"].frame_number + for idx in (curr_idx, next_idx) + ] + gap = max_gap + else: + gap = 1 # we'll check that raw dataset indices are consecutive + + self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info) + + +def _count_by_quotient(indices, divisor): + counter = defaultdict(int) + for i in indices: + counter[i // divisor] += 1 + + return counter diff --git a/tests/implicitron/test_circle_fitting.py b/tests/implicitron/test_circle_fitting.py new file mode 100644 index 00000000..5177f2f3 --- /dev/null +++ b/tests/implicitron/test_circle_fitting.py @@ -0,0 +1,177 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest +from math import pi + +import torch +from pytorch3d.implicitron.tools.circle_fitting import ( + _signed_area, + fit_circle_in_2d, + fit_circle_in_3d, +) +from pytorch3d.transforms import random_rotation + + +if os.environ.get("FB_TEST", False): + from common_testing import TestCaseMixin +else: + from tests.common_testing import TestCaseMixin + + +class TestCircleFitting(TestCaseMixin, unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def _assertParallel(self, a, b, **kwargs): + """ + Given a and b of shape (..., 3) each containing 3D vectors, + assert that correspnding vectors are parallel. Changed sign is ok. + """ + self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs) + + def test_simple_3d(self): + device = torch.device("cuda:0") + for _ in range(7): + radius = 10 * torch.rand(1, device=device)[0] + center = 10 * torch.rand(3, device=device) + rot = random_rotation(device=device) + offset = torch.rand(3, device=device) + up = torch.rand(3, device=device) + self._simple_3d_test(radius, center, rot, offset, up) + + def _simple_3d_test(self, radius, center, rot, offset, up): + # angles are increasing so the points move in a well defined direction. + angles = torch.cumsum(torch.rand(17, device=rot.device), dim=0) + many = torch.stack( + [torch.cos(angles), torch.sin(angles), torch.zeros_like(angles)], dim=1 + ) + source_points = (many * radius) @ rot + center[None] + + # case with no generation + result = fit_circle_in_3d(source_points) + self.assertClose(result.radius, radius) + self.assertClose(result.center, center) + self._assertParallel(result.normal, rot[2], atol=1e-5) + self.assertEqual(result.generated_points.shape, (0, 3)) + + # Generate 5 points around the circle + n_new_points = 5 + result2 = fit_circle_in_3d(source_points, n_points=n_new_points) + self.assertClose(result2.radius, radius) + self.assertClose(result2.center, center) + self.assertClose(result2.normal, result.normal) + self.assertEqual(result2.generated_points.shape, (5, 3)) + + observed_points = result2.generated_points + self.assertClose(observed_points[0], observed_points[4], atol=1e-4) + self.assertClose(observed_points[0], source_points[0], atol=1e-5) + observed_normal = torch.cross( + observed_points[0] - observed_points[2], + observed_points[1] - observed_points[3], + dim=-1, + ) + self._assertParallel(observed_normal, result.normal, atol=1e-4) + diameters = observed_points[:2] - observed_points[2:4] + self.assertClose( + torch.norm(diameters, dim=1), diameters.new_full((2,), 2 * radius) + ) + + # Regenerate the input points + result3 = fit_circle_in_3d(source_points, angles=angles - angles[0]) + self.assertClose(result3.radius, radius) + self.assertClose(result3.center, center) + self.assertClose(result3.normal, result.normal) + self.assertClose(result3.generated_points, source_points, atol=1e-5) + + # Test with offset + result4 = fit_circle_in_3d( + source_points, angles=angles - angles[0], offset=offset, up=up + ) + self.assertClose(result4.radius, radius) + self.assertClose(result4.center, center) + self.assertClose(result4.normal, result.normal) + observed_offsets = result4.generated_points - source_points + + # observed_offset is constant + self.assertClose( + observed_offsets.min(0).values, observed_offsets.max(0).values, atol=1e-5 + ) + # observed_offset has the right length + self.assertClose(observed_offsets[0].norm(), offset.norm()) + + self.assertClose(result.normal.norm(), torch.ones(())) + # component of observed_offset along normal + component = torch.dot(observed_offsets[0], result.normal) + self.assertClose(component.abs(), offset[2].abs(), atol=1e-5) + agree_normal = torch.dot(result.normal, up) > 0 + agree_signs = component * offset[2] > 0 + self.assertEqual(agree_normal, agree_signs) + + def test_simple_2d(self): + radius = 7.0 + center = torch.tensor([9, 2.5]) + angles = torch.cumsum(torch.rand(17), dim=0) + many = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1) + source_points = (many * radius) + center[None] + + result = fit_circle_in_2d(source_points) + self.assertClose(result.radius, torch.tensor(radius)) + self.assertClose(result.center, center) + self.assertEqual(result.generated_points.shape, (0, 2)) + + # Generate 5 points around the circle + n_new_points = 5 + result2 = fit_circle_in_2d(source_points, n_points=n_new_points) + self.assertClose(result2.radius, torch.tensor(radius)) + self.assertClose(result2.center, center) + self.assertEqual(result2.generated_points.shape, (5, 2)) + + observed_points = result2.generated_points + self.assertClose(observed_points[0], observed_points[4]) + self.assertClose(observed_points[0], source_points[0], atol=1e-5) + diameters = observed_points[:2] - observed_points[2:4] + self.assertClose(torch.norm(diameters, dim=1), torch.full((2,), 2 * radius)) + + # Regenerate the input points + result3 = fit_circle_in_2d(source_points, angles=angles - angles[0]) + self.assertClose(result3.radius, torch.tensor(radius)) + self.assertClose(result3.center, center) + self.assertClose(result3.generated_points, source_points, atol=1e-5) + + def test_minimum_inputs(self): + fit_circle_in_3d(torch.rand(3, 3), n_points=10) + + with self.assertRaisesRegex( + ValueError, "2 points are not enough to determine a circle" + ): + fit_circle_in_3d(torch.rand(2, 3)) + + def test_signed_area(self): + n_points = 1001 + angles = torch.linspace(0, 2 * pi, n_points) + radius = 0.85 + center = torch.rand(2) + circle = center + radius * torch.stack( + [torch.cos(angles), torch.sin(angles)], dim=1 + ) + circle_area = torch.tensor(pi * radius * radius) + self.assertClose(_signed_area(circle), circle_area) + # clockwise is negative + self.assertClose(_signed_area(circle.flip(0)), -circle_area) + + # Semicircles + self.assertClose(_signed_area(circle[: (n_points + 1) // 2]), circle_area / 2) + self.assertClose(_signed_area(circle[n_points // 2 :]), circle_area / 2) + + # A straight line bounds no area + self.assertClose(_signed_area(torch.rand(2, 2)), torch.tensor(0.0)) + + # Letter 'L' written anticlockwise. + L_shape = [[0, 1], [0, 0], [1, 0]] + # Triangle area is 0.5 * b * h. + self.assertClose(_signed_area(torch.tensor(L_shape)), torch.tensor(0.5)) diff --git a/tests/implicitron/test_config.py b/tests/implicitron/test_config.py new file mode 100644 index 00000000..a1dd7fd0 --- /dev/null +++ b/tests/implicitron/test_config.py @@ -0,0 +1,610 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import textwrap +import unittest +from dataclasses import dataclass, field, is_dataclass +from enum import Enum +from typing import List, Optional, Tuple + +from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError +from pytorch3d.implicitron.tools.config import ( + Configurable, + ReplaceableBase, + _is_actually_dataclass, + _Registry, + expand_args_fields, + get_default_args, + get_default_args_field, + registry, + remove_unused_components, + run_auto_creation, +) + + +@dataclass +class Animal(ReplaceableBase): + pass + + +class Fruit(ReplaceableBase): + pass + + +@registry.register +class Banana(Fruit): + pips: int + spots: int + bananame: str + + +@registry.register +class Pear(Fruit): + n_pips: int = 13 + + +class Pineapple(Fruit): + pass + + +@registry.register +class Orange(Fruit): + pass + + +@registry.register +class Kiwi(Fruit): + pass + + +@registry.register +class LargePear(Pear): + pass + + +class MainTest(Configurable): + the_fruit: Fruit + n_ids: int + n_reps: int = 8 + the_second_fruit: Fruit + + def create_the_second_fruit(self): + expand_args_fields(Pineapple) + self.the_second_fruit = Pineapple() + + def __post_init__(self): + run_auto_creation(self) + + +class TestConfig(unittest.TestCase): + def test_is_actually_dataclass(self): + @dataclass + class A: + pass + + self.assertTrue(_is_actually_dataclass(A)) + self.assertTrue(is_dataclass(A)) + + class B(A): + a: int + + self.assertFalse(_is_actually_dataclass(B)) + self.assertTrue(is_dataclass(B)) + + def test_simple_replacement(self): + struct = get_default_args(MainTest) + struct.n_ids = 9780 + struct.the_fruit_Pear_args.n_pips = 3 + struct.the_fruit_class_type = "Pear" + struct.the_second_fruit_class_type = "Pear" + + main = MainTest(**struct) + self.assertIsInstance(main.the_fruit, Pear) + self.assertEqual(main.n_reps, 8) + self.assertEqual(main.n_ids, 9780) + self.assertEqual(main.the_fruit.n_pips, 3) + self.assertIsInstance(main.the_second_fruit, Pineapple) + + struct2 = get_default_args(MainTest) + self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13) + + self.assertEqual( + MainTest._creation_functions, + ("create_the_fruit", "create_the_second_fruit"), + ) + + def test_detect_bases(self): + # testing the _base_class_from_class function + self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase)) + self.assertIsNone(_Registry._base_class_from_class(MainTest)) + self.assertIs(_Registry._base_class_from_class(Fruit), Fruit) + self.assertIs(_Registry._base_class_from_class(Pear), Fruit) + + class PricklyPear(Pear): + pass + + self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit) + + def test_registry_entries(self): + self.assertIs(registry.get(Fruit, "Banana"), Banana) + with self.assertRaisesRegex(ValueError, "Banana has not been registered."): + registry.get(Animal, "Banana") + with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."): + registry.get(Fruit, "PricklyPear") + + self.assertIs(registry.get(Pear, "Pear"), Pear) + self.assertIs(registry.get(Pear, "LargePear"), LargePear) + with self.assertRaisesRegex(ValueError, "Banana resolves to"): + registry.get(Pear, "Banana") + + all_fruit = set(registry.get_all(Fruit)) + self.assertIn(Banana, all_fruit) + self.assertIn(Pear, all_fruit) + self.assertIn(LargePear, all_fruit) + self.assertEqual(set(registry.get_all(Pear)), {LargePear}) + + @registry.register + class Apple(Fruit): + pass + + @registry.register + class CrabApple(Apple): + pass + + self.assertEqual(set(registry.get_all(Apple)), {CrabApple}) + + self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple) + + with self.assertRaisesRegex(ValueError, "Cannot tell what it is."): + + @registry.register + class NotAFruit: + pass + + def test_recursion(self): + class Shape(ReplaceableBase): + pass + + @registry.register + class Triangle(Shape): + a: float = 5.0 + + @registry.register + class Square(Shape): + a: float = 3.0 + + @registry.register + class LargeShape(Shape): + inner: Shape + + def __post_init__(self): + run_auto_creation(self) + + class ShapeContainer(Configurable): + shape: Shape + + container = ShapeContainer(**get_default_args(ShapeContainer)) + # This is because ShapeContainer is missing __post_init__ + with self.assertRaises(AttributeError): + container.shape + + class ShapeContainer2(Configurable): + x: Shape + x_class_type: str = "LargeShape" + + def __post_init__(self): + self.x_LargeShape_args.inner_class_type = "Triangle" + run_auto_creation(self) + + container2_args = get_default_args(ShapeContainer2) + container2_args.x_LargeShape_args.inner_Triangle_args.a += 10 + self.assertIn("inner_Square_args", container2_args.x_LargeShape_args) + # We do not perform expansion that would result in an infinite recursion, + # so this member is not present. + self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args) + container2_args.x_LargeShape_args.inner_Square_args.a += 100 + container2 = ShapeContainer2(**container2_args) + self.assertIsInstance(container2.x, LargeShape) + self.assertIsInstance(container2.x.inner, Triangle) + self.assertEqual(container2.x.inner.a, 15.0) + + def test_simpleclass_member(self): + # Members which are not dataclasses are + # tolerated. But it would be nice to be able to + # configure them. + class Foo: + def __init__(self, a=1, b=2): + self.a, self.b = a, b + + @dataclass() + class Bar: + aa: int = 9 + bb: int = 9 + + class Container(Configurable): + bar: Bar = Bar() + # TODO make this work? + # foo: Foo = Foo() + fruit: Fruit + fruit_class_type: str = "Orange" + + def __post_init__(self): + run_auto_creation(self) + + self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2}) + container_args = get_default_args(Container) + container = Container(**container_args) + self.assertIsInstance(container.fruit, Orange) + # self.assertIsInstance(container.bar, Bar) + + container_defaulted = Container() + container_defaulted.fruit_Pear_args.n_pips += 4 + + container_args2 = get_default_args(Container) + container = Container(**container_args2) + self.assertEqual(container.fruit_Pear_args.n_pips, 13) + + def test_inheritance(self): + class FruitBowl(ReplaceableBase): + main_fruit: Fruit + main_fruit_class_type: str = "Orange" + + def __post_init__(self): + raise ValueError("This doesn't get called") + + class LargeFruitBowl(FruitBowl): + extra_fruit: Fruit + extra_fruit_class_type: str = "Kiwi" + + def __post_init__(self): + run_auto_creation(self) + + large_args = get_default_args(LargeFruitBowl) + self.assertNotIn("extra_fruit", large_args) + self.assertNotIn("main_fruit", large_args) + large = LargeFruitBowl(**large_args) + self.assertIsInstance(large.main_fruit, Orange) + self.assertIsInstance(large.extra_fruit, Kiwi) + + def test_inheritance2(self): + # This is a case where a class could contain an instance + # of a subclass, which is ignored. + class Parent(ReplaceableBase): + pass + + class Main(Configurable): + parent: Parent + # Note - no __post__init__ + + @registry.register + class Derived(Parent, Main): + pass + + args = get_default_args(Main) + # Derived has been ignored in processing Main. + self.assertCountEqual(args.keys(), ["parent_class_type"]) + + main = Main(**args) + + with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."): + run_auto_creation(main) + + main.parent_class_type = "Derived" + # Illustrates that a dict works fine instead of a DictConfig. + main.parent_Derived_args = {} + with self.assertRaises(AttributeError): + main.parent + run_auto_creation(main) + self.assertIsInstance(main.parent, Derived) + + def test_redefine(self): + class FruitBowl(ReplaceableBase): + main_fruit: Fruit + main_fruit_class_type: str = "Grape" + + def __post_init__(self): + run_auto_creation(self) + + @registry.register + @dataclass + class Grape(Fruit): + large: bool = False + + def get_color(self): + return "red" + + def __post_init__(self): + raise ValueError("This doesn't get called") + + bowl_args = get_default_args(FruitBowl) + + @registry.register + @dataclass + class Grape(Fruit): # noqa: F811 + large: bool = True + + def get_color(self): + return "green" + + with self.assertWarnsRegex( + UserWarning, "New implementation of Grape is being chosen." + ): + bowl = FruitBowl(**bowl_args) + self.assertIsInstance(bowl.main_fruit, Grape) + + # Redefining the same class won't help with defaults because encoded in args + self.assertEqual(bowl.main_fruit.large, False) + + # But the override worked. + self.assertEqual(bowl.main_fruit.get_color(), "green") + + # 2. Try redefining without the dataclass modifier + # This relies on the fact that default creation processes the class. + # (otherwise incomprehensible messages) + @registry.register + class Grape(Fruit): # noqa: F811 + large: bool = True + + with self.assertWarnsRegex( + UserWarning, "New implementation of Grape is being chosen." + ): + bowl = FruitBowl(**bowl_args) + + # 3. Adding a new class doesn't get picked up, because the first + # get_default_args call has frozen FruitBowl. This is intrinsic to + # the way dataclass and expand_args_fields work in-place but + # expand_args_fields is not pure - it depends on the registry. + @registry.register + class Fig(Fruit): + pass + + bowl_args2 = get_default_args(FruitBowl) + self.assertIn("main_fruit_Grape_args", bowl_args2) + self.assertNotIn("main_fruit_Fig_args", bowl_args2) + + # TODO Is it possible to make this work? + # bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig) + # bowl_args2.main_fruit_class_type = "Fig" + # bowl2 = FruitBowl(**bowl_args2) <= unexpected argument + + # Note that it is possible to use Fig if you can set + # bowl2.main_fruit_Fig_args explicitly (not in bowl_args2) + # before run_auto_creation happens. See test_inheritance2 + # for an example. + + def test_no_replacement(self): + # Test of Configurables without ReplaceableBase + class A(Configurable): + n: int = 9 + + class B(Configurable): + a: A + + def __post_init__(self): + run_auto_creation(self) + + class C(Configurable): + b: B + + def __post_init__(self): + run_auto_creation(self) + + c_args = get_default_args(C) + c = C(**c_args) + self.assertIsInstance(c.b.a, A) + self.assertEqual(c.b.a.n, 9) + + def test_doc(self): + # The case in the docstring. + class A(ReplaceableBase): + k: int = 1 + + @registry.register + class A1(A): + m: int = 3 + + @registry.register + class A2(A): + n: str = "2" + + class B(Configurable): + a: A + a_class_type: str = "A2" + + def __post_init__(self): + run_auto_creation(self) + + b_args = get_default_args(B) + self.assertNotIn("a", b_args) + b = B(**b_args) + self.assertEqual(b.a.n, "2") + + def test_raw_types(self): + @dataclass + class MyDataclass: + int_field: int = 0 + none_field: Optional[int] = None + float_field: float = 9.3 + bool_field: bool = True + tuple_field: tuple = (3, True, "j") + + class SimpleClass: + def __init__(self, tuple_member_=(3, 4)): + self.tuple_member = tuple_member_ + + def get_tuple(self): + return self.tuple_member + + def f(*, a: int = 3, b: str = "kj"): + self.assertEqual(a, 3) + self.assertEqual(b, "kj") + + class C(Configurable): + simple: DictConfig = get_default_args_field(SimpleClass) + # simple2: SimpleClass2 = SimpleClass2() + mydata: DictConfig = get_default_args_field(MyDataclass) + a_tuple: Tuple[float] = (4.0, 3.0) + f_args: DictConfig = get_default_args_field(f) + + args = get_default_args(C) + c = C(**args) + self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"]) + + mydata = MyDataclass(**c.mydata) + simple = SimpleClass(**c.simple) + + # OmegaConf converts tuples to ListConfigs (which act like lists). + self.assertEqual(simple.get_tuple(), [3, 4]) + self.assertTrue(isinstance(simple.get_tuple(), ListConfig)) + self.assertEqual(c.a_tuple, [4.0, 3.0]) + self.assertTrue(isinstance(c.a_tuple, ListConfig)) + self.assertEqual(mydata.tuple_field, (3, True, "j")) + self.assertTrue(isinstance(mydata.tuple_field, ListConfig)) + f(**c.f_args) + + def test_irrelevant_bases(self): + class NotADataclass: + # Like torch.nn.Module, this class contains annotations + # but is not designed to be dataclass'd. + # This test ensures that such classes, when inherited fron, + # are not accidentally expand_args_fields. + a: int = 9 + b: int + + class LeftConfigured(Configurable, NotADataclass): + left: int = 1 + + class RightConfigured(NotADataclass, Configurable): + right: int = 2 + + class Outer(Configurable): + left: LeftConfigured + right: RightConfigured + + def __post_init__(self): + run_auto_creation(self) + + outer = Outer(**get_default_args(Outer)) + self.assertEqual(outer.left.left, 1) + self.assertEqual(outer.right.right, 2) + with self.assertRaisesRegex(TypeError, "non-default argument"): + dataclass(NotADataclass) + + def test_unprocessed(self): + # behavior of Configurable classes which need processing in __new__, + class Unprocessed(Configurable): + a: int = 9 + + class UnprocessedReplaceable(ReplaceableBase): + a: int = 1 + + with self.assertWarnsRegex(UserWarning, "must be processed"): + Unprocessed() + with self.assertWarnsRegex(UserWarning, "must be processed"): + UnprocessedReplaceable() + + def test_enum(self): + # Test that enum values are kept, i.e. that OmegaConf's runtime checks + # are in use. + + class A(Enum): + B1 = "b1" + B2 = "b2" + + class C(Configurable): + a: A = A.B1 + + base = get_default_args(C) + replaced = OmegaConf.merge(base, {"a": "B2"}) + self.assertEqual(replaced.a, A.B2) + with self.assertRaises(ValidationError): + # You can't use a value which is not one of the + # choices, even if it is the str representation + # of one of the choices. + OmegaConf.merge(base, {"a": "b2"}) + + remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base))) + self.assertEqual(remerged.a, A.B1) + + def test_remove_unused_components(self): + struct = get_default_args(MainTest) + struct.n_ids = 32 + struct.the_fruit_class_type = "Pear" + struct.the_second_fruit_class_type = "Banana" + remove_unused_components(struct) + expected_keys = [ + "n_ids", + "n_reps", + "the_fruit_Pear_args", + "the_fruit_class_type", + "the_second_fruit_Banana_args", + "the_second_fruit_class_type", + ] + expected_yaml = textwrap.dedent( + """\ + n_ids: 32 + n_reps: 8 + the_fruit_class_type: Pear + the_fruit_Pear_args: + n_pips: 13 + the_second_fruit_class_type: Banana + the_second_fruit_Banana_args: + pips: ??? + spots: ??? + bananame: ??? + """ + ) + self.assertEqual(sorted(struct.keys()), expected_keys) + + # Check that struct is what we expect + expected = OmegaConf.create(expected_yaml) + self.assertEqual(struct, expected) + + # Check that we get what we expect when writing to yaml. + self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml) + + main = MainTest(**struct) + instance_data = OmegaConf.structured(main) + remove_unused_components(instance_data) + self.assertEqual(sorted(instance_data.keys()), expected_keys) + self.assertEqual(instance_data, expected) + + +@dataclass(eq=False) +class MockDataclass: + field_no_default: int + field_primitive_type: int = 42 + field_reference_type: List[int] = field(default_factory=lambda: []) + + +class MockClassWithInit: # noqa: B903 + def __init__( + self, + field_no_default: int, + field_primitive_type: int = 42, + field_reference_type: List[int] = [], # noqa: B006 + ): + self.field_no_default = field_no_default + self.field_primitive_type = field_primitive_type + self.field_reference_type = field_reference_type + + +class TestRawClasses(unittest.TestCase): + def test_get_default_args(self): + for cls in [MockDataclass, MockClassWithInit]: + dataclass_defaults = get_default_args(cls) + inst = cls(field_no_default=0) + dataclass_defaults.field_no_default = 0 + for name, val in dataclass_defaults.items(): + self.assertTrue(hasattr(inst, name)) + self.assertEqual(val, getattr(inst, name)) + + def test_get_default_args_readonly(self): + for cls in [MockDataclass, MockClassWithInit]: + dataclass_defaults = get_default_args(cls) + dataclass_defaults["field_reference_type"].append(13) + inst = cls(field_no_default=0) + self.assertEqual(inst.field_reference_type, []) diff --git a/tests/implicitron/test_config_use.py b/tests/implicitron/test_config_use.py new file mode 100644 index 00000000..6ab5ecb2 --- /dev/null +++ b/tests/implicitron/test_config_use.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +from omegaconf import OmegaConf +from pytorch3d.implicitron.models.autodecoder import Autodecoder +from pytorch3d.implicitron.models.base import GenericModel +from pytorch3d.implicitron.models.implicit_function.idr_feature_field import ( + IdrFeatureField, +) +from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import ( + NeuralRadianceFieldImplicitFunction, +) +from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer +from pytorch3d.implicitron.models.renderer.multipass_ea import ( + MultiPassEmissionAbsorptionRenderer, +) +from pytorch3d.implicitron.models.view_pooling.feature_aggregation import ( + AngleWeightedIdentityFeatureAggregator, + AngleWeightedReductionFeatureAggregator, +) +from pytorch3d.implicitron.tools.config import ( + get_default_args, + remove_unused_components, +) + + +if os.environ.get("FB_TEST", False): + from common_testing import get_tests_dir +else: + from tests.common_testing import get_tests_dir + +DATA_DIR = get_tests_dir() / "implicitron/data" +DEBUG: bool = False + +# Tests the use of the config system in implicitron + + +class TestGenericModel(unittest.TestCase): + def setUp(self): + self.maxDiff = None + + def test_create_gm(self): + args = get_default_args(GenericModel) + gm = GenericModel(**args) + self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer) + self.assertIsInstance( + gm.feature_aggregator, AngleWeightedReductionFeatureAggregator + ) + self.assertIsInstance( + gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction + ) + self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) + self.assertFalse(hasattr(gm, "implicit_function")) + self.assertFalse(hasattr(gm, "image_feature_extractor")) + + def test_create_gm_overrides(self): + args = get_default_args(GenericModel) + args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator" + args.implicit_function_class_type = "IdrFeatureField" + args.renderer_class_type = "LSTMRenderer" + gm = GenericModel(**args) + self.assertIsInstance(gm.renderer, LSTMRenderer) + self.assertIsInstance( + gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator + ) + self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField) + self.assertIsInstance(gm.sequence_autodecoder, Autodecoder) + self.assertFalse(hasattr(gm, "implicit_function")) + + instance_args = OmegaConf.structured(gm) + remove_unused_components(instance_args) + yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) + if DEBUG: + (DATA_DIR / "overrides.yaml_").write_text(yaml) + self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) diff --git a/tests/implicitron/test_dataset_visualize.py b/tests/implicitron/test_dataset_visualize.py new file mode 100644 index 00000000..aa64f8c8 --- /dev/null +++ b/tests/implicitron/test_dataset_visualize.py @@ -0,0 +1,191 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import contextlib +import copy +import os +import unittest + +import torch +import torchvision +from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset +from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud +from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d +from pytorch3d.vis.plotly_vis import plot_scene +from visdom import Visdom + +if os.environ.get("FB_TEST", False): + from .common_resources import get_skateboard_data +else: + from common_resources import get_skateboard_data + + +class TestDatasetVisualize(unittest.TestCase): + def setUp(self): + if os.environ.get("INSIDE_RE_WORKER") is not None: + raise unittest.SkipTest("Visdom not available") + category = "skateboard" + stack = contextlib.ExitStack() + dataset_root, path_manager = stack.enter_context(get_skateboard_data()) + self.addCleanup(stack.close) + frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") + sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") + self.image_size = 256 + self.datasets = { + "simple": ImplicitronDataset( + frame_annotations_file=frame_file, + sequence_annotations_file=sequence_file, + dataset_root=dataset_root, + image_height=self.image_size, + image_width=self.image_size, + box_crop=True, + load_point_clouds=True, + path_manager=path_manager, + ), + "nonsquare": ImplicitronDataset( + frame_annotations_file=frame_file, + sequence_annotations_file=sequence_file, + dataset_root=dataset_root, + image_height=self.image_size, + image_width=self.image_size // 2, + box_crop=True, + load_point_clouds=True, + path_manager=path_manager, + ), + "nocrop": ImplicitronDataset( + frame_annotations_file=frame_file, + sequence_annotations_file=sequence_file, + dataset_root=dataset_root, + image_height=self.image_size, + image_width=self.image_size // 2, + box_crop=False, + load_point_clouds=True, + path_manager=path_manager, + ), + } + self.datasets.update( + { + k + "_newndc": _change_annotations_to_new_ndc(dataset) + for k, dataset in self.datasets.items() + } + ) + self.visdom = Visdom() + if not self.visdom.check_connection(): + print("Visdom server not running! Disabling visdom visualizations.") + self.visdom = None + + def _render_one_pointcloud(self, point_cloud, cameras, render_size): + (_image_render, _, _) = render_point_cloud_pytorch3d( + cameras, + point_cloud, + render_size=render_size, + point_radius=1e-2, + topk=10, + bg_color=0.0, + ) + return _image_render.clamp(0.0, 1.0) + + def test_one(self): + """Test dataset visualization.""" + for max_frames in (16, -1): + for load_dataset_point_cloud in (True, False): + for dataset_key in self.datasets: + self._gen_and_render_pointcloud( + max_frames, load_dataset_point_cloud, dataset_key + ) + + def _gen_and_render_pointcloud( + self, max_frames, load_dataset_point_cloud, dataset_key + ): + dataset = self.datasets[dataset_key] + # load the point cloud of the first sequence + sequence_show = list(dataset.seq_annots.keys())[0] + device = torch.device("cuda:0") + + point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud( + dataset, + sequence_name=sequence_show, + mask_points=True, + max_frames=max_frames, + num_workers=10, + load_dataset_point_cloud=load_dataset_point_cloud, + ) + + # render on gpu + point_cloud = point_cloud.to(device) + cameras = sequence_frame_data.camera.to(device) + + # render the point_cloud from the viewpoint of loaded cameras + images_render = torch.cat( + [ + self._render_one_pointcloud( + point_cloud, + cameras[frame_i], + ( + dataset.image_height, + dataset.image_width, + ), + ) + for frame_i in range(len(cameras)) + ] + ).cpu() + images_gt_and_render = torch.cat( + [sequence_frame_data.image_rgb, images_render], dim=3 + ) + + imfile = os.path.join( + os.path.split(os.path.abspath(__file__))[0], + "test_dataset_visualize" + + f"_max_frames={max_frames}" + + f"_load_pcl={load_dataset_point_cloud}.png", + ) + print(f"Exporting image {imfile}.") + torchvision.utils.save_image(images_gt_and_render, imfile, nrow=2) + + if self.visdom is not None: + test_name = f"{max_frames}_{load_dataset_point_cloud}_{dataset_key}" + self.visdom.images( + images_gt_and_render, + env="test_dataset_visualize", + win=f"pcl_renders_{test_name}", + opts={"title": f"pcl_renders_{test_name}"}, + ) + plotlyplot = plot_scene( + { + "scene_batch": { + "cameras": cameras, + "point_cloud": point_cloud, + } + }, + camera_scale=1.0, + pointcloud_max_points=10000, + pointcloud_marker_size=1.0, + ) + self.visdom.plotlyplot( + plotlyplot, + env="test_dataset_visualize", + win=f"pcl_{test_name}", + ) + + +def _change_annotations_to_new_ndc(dataset): + dataset = copy.deepcopy(dataset) + for frame in dataset.frame_annots: + vp = frame["frame_annotation"].viewpoint + vp.intrinsics_format = "ndc_isotropic" + # this assume the focal length to be equal on x and y (ok for a test) + max_flength = max(vp.focal_length) + vp.principal_point = ( + vp.principal_point[0] * max_flength / vp.focal_length[0], + vp.principal_point[1] * max_flength / vp.focal_length[1], + ) + vp.focal_length = ( + max_flength, + max_flength, + ) + + return dataset diff --git a/tests/implicitron/test_eval_cameras.py b/tests/implicitron/test_eval_cameras.py new file mode 100644 index 00000000..6ef3aa43 --- /dev/null +++ b/tests/implicitron/test_eval_cameras.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import torch +from pytorch3d.implicitron.tools.eval_video_trajectory import ( + generate_eval_video_cameras, +) +from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform +from pytorch3d.transforms import axis_angle_to_matrix + + +if os.environ.get("FB_TEST", False): + from common_testing import TestCaseMixin +else: + from tests.common_testing import TestCaseMixin + + +class TestEvalCameras(TestCaseMixin, unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + + def test_circular(self): + n_train_cameras = 10 + n_test_cameras = 100 + R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) + amplitude = 0.01 + R_jiggled = torch.bmm( + R, axis_angle_to_matrix(torch.rand(n_train_cameras, 3) * amplitude) + ) + cameras_train = PerspectiveCameras(R=R_jiggled, T=T) + cameras_test = generate_eval_video_cameras( + cameras_train, trajectory_type="circular_lsq_fit", trajectory_scale=1.0 + ) + + positions_test = cameras_test.get_camera_center() + center = positions_test.mean(0) + self.assertClose(center, torch.zeros(3), atol=0.1) + self.assertClose( + (positions_test - center).norm(dim=[1]), + torch.ones(n_test_cameras), + atol=0.1, + ) diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py new file mode 100644 index 00000000..9d50aff8 --- /dev/null +++ b/tests/implicitron/test_evaluation.py @@ -0,0 +1,290 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import contextlib +import copy +import dataclasses +import math +import os +import unittest + +import lpips +import torch +from pytorch3d.implicitron.dataset.implicitron_dataset import ( + FrameData, + ImplicitronDataset, +) +from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch +from pytorch3d.implicitron.models.model_dbir import ModelDBIR +from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth +from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_ + +if os.environ.get("FB_TEST", False): + from .common_resources import get_skateboard_data, provide_lpips_vgg +else: + from common_resources import get_skateboard_data, provide_lpips_vgg + + +class TestEvaluation(unittest.TestCase): + def setUp(self): + # initialize evaluation dataset/dataloader + torch.manual_seed(42) + + stack = contextlib.ExitStack() + dataset_root, path_manager = stack.enter_context(get_skateboard_data()) + self.addCleanup(stack.close) + + category = "skateboard" + frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz") + sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz") + self.image_size = 256 + self.dataset = ImplicitronDataset( + frame_annotations_file=frame_file, + sequence_annotations_file=sequence_file, + dataset_root=dataset_root, + image_height=self.image_size, + image_width=self.image_size, + box_crop=True, + path_manager=path_manager, + ) + self.bg_color = 0.0 + + # init the lpips model for eval + provide_lpips_vgg() + self.lpips_model = lpips.LPIPS(net="vgg") + + def test_eval_depth(self): + """ + Check that eval_depth correctly masks errors and that, for get_best_scale=True, + the error with scaled prediction equals the error without scaling the + predicted depth. Finally, test that the error values are as expected + for prediction and gt differing by a constant offset. + """ + gt = (torch.randn(10, 1, 300, 400, device="cuda") * 5.0).clamp(0.0) + mask = (torch.rand_like(gt) > 0.5).type_as(gt) + + for diff in 10 ** torch.linspace(-5, 0, 6): + for crop in (0, 5): + + pred = gt + (torch.rand_like(gt) - 0.5) * 2 * diff + + # scaled prediction test + mse_depth, abs_depth = eval_depth( + pred, + gt, + crop=crop, + mask=mask, + get_best_scale=True, + ) + mse_depth_scale, abs_depth_scale = eval_depth( + pred * 10.0, + gt, + crop=crop, + mask=mask, + get_best_scale=True, + ) + self.assertAlmostEqual( + float(mse_depth.sum()), float(mse_depth_scale.sum()), delta=1e-4 + ) + self.assertAlmostEqual( + float(abs_depth.sum()), float(abs_depth_scale.sum()), delta=1e-4 + ) + + # error masking test + pred_masked_err = gt + (torch.rand_like(gt) + diff) * (1 - mask) + mse_depth_masked, abs_depth_masked = eval_depth( + pred_masked_err, + gt, + crop=crop, + mask=mask, + get_best_scale=True, + ) + self.assertAlmostEqual( + float(mse_depth_masked.sum()), float(0.0), delta=1e-4 + ) + self.assertAlmostEqual( + float(abs_depth_masked.sum()), float(0.0), delta=1e-4 + ) + mse_depth_unmasked, abs_depth_unmasked = eval_depth( + pred_masked_err, + gt, + crop=crop, + mask=1 - mask, + get_best_scale=True, + ) + self.assertGreater( + float(mse_depth_unmasked.sum()), + float(diff ** 2), + ) + self.assertGreater( + float(abs_depth_unmasked.sum()), + float(diff), + ) + + # tests with constant error + pred_fix_diff = gt + diff * mask + for _mask_gt in (mask, None): + mse_depth_fix_diff, abs_depth_fix_diff = eval_depth( + pred_fix_diff, + gt, + crop=crop, + mask=_mask_gt, + get_best_scale=False, + ) + if _mask_gt is not None: + expected_err_abs = diff + expected_err_mse = diff ** 2 + else: + err_mask = (gt > 0.0).float() * mask + if crop > 0: + err_mask = err_mask[:, :, crop:-crop, crop:-crop] + gt_cropped = gt[:, :, crop:-crop, crop:-crop] + else: + gt_cropped = gt + gt_mass = (gt_cropped > 0.0).float().sum(dim=(1, 2, 3)) + expected_err_abs = ( + diff * err_mask.sum(dim=(1, 2, 3)) / (gt_mass) + ) + expected_err_mse = diff * expected_err_abs + self.assertTrue( + torch.allclose( + abs_depth_fix_diff, + expected_err_abs * torch.ones_like(abs_depth_fix_diff), + atol=1e-4, + ) + ) + self.assertTrue( + torch.allclose( + mse_depth_fix_diff, + expected_err_mse * torch.ones_like(mse_depth_fix_diff), + atol=1e-4, + ) + ) + + def test_psnr(self): + """ + Compare against opencv and check that the psnr is above + the minimum possible value. + """ + import cv2 + + im1 = torch.rand(100, 3, 256, 256).cuda() + im1_uint8 = (im1 * 255).to(torch.uint8) + im1_rounded = im1_uint8.float() / 255 + for max_diff in 10 ** torch.linspace(-5, 0, 6): + im2 = im1 + (torch.rand_like(im1) - 0.5) * 2 * max_diff + im2 = im2.clamp(0.0, 1.0) + im2_uint8 = (im2 * 255).to(torch.uint8) + im2_rounded = im2_uint8.float() / 255 + # check that our psnr matches the output of opencv + psnr = calc_psnr(im1_rounded, im2_rounded) + # some versions of cv2 can only take uint8 input + psnr_cv2 = cv2.PSNR( + im1_uint8.cpu().numpy(), + im2_uint8.cpu().numpy(), + ) + self.assertAlmostEqual(float(psnr), float(psnr_cv2), delta=1e-4) + # check that all PSNRs are bigger than the minimum possible PSNR + max_mse = max_diff ** 2 + min_psnr = 10 * math.log10(1.0 / max_mse) + for _im1, _im2 in zip(im1, im2): + _psnr = calc_psnr(_im1, _im2) + self.assertGreaterEqual(float(_psnr) + 1e-6, min_psnr) + + def _one_sequence_test( + self, + seq_dataset, + n_batches=2, + min_batch_size=5, + max_batch_size=10, + ): + # form a list of random batches + batch_indices = [] + for _ in range(n_batches): + batch_size = torch.randint( + low=min_batch_size, high=max_batch_size, size=(1,) + ) + batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size]) + + loader = torch.utils.data.DataLoader( + seq_dataset, + # batch_size=1, + shuffle=False, + batch_sampler=batch_indices, + collate_fn=FrameData.collate, + ) + + model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color) + model.cuda() + self.lpips_model.cuda() + + for frame_data in loader: + self.assertIsNone(frame_data.frame_type) + self.assertIsNotNone(frame_data.image_rgb) + # override the frame_type + frame_data.frame_type = [ + "train_unseen", + *(["train_known"] * (len(frame_data.image_rgb) - 1)), + ] + + # move frame_data to gpu + frame_data = dataclass_to_cuda_(frame_data) + preds = model(**dataclasses.asdict(frame_data)) + + nvs_prediction = copy.deepcopy(preds["nvs_prediction"]) + eval_result = eval_batch( + frame_data, + nvs_prediction, + bg_color=self.bg_color, + lpips_model=self.lpips_model, + ) + + # Make a terribly bad NVS prediction and check that this is worse + # than the DBIR prediction. + nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"]) + nvs_prediction_bad.depth_render += ( + torch.randn_like(nvs_prediction.depth_render) * 100.0 + ) + nvs_prediction_bad.image_render += ( + torch.randn_like(nvs_prediction.image_render) * 100.0 + ) + nvs_prediction_bad.mask_render = ( + torch.randn_like(nvs_prediction.mask_render) > 0.0 + ).float() + eval_result_bad = eval_batch( + frame_data, + nvs_prediction_bad, + bg_color=self.bg_color, + lpips_model=self.lpips_model, + ) + + lower_better = { + "psnr": False, + "psnr_fg": False, + "depth_abs_fg": True, + "iou": False, + "rgb_l1": True, + "rgb_l1_fg": True, + } + + for metric in lower_better.keys(): + m_better = eval_result[metric] + m_worse = eval_result_bad[metric] + if m_better != m_better or m_worse != m_worse: + continue # metric is missing, i.e. NaN + _assert = ( + self.assertLessEqual + if lower_better[metric] + else self.assertGreaterEqual + ) + _assert(m_better, m_worse) + + def test_full_eval(self, n_sequences=5): + """Test evaluation.""" + for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]: + seq_dataset = torch.utils.data.Subset(self.dataset, idx) + self._one_sequence_test(seq_dataset) diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py new file mode 100644 index 00000000..eda733d6 --- /dev/null +++ b/tests/implicitron/test_forward_pass.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from pytorch3d.implicitron.models.base import GenericModel +from pytorch3d.implicitron.models.renderer.base import EvaluationMode +from pytorch3d.implicitron.tools.config import expand_args_fields +from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform + + +class TestGenericModel(unittest.TestCase): + def test_gm(self): + # Simple test of a forward pass of the default GenericModel. + device = torch.device("cuda:1") + expand_args_fields(GenericModel) + model = GenericModel() + model.to(device) + + n_train_cameras = 2 + R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360) + cameras = PerspectiveCameras(R=R, T=T, device=device) + + # TODO: make these default to None? + defaulted_args = { + "fg_probability": None, + "depth_map": None, + "mask_crop": None, + "sequence_name": None, + } + + with self.assertWarnsRegex(UserWarning, "No main objective found"): + model( + camera=cameras, + evaluation_mode=EvaluationMode.TRAINING, + **defaulted_args, + image_rgb=None, + ) + target_image_rgb = torch.rand( + (n_train_cameras, 3, model.render_image_height, model.render_image_width), + device=device, + ) + train_preds = model( + camera=cameras, + evaluation_mode=EvaluationMode.TRAINING, + image_rgb=target_image_rgb, + **defaulted_args, + ) + self.assertGreater(train_preds["objective"].item(), 0) + + model.eval() + with torch.no_grad(): + # TODO: perhaps this warning should be skipped in eval mode? + with self.assertWarnsRegex(UserWarning, "No main objective found"): + eval_preds = model( + camera=cameras[0], + **defaulted_args, + image_rgb=None, + ) + self.assertEqual( + eval_preds["images_render"].shape, + (1, 3, model.render_image_height, model.render_image_width), + ) diff --git a/tests/implicitron/test_ray_point_refiner.py b/tests/implicitron/test_ray_point_refiner.py new file mode 100644 index 00000000..cc16197e --- /dev/null +++ b/tests/implicitron/test_ray_point_refiner.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import torch +from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner +from pytorch3d.renderer import RayBundle + + +if os.environ.get("FB_TEST", False): + from common_testing import TestCaseMixin +else: + from tests.common_testing import TestCaseMixin + + +class TestRayPointRefiner(TestCaseMixin, unittest.TestCase): + def test_simple(self): + length = 15 + n_pts_per_ray = 10 + + for add_input_samples in [False, True]: + ray_point_refiner = RayPointRefiner( + n_pts_per_ray=n_pts_per_ray, + random_sampling=False, + add_input_samples=add_input_samples, + ) + lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length) + bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None) + weights = torch.ones(3, 25, length) + refined = ray_point_refiner(bundle, weights) + + self.assertIsNone(refined.directions) + self.assertIsNone(refined.origins) + self.assertIsNone(refined.xys) + expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray) + expected = expected.expand(3, 25, n_pts_per_ray) + if add_input_samples: + full_expected = torch.cat((lengths, expected), dim=-1).sort()[0] + else: + full_expected = expected + self.assertClose(refined.lengths, full_expected) + + ray_point_refiner_random = RayPointRefiner( + n_pts_per_ray=n_pts_per_ray, + random_sampling=True, + add_input_samples=add_input_samples, + ) + refined_random = ray_point_refiner_random(bundle, weights) + lengths_random = refined_random.lengths + self.assertEqual(lengths_random.shape, full_expected.shape) + if not add_input_samples: + self.assertGreater(lengths_random.min().item(), 0.5) + self.assertLess(lengths_random.max().item(), length - 1.5) + + # Check sorted + self.assertTrue( + (lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all() + ) diff --git a/tests/implicitron/test_srn.py b/tests/implicitron/test_srn.py new file mode 100644 index 00000000..062d549c --- /dev/null +++ b/tests/implicitron/test_srn.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +import torch +from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import ( + SRNHyperNetImplicitFunction, + SRNImplicitFunction, + SRNPixelGenerator, +) +from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper +from pytorch3d.implicitron.tools.config import get_default_args +from pytorch3d.renderer import RayBundle + + +if os.environ.get("FB_TEST", False): + from common_testing import TestCaseMixin +else: + from tests.common_testing import TestCaseMixin + +_BATCH_SIZE: int = 3 +_N_RAYS: int = 100 +_N_POINTS_ON_RAY: int = 10 + + +class TestSRN(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + get_default_args(SRNHyperNetImplicitFunction) + get_default_args(SRNImplicitFunction) + + def test_pixel_generator(self): + SRNPixelGenerator() + + def _get_bundle(self, *, device) -> RayBundle: + origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) + directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device) + lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device) + bundle = RayBundle( + lengths=lengths, origins=origins, directions=directions, xys=None + ) + return bundle + + def test_srn_implicit_function(self): + implicit_function = SRNImplicitFunction() + device = torch.device("cpu") + bundle = self._get_bundle(device=device) + rays_densities, rays_colors = implicit_function(bundle) + out_features = implicit_function.raymarch_function.out_features + self.assertEqual( + rays_densities.shape, + (_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features), + ) + self.assertIsNone(rays_colors) + + def test_srn_hypernet_implicit_function(self): + # TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core? + latent_dim_hypernet = 39 + hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet} + device = torch.device("cuda:0") + implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args) + implicit_function.to(device) + global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) + bundle = self._get_bundle(device=device) + rays_densities, rays_colors = implicit_function(bundle, global_code=global_code) + out_features = implicit_function.hypernet.out_features + self.assertEqual( + rays_densities.shape, + (_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features), + ) + self.assertIsNone(rays_colors) + + def test_srn_hypernet_implicit_function_optim(self): + # Test optimization loop, requiring that the cache is properly + # cleared in new_args_bound + latent_dim_hypernet = 39 + hyper_args = {"latent_dim_hypernet": latent_dim_hypernet} + device = torch.device("cuda:0") + global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device) + bundle = self._get_bundle(device=device) + + implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args) + implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args) + implicit_function.to(device) + implicit_function2.to(device) + + wrapper = ImplicitFunctionWrapper(implicit_function) + optimizer = torch.optim.Adam(implicit_function.parameters()) + for _step in range(3): + optimizer.zero_grad() + wrapper.bind_args(global_code=global_code) + rays_densities, _rays_colors = wrapper(bundle) + wrapper.unbind_args() + loss = rays_densities.sum() + loss.backward() + optimizer.step() + + wrapper2 = ImplicitFunctionWrapper(implicit_function) + optimizer2 = torch.optim.Adam(implicit_function2.parameters()) + implicit_function2.load_state_dict(implicit_function.state_dict()) + optimizer2.load_state_dict(optimizer.state_dict()) + for _step in range(3): + optimizer2.zero_grad() + wrapper2.bind_args(global_code=global_code) + rays_densities, _rays_colors = wrapper2(bundle) + wrapper2.unbind_args() + loss = rays_densities.sum() + loss.backward() + optimizer2.step() diff --git a/tests/implicitron/test_types.py b/tests/implicitron/test_types.py new file mode 100644 index 00000000..91338edc --- /dev/null +++ b/tests/implicitron/test_types.py @@ -0,0 +1,93 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import dataclasses +import unittest +from typing import Dict, List, NamedTuple, Tuple + +from pytorch3d.implicitron.dataset import types +from pytorch3d.implicitron.dataset.types import FrameAnnotation + + +class _NT(NamedTuple): + annot: FrameAnnotation + + +class TestDatasetTypes(unittest.TestCase): + def setUp(self): + self.entry = FrameAnnotation( + frame_number=23, + sequence_name="1", + frame_timestamp=1.2, + image=types.ImageAnnotation(path="/tmp/1.jpg", size=(224, 224)), + mask=types.MaskAnnotation(path="/tmp/1.png", mass=42.0), + viewpoint=types.ViewpointAnnotation( + R=( + (1, 0, 0), + (1, 0, 0), + (1, 0, 0), + ), + T=(0, 0, 0), + principal_point=(100, 100), + focal_length=(200, 200), + ), + ) + + def test_asdict_rec(self): + first = [dataclasses.asdict(self.entry)] + second = types._asdict_rec([self.entry]) + self.assertEqual(first, second) + + def test_parsing(self): + """Test that we handle collections enclosing dataclasses.""" + + dct = dataclasses.asdict(self.entry) + + parsed = types._dataclass_from_dict(dct, FrameAnnotation) + self.assertEqual(parsed, self.entry) + + # namedtuple + parsed = types._dataclass_from_dict(_NT(dct), _NT) + self.assertEqual(parsed.annot, self.entry) + + # tuple + parsed = types._dataclass_from_dict((dct,), Tuple[FrameAnnotation]) + self.assertEqual(parsed, (self.entry,)) + + # list + parsed = types._dataclass_from_dict( + [ + dct, + ], + List[FrameAnnotation], + ) + self.assertEqual( + parsed, + [ + self.entry, + ], + ) + + # dict + parsed = types._dataclass_from_dict({"k": dct}, Dict[str, FrameAnnotation]) + self.assertEqual(parsed, {"k": self.entry}) + + def test_parsing_vectorized(self): + dct = dataclasses.asdict(self.entry) + + self._compare_with_scalar(dct, FrameAnnotation) + self._compare_with_scalar(_NT(dct), _NT) + self._compare_with_scalar((dct,), Tuple[FrameAnnotation]) + self._compare_with_scalar([dct], List[FrameAnnotation]) + self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation]) + + def _compare_with_scalar(self, obj, typeannot, repeat=3): + input = [obj] * 3 + vect_output = types._dataclass_list_from_dict_list(input, typeannot) + self.assertEqual(len(input), repeat) + gt = types._dataclass_from_dict(obj, typeannot) + self.assertTrue(all(res == gt for res in vect_output)) diff --git a/tests/implicitron/test_viewsampling.py b/tests/implicitron/test_viewsampling.py new file mode 100644 index 00000000..dd438eb9 --- /dev/null +++ b/tests/implicitron/test_viewsampling.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +import pytorch3d as pt3d +import torch +from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler +from pytorch3d.implicitron.tools.config import expand_args_fields + + +class TestViewsampling(unittest.TestCase): + def setUp(self): + torch.manual_seed(42) + expand_args_fields(ViewSampler) + + def _init_view_sampler_problem(self, random_masks): + """ + Generates a view-sampling problem: + - 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2' + - 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively. + - first 50 points in each batch correctly project to the source views, + while the remaining 50 do not land in any projection plane. + - each source view is labeled with image feature tensors of shape 7x100x50, + where all elements of the n-th tensor are set to `n+1`. + - the elements of the source view masks are either set to random binary number + (if `random_masks==True`), or all set to 1 (`random_masks==False`). + - the source view cameras are uniformly distributed on a unit circle + in the x-z plane and look at (0,0,0). + """ + seq_id_camera = ["seq1", "seq1", "seq2", "seq2"] + seq_id_pts = ["seq1", "seq2", "seq2"] + pts_batch = 3 + n_pts = 100 + n_views = 4 + fdim = 7 + H = 100 + W = 50 + + # points that land into the projection planes of all cameras + pts_inside = ( + torch.nn.functional.normalize( + torch.randn(pts_batch, n_pts // 2, 3, device="cuda"), + dim=-1, + ) + * 0.1 + ) + + # move the outside points far above the scene + pts_outside = pts_inside.clone() + pts_outside[:, :, 1] += 1e8 + pts = torch.cat([pts_inside, pts_outside], dim=1) + + R, T = pt3d.renderer.look_at_view_transform( + dist=1.0, + elev=0.0, + azim=torch.linspace(0, 360, n_views + 1)[:n_views], + degrees=True, + device=pts.device, + ) + focal_length = R.new_ones(n_views, 2) + principal_point = R.new_zeros(n_views, 2) + camera = pt3d.renderer.PerspectiveCameras( + R=R, + T=T, + focal_length=focal_length, + principal_point=principal_point, + device=pts.device, + ) + + feats_map = torch.arange(n_views, device=pts.device, dtype=pts.dtype) + 1 + feats = {"feats": feats_map[:, None, None, None].repeat(1, fdim, H, W)} + + masks = ( + torch.rand(n_views, 1, H, W, device=pts.device, dtype=pts.dtype) > 0.5 + ).type_as(R) + + if not random_masks: + masks[:] = 1.0 + + return pts, camera, feats, masks, seq_id_camera, seq_id_pts + + def test_compare_with_naive(self): + """ + Compares the outputs of the efficient ViewSampler module with a + naive implementation. + """ + + ( + pts, + camera, + feats, + masks, + seq_id_camera, + seq_id_pts, + ) = self._init_view_sampler_problem(True) + + for masked_sampling in (True, False): + feats_sampled_n, masks_sampled_n = _view_sample_naive( + pts, + seq_id_pts, + camera, + seq_id_camera, + feats, + masks, + masked_sampling, + ) + # make sure we generate the constructor for ViewSampler + expand_args_fields(ViewSampler) + view_sampler = ViewSampler(masked_sampling=masked_sampling) + feats_sampled, masks_sampled = view_sampler( + pts=pts, + seq_id_pts=seq_id_pts, + camera=camera, + seq_id_camera=seq_id_camera, + feats=feats, + masks=masks, + ) + for k in feats_sampled.keys(): + self.assertTrue(torch.allclose(feats_sampled[k], feats_sampled_n[k])) + self.assertTrue(torch.allclose(masks_sampled, masks_sampled_n)) + + def test_viewsampling(self): + """ + Generates a viewsampling problem with predictable outcome, and compares + the ViewSampler's output to the expected result. + """ + + ( + pts, + camera, + feats, + masks, + seq_id_camera, + seq_id_pts, + ) = self._init_view_sampler_problem(False) + + expand_args_fields(ViewSampler) + + for masked_sampling in (True, False): + + view_sampler = ViewSampler(masked_sampling=masked_sampling) + + feats_sampled, masks_sampled = view_sampler( + pts=pts, + seq_id_pts=seq_id_pts, + camera=camera, + seq_id_camera=seq_id_camera, + feats=feats, + masks=masks, + ) + + n_views = camera.R.shape[0] + n_pts = pts.shape[1] + feat_dim = feats["feats"].shape[1] + pts_batch = pts.shape[0] + n_pts_away = n_pts // 2 + + for pts_i in range(pts_batch): + for view_i in range(n_views): + if seq_id_pts[pts_i] != seq_id_camera[view_i]: + # points / cameras come from different sequences + gt_masks = pts.new_zeros(n_pts, 1) + gt_feats = pts.new_zeros(n_pts, feat_dim) + else: + gt_masks = pts.new_ones(n_pts, 1) + gt_feats = pts.new_ones(n_pts, feat_dim) * (view_i + 1) + gt_feats[n_pts_away:] = 0.0 + if masked_sampling: + gt_masks[n_pts_away:] = 0.0 + + for k in feats_sampled: + self.assertTrue( + torch.allclose( + feats_sampled[k][pts_i, view_i], + gt_feats, + ) + ) + self.assertTrue( + torch.allclose( + masks_sampled[pts_i, view_i], + gt_masks, + ) + ) + + +def _view_sample_naive( + pts, + seq_id_pts, + camera, + seq_id_camera, + feats, + masks, + masked_sampling, +): + """ + A naive implementation of the forward pass of ViewSampler. + Refer to ViewSampler's docstring for description of the arguments. + """ + + pts_batch = pts.shape[0] + n_views = camera.R.shape[0] + n_pts = pts.shape[1] + + feats_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)] + masks_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)] + + for pts_i in range(pts_batch): + for view_i in range(n_views): + if seq_id_pts[pts_i] != seq_id_camera[view_i]: + # points/cameras come from different sequences + feats_sampled_ = { + k: f.new_zeros(n_pts, f.shape[1]) for k, f in feats.items() + } + masks_sampled_ = masks.new_zeros(n_pts, 1) + else: + # same sequence of pts and cameras -> sample + feats_sampled_, masks_sampled_ = _sample_one_view_naive( + camera[view_i], + pts[pts_i], + {k: f[view_i] for k, f in feats.items()}, + masks[view_i], + masked_sampling, + sampling_mode="bilinear", + ) + feats_sampled[pts_i][view_i] = feats_sampled_ + masks_sampled[pts_i][view_i] = masks_sampled_ + + masks_sampled_cat = torch.stack([torch.stack(m) for m in masks_sampled]) + feats_sampled_cat = {} + for k in feats_sampled[0][0].keys(): + feats_sampled_cat[k] = torch.stack( + [torch.stack([f_[k] for f_ in f]) for f in feats_sampled] + ) + return feats_sampled_cat, masks_sampled_cat + + +def _sample_one_view_naive( + camera, + pts, + feats, + masks, + masked_sampling, + sampling_mode="bilinear", +): + """ + Sample a single source view. + """ + proj_ndc = camera.transform_points(pts[None])[None, ..., :-1] # 1 x 1 x n_pts x 2 + feats_sampled = { + k: pt3d.renderer.ndc_grid_sample(f[None], proj_ndc, mode=sampling_mode).permute( + 0, 3, 1, 2 + )[0, :, :, 0] + for k, f in feats.items() + } # n_pts x dim + if not masked_sampling: + n_pts = pts.shape[0] + masks_sampled = proj_ndc.new_ones(n_pts, 1) + else: + masks_sampled = pt3d.renderer.ndc_grid_sample( + masks[None], + proj_ndc, + mode=sampling_mode, + align_corners=False, + )[0, 0, 0, :][:, None] + return feats_sampled, masks_sampled