mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	loading llff and blender datasets
Summary: Copy code from NeRF for loading LLFF data and blender synthetic data, and create dataset objects for them Reviewed By: shapovalov Differential Revision: D35581039 fbshipit-source-id: af7a6f3e9a42499700693381b5b147c991f57e5d
This commit is contained in:
		
							parent
							
								
									7978ffd1e4
								
							
						
					
					
						commit
						65f667fd2e
					
				@ -46,3 +46,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			|||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
					LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
					OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
SOFTWARE.
 | 
					SOFTWARE.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					NeRF https://github.com/bmild/nerf/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Copyright (c) 2020 bmild
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
				
			||||||
 | 
					of this software and associated documentation files (the "Software"), to deal
 | 
				
			||||||
 | 
					in the Software without restriction, including without limitation the rights
 | 
				
			||||||
 | 
					to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
				
			||||||
 | 
					copies of the Software, and to permit persons to whom the Software is
 | 
				
			||||||
 | 
					furnished to do so, subject to the following conditions:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					The above copyright notice and this permission notice shall be included in all
 | 
				
			||||||
 | 
					copies or substantial portions of the Software.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
				
			||||||
 | 
					IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
				
			||||||
 | 
					FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
				
			||||||
 | 
					AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
				
			||||||
 | 
					LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
				
			||||||
 | 
					OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
				
			||||||
 | 
					SOFTWARE.
 | 
				
			||||||
 | 
				
			|||||||
@ -5,7 +5,7 @@ Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling t
 | 
				
			|||||||
# License
 | 
					# License
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
 | 
					Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE).
 | 
				
			||||||
It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
 | 
					It includes code from the [NeRF](https://github.com/bmild/nerf), [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos.
 | 
				
			||||||
See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
 | 
					See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -315,7 +315,7 @@ def trainvalidate(
 | 
				
			|||||||
    epoch,
 | 
					    epoch,
 | 
				
			||||||
    loader,
 | 
					    loader,
 | 
				
			||||||
    optimizer,
 | 
					    optimizer,
 | 
				
			||||||
    validation,
 | 
					    validation: bool,
 | 
				
			||||||
    bp_var: str = "objective",
 | 
					    bp_var: str = "objective",
 | 
				
			||||||
    metric_print_interval: int = 5,
 | 
					    metric_print_interval: int = 5,
 | 
				
			||||||
    visualize_interval: int = 100,
 | 
					    visualize_interval: int = 100,
 | 
				
			||||||
 | 
				
			|||||||
@ -95,13 +95,6 @@ generic_model_args:
 | 
				
			|||||||
    append_coarse_samples_to_fine: true
 | 
					    append_coarse_samples_to_fine: true
 | 
				
			||||||
    density_noise_std_train: 0.0
 | 
					    density_noise_std_train: 0.0
 | 
				
			||||||
    return_weights: false
 | 
					    return_weights: false
 | 
				
			||||||
    raymarcher_EmissionAbsorptionRaymarcher_args:
 | 
					 | 
				
			||||||
      surface_thickness: 1
 | 
					 | 
				
			||||||
      bg_color:
 | 
					 | 
				
			||||||
      - 0.0
 | 
					 | 
				
			||||||
      background_opacity: 10000000000.0
 | 
					 | 
				
			||||||
      density_relu: true
 | 
					 | 
				
			||||||
      blend_output: false
 | 
					 | 
				
			||||||
    raymarcher_CumsumRaymarcher_args:
 | 
					    raymarcher_CumsumRaymarcher_args:
 | 
				
			||||||
      surface_thickness: 1
 | 
					      surface_thickness: 1
 | 
				
			||||||
      bg_color:
 | 
					      bg_color:
 | 
				
			||||||
@ -109,6 +102,13 @@ generic_model_args:
 | 
				
			|||||||
      background_opacity: 0.0
 | 
					      background_opacity: 0.0
 | 
				
			||||||
      density_relu: true
 | 
					      density_relu: true
 | 
				
			||||||
      blend_output: false
 | 
					      blend_output: false
 | 
				
			||||||
 | 
					    raymarcher_EmissionAbsorptionRaymarcher_args:
 | 
				
			||||||
 | 
					      surface_thickness: 1
 | 
				
			||||||
 | 
					      bg_color:
 | 
				
			||||||
 | 
					      - 0.0
 | 
				
			||||||
 | 
					      background_opacity: 10000000000.0
 | 
				
			||||||
 | 
					      density_relu: true
 | 
				
			||||||
 | 
					      blend_output: false
 | 
				
			||||||
  renderer_SignedDistanceFunctionRenderer_args:
 | 
					  renderer_SignedDistanceFunctionRenderer_args:
 | 
				
			||||||
    render_features_dimensions: 3
 | 
					    render_features_dimensions: 3
 | 
				
			||||||
    ray_tracer_args:
 | 
					    ray_tracer_args:
 | 
				
			||||||
@ -157,6 +157,21 @@ generic_model_args:
 | 
				
			|||||||
    view_sampler_args:
 | 
					    view_sampler_args:
 | 
				
			||||||
      masked_sampling: false
 | 
					      masked_sampling: false
 | 
				
			||||||
      sampling_mode: bilinear
 | 
					      sampling_mode: bilinear
 | 
				
			||||||
 | 
					    feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
 | 
				
			||||||
 | 
					      exclude_target_view: true
 | 
				
			||||||
 | 
					      exclude_target_view_mask_features: true
 | 
				
			||||||
 | 
					      concatenate_output: true
 | 
				
			||||||
 | 
					      weight_by_ray_angle_gamma: 1.0
 | 
				
			||||||
 | 
					      min_ray_angle_weight: 0.1
 | 
				
			||||||
 | 
					    feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
 | 
				
			||||||
 | 
					      exclude_target_view: true
 | 
				
			||||||
 | 
					      exclude_target_view_mask_features: true
 | 
				
			||||||
 | 
					      concatenate_output: true
 | 
				
			||||||
 | 
					      reduction_functions:
 | 
				
			||||||
 | 
					      - AVG
 | 
				
			||||||
 | 
					      - STD
 | 
				
			||||||
 | 
					      weight_by_ray_angle_gamma: 1.0
 | 
				
			||||||
 | 
					      min_ray_angle_weight: 0.1
 | 
				
			||||||
    feature_aggregator_IdentityFeatureAggregator_args:
 | 
					    feature_aggregator_IdentityFeatureAggregator_args:
 | 
				
			||||||
      exclude_target_view: true
 | 
					      exclude_target_view: true
 | 
				
			||||||
      exclude_target_view_mask_features: true
 | 
					      exclude_target_view_mask_features: true
 | 
				
			||||||
@ -168,21 +183,6 @@ generic_model_args:
 | 
				
			|||||||
      reduction_functions:
 | 
					      reduction_functions:
 | 
				
			||||||
      - AVG
 | 
					      - AVG
 | 
				
			||||||
      - STD
 | 
					      - STD
 | 
				
			||||||
    feature_aggregator_AngleWeightedReductionFeatureAggregator_args:
 | 
					 | 
				
			||||||
      exclude_target_view: true
 | 
					 | 
				
			||||||
      exclude_target_view_mask_features: true
 | 
					 | 
				
			||||||
      concatenate_output: true
 | 
					 | 
				
			||||||
      reduction_functions:
 | 
					 | 
				
			||||||
      - AVG
 | 
					 | 
				
			||||||
      - STD
 | 
					 | 
				
			||||||
      weight_by_ray_angle_gamma: 1.0
 | 
					 | 
				
			||||||
      min_ray_angle_weight: 0.1
 | 
					 | 
				
			||||||
    feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
 | 
					 | 
				
			||||||
      exclude_target_view: true
 | 
					 | 
				
			||||||
      exclude_target_view_mask_features: true
 | 
					 | 
				
			||||||
      concatenate_output: true
 | 
					 | 
				
			||||||
      weight_by_ray_angle_gamma: 1.0
 | 
					 | 
				
			||||||
      min_ray_angle_weight: 0.1
 | 
					 | 
				
			||||||
  implicit_function_IdrFeatureField_args:
 | 
					  implicit_function_IdrFeatureField_args:
 | 
				
			||||||
    feature_vector_size: 3
 | 
					    feature_vector_size: 3
 | 
				
			||||||
    d_in: 3
 | 
					    d_in: 3
 | 
				
			||||||
@ -203,19 +203,6 @@ generic_model_args:
 | 
				
			|||||||
    n_harmonic_functions_xyz: 0
 | 
					    n_harmonic_functions_xyz: 0
 | 
				
			||||||
    pooled_feature_dim: 0
 | 
					    pooled_feature_dim: 0
 | 
				
			||||||
    encoding_dim: 0
 | 
					    encoding_dim: 0
 | 
				
			||||||
  implicit_function_NeuralRadianceFieldImplicitFunction_args:
 | 
					 | 
				
			||||||
    n_harmonic_functions_xyz: 10
 | 
					 | 
				
			||||||
    n_harmonic_functions_dir: 4
 | 
					 | 
				
			||||||
    n_hidden_neurons_dir: 128
 | 
					 | 
				
			||||||
    latent_dim: 0
 | 
					 | 
				
			||||||
    input_xyz: true
 | 
					 | 
				
			||||||
    xyz_ray_dir_in_camera_coords: false
 | 
					 | 
				
			||||||
    color_dim: 3
 | 
					 | 
				
			||||||
    transformer_dim_down_factor: 1.0
 | 
					 | 
				
			||||||
    n_hidden_neurons_xyz: 256
 | 
					 | 
				
			||||||
    n_layers_xyz: 8
 | 
					 | 
				
			||||||
    append_xyz:
 | 
					 | 
				
			||||||
    - 5
 | 
					 | 
				
			||||||
  implicit_function_NeRFormerImplicitFunction_args:
 | 
					  implicit_function_NeRFormerImplicitFunction_args:
 | 
				
			||||||
    n_harmonic_functions_xyz: 10
 | 
					    n_harmonic_functions_xyz: 10
 | 
				
			||||||
    n_harmonic_functions_dir: 4
 | 
					    n_harmonic_functions_dir: 4
 | 
				
			||||||
@ -229,24 +216,19 @@ generic_model_args:
 | 
				
			|||||||
    n_layers_xyz: 2
 | 
					    n_layers_xyz: 2
 | 
				
			||||||
    append_xyz:
 | 
					    append_xyz:
 | 
				
			||||||
    - 1
 | 
					    - 1
 | 
				
			||||||
  implicit_function_SRNImplicitFunction_args:
 | 
					  implicit_function_NeuralRadianceFieldImplicitFunction_args:
 | 
				
			||||||
    raymarch_function_args:
 | 
					    n_harmonic_functions_xyz: 10
 | 
				
			||||||
      n_harmonic_functions: 3
 | 
					    n_harmonic_functions_dir: 4
 | 
				
			||||||
      n_hidden_units: 256
 | 
					    n_hidden_neurons_dir: 128
 | 
				
			||||||
      n_layers: 2
 | 
					    latent_dim: 0
 | 
				
			||||||
      in_features: 3
 | 
					    input_xyz: true
 | 
				
			||||||
      out_features: 256
 | 
					    xyz_ray_dir_in_camera_coords: false
 | 
				
			||||||
      latent_dim: 0
 | 
					    color_dim: 3
 | 
				
			||||||
      xyz_in_camera_coords: false
 | 
					    transformer_dim_down_factor: 1.0
 | 
				
			||||||
      raymarch_function: null
 | 
					    n_hidden_neurons_xyz: 256
 | 
				
			||||||
    pixel_generator_args:
 | 
					    n_layers_xyz: 8
 | 
				
			||||||
      n_harmonic_functions: 4
 | 
					    append_xyz:
 | 
				
			||||||
      n_hidden_units: 256
 | 
					    - 5
 | 
				
			||||||
      n_hidden_units_color: 128
 | 
					 | 
				
			||||||
      n_layers: 2
 | 
					 | 
				
			||||||
      in_features: 256
 | 
					 | 
				
			||||||
      out_features: 3
 | 
					 | 
				
			||||||
      ray_dir_in_camera_coords: false
 | 
					 | 
				
			||||||
  implicit_function_SRNHyperNetImplicitFunction_args:
 | 
					  implicit_function_SRNHyperNetImplicitFunction_args:
 | 
				
			||||||
    hypernet_args:
 | 
					    hypernet_args:
 | 
				
			||||||
      n_harmonic_functions: 3
 | 
					      n_harmonic_functions: 3
 | 
				
			||||||
@ -267,6 +249,24 @@ generic_model_args:
 | 
				
			|||||||
      in_features: 256
 | 
					      in_features: 256
 | 
				
			||||||
      out_features: 3
 | 
					      out_features: 3
 | 
				
			||||||
      ray_dir_in_camera_coords: false
 | 
					      ray_dir_in_camera_coords: false
 | 
				
			||||||
 | 
					  implicit_function_SRNImplicitFunction_args:
 | 
				
			||||||
 | 
					    raymarch_function_args:
 | 
				
			||||||
 | 
					      n_harmonic_functions: 3
 | 
				
			||||||
 | 
					      n_hidden_units: 256
 | 
				
			||||||
 | 
					      n_layers: 2
 | 
				
			||||||
 | 
					      in_features: 3
 | 
				
			||||||
 | 
					      out_features: 256
 | 
				
			||||||
 | 
					      latent_dim: 0
 | 
				
			||||||
 | 
					      xyz_in_camera_coords: false
 | 
				
			||||||
 | 
					      raymarch_function: null
 | 
				
			||||||
 | 
					    pixel_generator_args:
 | 
				
			||||||
 | 
					      n_harmonic_functions: 4
 | 
				
			||||||
 | 
					      n_hidden_units: 256
 | 
				
			||||||
 | 
					      n_hidden_units_color: 128
 | 
				
			||||||
 | 
					      n_layers: 2
 | 
				
			||||||
 | 
					      in_features: 256
 | 
				
			||||||
 | 
					      out_features: 3
 | 
				
			||||||
 | 
					      ray_dir_in_camera_coords: false
 | 
				
			||||||
solver_args:
 | 
					solver_args:
 | 
				
			||||||
  breed: adam
 | 
					  breed: adam
 | 
				
			||||||
  weight_decay: 0.0
 | 
					  weight_decay: 0.0
 | 
				
			||||||
@ -282,6 +282,13 @@ solver_args:
 | 
				
			|||||||
data_source_args:
 | 
					data_source_args:
 | 
				
			||||||
  dataset_map_provider_class_type: ???
 | 
					  dataset_map_provider_class_type: ???
 | 
				
			||||||
  data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
 | 
					  data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
 | 
				
			||||||
 | 
					  dataset_map_provider_BlenderDatasetMapProvider_args:
 | 
				
			||||||
 | 
					    base_dir: ???
 | 
				
			||||||
 | 
					    object_name: ???
 | 
				
			||||||
 | 
					    path_manager_factory_class_type: PathManagerFactory
 | 
				
			||||||
 | 
					    n_known_frames_for_test: null
 | 
				
			||||||
 | 
					    path_manager_factory_PathManagerFactory_args:
 | 
				
			||||||
 | 
					      silence_logs: true
 | 
				
			||||||
  dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
					  dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
				
			||||||
    category: ???
 | 
					    category: ???
 | 
				
			||||||
    task_str: singlesequence
 | 
					    task_str: singlesequence
 | 
				
			||||||
@ -317,6 +324,13 @@ data_source_args:
 | 
				
			|||||||
      sort_frames: false
 | 
					      sort_frames: false
 | 
				
			||||||
    path_manager_factory_PathManagerFactory_args:
 | 
					    path_manager_factory_PathManagerFactory_args:
 | 
				
			||||||
      silence_logs: true
 | 
					      silence_logs: true
 | 
				
			||||||
 | 
					  dataset_map_provider_LlffDatasetMapProvider_args:
 | 
				
			||||||
 | 
					    base_dir: ???
 | 
				
			||||||
 | 
					    object_name: ???
 | 
				
			||||||
 | 
					    path_manager_factory_class_type: PathManagerFactory
 | 
				
			||||||
 | 
					    n_known_frames_for_test: null
 | 
				
			||||||
 | 
					    path_manager_factory_PathManagerFactory_args:
 | 
				
			||||||
 | 
					      silence_logs: true
 | 
				
			||||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
					  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
				
			||||||
    batch_size: 1
 | 
					    batch_size: 1
 | 
				
			||||||
    num_workers: 0
 | 
					    num_workers: 0
 | 
				
			||||||
 | 
				
			|||||||
@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.tools.config import registry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .load_blender import load_blender_data
 | 
				
			||||||
 | 
					from .single_sequence_dataset import (
 | 
				
			||||||
 | 
					    _interpret_blender_cameras,
 | 
				
			||||||
 | 
					    SingleSceneDatasetMapProviderBase,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@registry.register
 | 
				
			||||||
 | 
					class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Provides data for one scene from Blender synthetic dataset.
 | 
				
			||||||
 | 
					    Uses the code in load_blender.py
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Members:
 | 
				
			||||||
 | 
					        base_dir: directory holding the data for the scene.
 | 
				
			||||||
 | 
					        object_name: The name of the scene (e.g. "lego"). This is just used as a label.
 | 
				
			||||||
 | 
					            It will typically be equal to the name of the directory self.base_dir.
 | 
				
			||||||
 | 
					        path_manager_factory: Creates path manager which may be used for
 | 
				
			||||||
 | 
					            interpreting paths.
 | 
				
			||||||
 | 
					        n_known_frames_for_test: If set, training frames are included in the val
 | 
				
			||||||
 | 
					            and test datasets, and this many random training frames are added to
 | 
				
			||||||
 | 
					            each test batch. If not set, test batches each contain just a single
 | 
				
			||||||
 | 
					            testing frame.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _load_data(self) -> None:
 | 
				
			||||||
 | 
					        path_manager = self.path_manager_factory.get()
 | 
				
			||||||
 | 
					        images, poses, _, hwf, i_split = load_blender_data(
 | 
				
			||||||
 | 
					            self.base_dir,
 | 
				
			||||||
 | 
					            testskip=1,
 | 
				
			||||||
 | 
					            path_manager=path_manager,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        H, W, focal = hwf
 | 
				
			||||||
 | 
					        H, W = int(H), int(W)
 | 
				
			||||||
 | 
					        images = torch.from_numpy(images)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        self.poses = _interpret_blender_cameras(poses, H, W, focal)
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        self.images = images
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        self.i_split = i_split
 | 
				
			||||||
@ -8,9 +8,11 @@ from typing import Tuple
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
 | 
					from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import json_index_dataset_map_provider  # noqa
 | 
					from .blender_dataset_map_provider import BlenderDatasetMapProvider  # noqa
 | 
				
			||||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
 | 
					from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
 | 
				
			||||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
 | 
					from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
 | 
				
			||||||
 | 
					from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider  # noqa
 | 
				
			||||||
 | 
					from .llff_dataset_map_provider import LlffDatasetMapProvider  # noqa
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class DataSourceBase(ReplaceableBase):
 | 
					class DataSourceBase(ReplaceableBase):
 | 
				
			||||||
 | 
				
			|||||||
@ -36,10 +36,11 @@ class FrameData(Mapping[str, Any]):
 | 
				
			|||||||
    Args:
 | 
					    Args:
 | 
				
			||||||
        frame_number: The number of the frame within its sequence.
 | 
					        frame_number: The number of the frame within its sequence.
 | 
				
			||||||
            0-based continuous integers.
 | 
					            0-based continuous integers.
 | 
				
			||||||
        frame_timestamp: The time elapsed since the start of a sequence in sec.
 | 
					 | 
				
			||||||
        sequence_name: The unique name of the frame's sequence.
 | 
					        sequence_name: The unique name of the frame's sequence.
 | 
				
			||||||
        sequence_category: The object category of the sequence.
 | 
					        sequence_category: The object category of the sequence.
 | 
				
			||||||
        image_size_hw: The size of the image in pixels; (height, width) tuple.
 | 
					        frame_timestamp: The time elapsed since the start of a sequence in sec.
 | 
				
			||||||
 | 
					        image_size_hw: The size of the image in pixels; (height, width) tensor
 | 
				
			||||||
 | 
					                        of shape (2,).
 | 
				
			||||||
        image_path: The qualified path to the loaded image (with dataset_root).
 | 
					        image_path: The qualified path to the loaded image (with dataset_root).
 | 
				
			||||||
        image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
 | 
					        image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
 | 
				
			||||||
            of the frame; elements are floats in [0, 1].
 | 
					            of the frame; elements are floats in [0, 1].
 | 
				
			||||||
@ -81,9 +82,9 @@ class FrameData(Mapping[str, Any]):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    frame_number: Optional[torch.LongTensor]
 | 
					    frame_number: Optional[torch.LongTensor]
 | 
				
			||||||
    frame_timestamp: Optional[torch.Tensor]
 | 
					 | 
				
			||||||
    sequence_name: Union[str, List[str]]
 | 
					    sequence_name: Union[str, List[str]]
 | 
				
			||||||
    sequence_category: Union[str, List[str]]
 | 
					    sequence_category: Union[str, List[str]]
 | 
				
			||||||
 | 
					    frame_timestamp: Optional[torch.Tensor] = None
 | 
				
			||||||
    image_size_hw: Optional[torch.Tensor] = None
 | 
					    image_size_hw: Optional[torch.Tensor] = None
 | 
				
			||||||
    image_path: Union[str, List[str], None] = None
 | 
					    image_path: Union[str, List[str], None] = None
 | 
				
			||||||
    image_rgb: Optional[torch.Tensor] = None
 | 
					    image_rgb: Optional[torch.Tensor] = None
 | 
				
			||||||
@ -101,7 +102,7 @@ class FrameData(Mapping[str, Any]):
 | 
				
			|||||||
    sequence_point_cloud_path: Union[str, List[str], None] = None
 | 
					    sequence_point_cloud_path: Union[str, List[str], None] = None
 | 
				
			||||||
    sequence_point_cloud: Optional[Pointclouds] = None
 | 
					    sequence_point_cloud: Optional[Pointclouds] = None
 | 
				
			||||||
    sequence_point_cloud_idx: Optional[torch.Tensor] = None
 | 
					    sequence_point_cloud_idx: Optional[torch.Tensor] = None
 | 
				
			||||||
    frame_type: Union[str, List[str], None] = None  # seen | unseen
 | 
					    frame_type: Union[str, List[str], None] = None  # known | unseen
 | 
				
			||||||
    meta: dict = field(default_factory=lambda: {})
 | 
					    meta: dict = field(default_factory=lambda: {})
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to(self, *args, **kwargs):
 | 
					    def to(self, *args, **kwargs):
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										61
									
								
								pytorch3d/implicitron/dataset/llff_dataset_map_provider.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								pytorch3d/implicitron/dataset/llff_dataset_map_provider.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,61 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.tools.config import registry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .load_llff import load_llff_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .single_sequence_dataset import (
 | 
				
			||||||
 | 
					    _interpret_blender_cameras,
 | 
				
			||||||
 | 
					    SingleSceneDatasetMapProviderBase,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@registry.register
 | 
				
			||||||
 | 
					class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Provides data for one scene from the LLFF dataset.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Members:
 | 
				
			||||||
 | 
					        base_dir: directory holding the data for the scene.
 | 
				
			||||||
 | 
					        object_name: The name of the scene (e.g. "fern"). This is just used as a label.
 | 
				
			||||||
 | 
					            It will typically be equal to the name of the directory self.base_dir.
 | 
				
			||||||
 | 
					        path_manager_factory: Creates path manager which may be used for
 | 
				
			||||||
 | 
					            interpreting paths.
 | 
				
			||||||
 | 
					        n_known_frames_for_test: If set, training frames are included in the val
 | 
				
			||||||
 | 
					            and test datasets, and this many random training frames are added to
 | 
				
			||||||
 | 
					            each test batch. If not set, test batches each contain just a single
 | 
				
			||||||
 | 
					            testing frame.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _load_data(self) -> None:
 | 
				
			||||||
 | 
					        path_manager = self.path_manager_factory.get()
 | 
				
			||||||
 | 
					        images, poses, _ = load_llff_data(
 | 
				
			||||||
 | 
					            self.base_dir, factor=8, path_manager=path_manager
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        hwf = poses[0, :3, -1]
 | 
				
			||||||
 | 
					        poses = poses[:, :3, :4]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        i_test = np.arange(images.shape[0])[::8]
 | 
				
			||||||
 | 
					        i_test_index = set(i_test.tolist())
 | 
				
			||||||
 | 
					        i_train = np.array(
 | 
				
			||||||
 | 
					            [i for i in np.arange(images.shape[0]) if i not in i_test_index]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        i_split = (i_train, i_test, i_test)
 | 
				
			||||||
 | 
					        H, W, focal = hwf
 | 
				
			||||||
 | 
					        H, W = int(H), int(W)
 | 
				
			||||||
 | 
					        images = torch.from_numpy(images)
 | 
				
			||||||
 | 
					        poses = torch.from_numpy(poses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        self.poses = _interpret_blender_cameras(poses, H, W, focal)
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        self.images = images
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        self.i_split = i_split
 | 
				
			||||||
							
								
								
									
										131
									
								
								pytorch3d/implicitron/dataset/load_blender.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										131
									
								
								pytorch3d/implicitron/dataset/load_blender.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,131 @@
 | 
				
			|||||||
 | 
					# @lint-ignore-every LICENSELINT
 | 
				
			||||||
 | 
					# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py
 | 
				
			||||||
 | 
					# Copyright (c) 2020 bmild
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def translate_by_t_along_z(t):
 | 
				
			||||||
 | 
					    tform = np.eye(4).astype(np.float32)
 | 
				
			||||||
 | 
					    tform[2][3] = t
 | 
				
			||||||
 | 
					    return tform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rotate_by_phi_along_x(phi):
 | 
				
			||||||
 | 
					    tform = np.eye(4).astype(np.float32)
 | 
				
			||||||
 | 
					    tform[1, 1] = tform[2, 2] = np.cos(phi)
 | 
				
			||||||
 | 
					    tform[1, 2] = -np.sin(phi)
 | 
				
			||||||
 | 
					    tform[2, 1] = -tform[1, 2]
 | 
				
			||||||
 | 
					    return tform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def rotate_by_theta_along_y(theta):
 | 
				
			||||||
 | 
					    tform = np.eye(4).astype(np.float32)
 | 
				
			||||||
 | 
					    tform[0, 0] = tform[2, 2] = np.cos(theta)
 | 
				
			||||||
 | 
					    tform[0, 2] = -np.sin(theta)
 | 
				
			||||||
 | 
					    tform[2, 0] = -tform[0, 2]
 | 
				
			||||||
 | 
					    return tform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def pose_spherical(theta, phi, radius):
 | 
				
			||||||
 | 
					    c2w = translate_by_t_along_z(radius)
 | 
				
			||||||
 | 
					    c2w = rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w
 | 
				
			||||||
 | 
					    c2w = rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w
 | 
				
			||||||
 | 
					    c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
 | 
				
			||||||
 | 
					    return c2w
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _local_path(path_manager, path):
 | 
				
			||||||
 | 
					    if path_manager is None:
 | 
				
			||||||
 | 
					        return path
 | 
				
			||||||
 | 
					    return path_manager.get_local_path(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_blender_data(
 | 
				
			||||||
 | 
					    basedir, half_res=False, testskip=1, debug=False, path_manager=None
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					    splits = ["train", "val", "test"]
 | 
				
			||||||
 | 
					    metas = {}
 | 
				
			||||||
 | 
					    for s in splits:
 | 
				
			||||||
 | 
					        path = os.path.join(basedir, f"transforms_{s}.json")
 | 
				
			||||||
 | 
					        with open(_local_path(path_manager, path)) as fp:
 | 
				
			||||||
 | 
					            metas[s] = json.load(fp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    all_imgs = []
 | 
				
			||||||
 | 
					    all_poses = []
 | 
				
			||||||
 | 
					    counts = [0]
 | 
				
			||||||
 | 
					    for s in splits:
 | 
				
			||||||
 | 
					        meta = metas[s]
 | 
				
			||||||
 | 
					        imgs = []
 | 
				
			||||||
 | 
					        poses = []
 | 
				
			||||||
 | 
					        if s == "train" or testskip == 0:
 | 
				
			||||||
 | 
					            skip = 1
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            skip = testskip
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for frame in meta["frames"][::skip]:
 | 
				
			||||||
 | 
					            fname = os.path.join(basedir, frame["file_path"] + ".png")
 | 
				
			||||||
 | 
					            imgs.append(np.array(Image.open(_local_path(path_manager, fname))))
 | 
				
			||||||
 | 
					            poses.append(np.array(frame["transform_matrix"]))
 | 
				
			||||||
 | 
					        imgs = (np.array(imgs) / 255.0).astype(np.float32)
 | 
				
			||||||
 | 
					        poses = np.array(poses).astype(np.float32)
 | 
				
			||||||
 | 
					        counts.append(counts[-1] + imgs.shape[0])
 | 
				
			||||||
 | 
					        all_imgs.append(imgs)
 | 
				
			||||||
 | 
					        all_poses.append(poses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    imgs = np.concatenate(all_imgs, 0)
 | 
				
			||||||
 | 
					    poses = np.concatenate(all_poses, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    H, W = imgs[0].shape[:2]
 | 
				
			||||||
 | 
					    camera_angle_x = float(meta["camera_angle_x"])
 | 
				
			||||||
 | 
					    focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    render_poses = torch.stack(
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            torch.from_numpy(pose_spherical(angle, -30.0, 4.0))
 | 
				
			||||||
 | 
					            for angle in np.linspace(-180, 180, 40 + 1)[:-1]
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        0,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # In debug mode, return extremely tiny images
 | 
				
			||||||
 | 
					    if debug:
 | 
				
			||||||
 | 
					        import cv2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        H = H // 32
 | 
				
			||||||
 | 
					        W = W // 32
 | 
				
			||||||
 | 
					        focal = focal / 32.0
 | 
				
			||||||
 | 
					        imgs = [
 | 
				
			||||||
 | 
					            torch.from_numpy(
 | 
				
			||||||
 | 
					                cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            for i in range(imgs.shape[0])
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        imgs = torch.stack(imgs, 0)
 | 
				
			||||||
 | 
					        poses = torch.from_numpy(poses)
 | 
				
			||||||
 | 
					        return imgs, poses, render_poses, [H, W, focal], i_split
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if half_res:
 | 
				
			||||||
 | 
					        import cv2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # TODO: resize images using INTER_AREA (cv2)
 | 
				
			||||||
 | 
					        H = H // 2
 | 
				
			||||||
 | 
					        W = W // 2
 | 
				
			||||||
 | 
					        focal = focal / 2.0
 | 
				
			||||||
 | 
					        imgs = [
 | 
				
			||||||
 | 
					            torch.from_numpy(
 | 
				
			||||||
 | 
					                cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					            for i in range(imgs.shape[0])
 | 
				
			||||||
 | 
					        ]
 | 
				
			||||||
 | 
					        imgs = torch.stack(imgs, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    poses = torch.from_numpy(poses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return imgs, poses, render_poses, [H, W, focal], i_split
 | 
				
			||||||
							
								
								
									
										343
									
								
								pytorch3d/implicitron/dataset/load_llff.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										343
									
								
								pytorch3d/implicitron/dataset/load_llff.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,343 @@
 | 
				
			|||||||
 | 
					# @lint-ignore-every LICENSELINT
 | 
				
			||||||
 | 
					# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py
 | 
				
			||||||
 | 
					# Copyright (c) 2020 bmild
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from PIL import Image
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Slightly modified version of LLFF data loading code
 | 
				
			||||||
 | 
					#  see https://github.com/Fyusion/LLFF for original
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _minify(basedir, path_manager, factors=(), resolutions=()):
 | 
				
			||||||
 | 
					    needtoload = False
 | 
				
			||||||
 | 
					    for r in factors:
 | 
				
			||||||
 | 
					        imgdir = os.path.join(basedir, "images_{}".format(r))
 | 
				
			||||||
 | 
					        if not _exists(path_manager, imgdir):
 | 
				
			||||||
 | 
					            needtoload = True
 | 
				
			||||||
 | 
					    for r in resolutions:
 | 
				
			||||||
 | 
					        imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0]))
 | 
				
			||||||
 | 
					        if not _exists(path_manager, imgdir):
 | 
				
			||||||
 | 
					            needtoload = True
 | 
				
			||||||
 | 
					    if not needtoload:
 | 
				
			||||||
 | 
					        return
 | 
				
			||||||
 | 
					    assert path_manager is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from subprocess import check_output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    imgdir = os.path.join(basedir, "images")
 | 
				
			||||||
 | 
					    imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))]
 | 
				
			||||||
 | 
					    imgs = [
 | 
				
			||||||
 | 
					        f
 | 
				
			||||||
 | 
					        for f in imgs
 | 
				
			||||||
 | 
					        if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]])
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    imgdir_orig = imgdir
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    wd = os.getcwd()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for r in factors + resolutions:
 | 
				
			||||||
 | 
					        if isinstance(r, int):
 | 
				
			||||||
 | 
					            name = "images_{}".format(r)
 | 
				
			||||||
 | 
					            resizearg = "{}%".format(100.0 / r)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            name = "images_{}x{}".format(r[1], r[0])
 | 
				
			||||||
 | 
					            resizearg = "{}x{}".format(r[1], r[0])
 | 
				
			||||||
 | 
					        imgdir = os.path.join(basedir, name)
 | 
				
			||||||
 | 
					        if os.path.exists(imgdir):
 | 
				
			||||||
 | 
					            continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        logger.info(f"Minifying {r}, {basedir}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        os.makedirs(imgdir)
 | 
				
			||||||
 | 
					        check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        ext = imgs[0].split(".")[-1]
 | 
				
			||||||
 | 
					        args = " ".join(
 | 
				
			||||||
 | 
					            ["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)]
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        logger.info(args)
 | 
				
			||||||
 | 
					        os.chdir(imgdir)
 | 
				
			||||||
 | 
					        check_output(args, shell=True)
 | 
				
			||||||
 | 
					        os.chdir(wd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if ext != "png":
 | 
				
			||||||
 | 
					            check_output("rm {}/*.{}".format(imgdir, ext), shell=True)
 | 
				
			||||||
 | 
					            logger.info("Removed duplicates")
 | 
				
			||||||
 | 
					        logger.info("Done")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _load_data(
 | 
				
			||||||
 | 
					    basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    poses_arr = np.load(
 | 
				
			||||||
 | 
					        _local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
 | 
				
			||||||
 | 
					    bds = poses_arr[:, -2:].transpose([1, 0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    img0 = [
 | 
				
			||||||
 | 
					        os.path.join(basedir, "images", f)
 | 
				
			||||||
 | 
					        for f in sorted(_ls(path_manager, os.path.join(basedir, "images")))
 | 
				
			||||||
 | 
					        if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
 | 
				
			||||||
 | 
					    ][0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def imread(f):
 | 
				
			||||||
 | 
					        return np.array(Image.open(f))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sh = imread(_local_path(path_manager, img0)).shape
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sfx = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if factor is not None:
 | 
				
			||||||
 | 
					        sfx = "_{}".format(factor)
 | 
				
			||||||
 | 
					        _minify(basedir, path_manager, factors=[factor])
 | 
				
			||||||
 | 
					        factor = factor
 | 
				
			||||||
 | 
					    elif height is not None:
 | 
				
			||||||
 | 
					        factor = sh[0] / float(height)
 | 
				
			||||||
 | 
					        width = int(sh[1] / factor)
 | 
				
			||||||
 | 
					        _minify(basedir, path_manager, resolutions=[[height, width]])
 | 
				
			||||||
 | 
					        sfx = "_{}x{}".format(width, height)
 | 
				
			||||||
 | 
					    elif width is not None:
 | 
				
			||||||
 | 
					        factor = sh[1] / float(width)
 | 
				
			||||||
 | 
					        height = int(sh[0] / factor)
 | 
				
			||||||
 | 
					        _minify(basedir, path_manager, resolutions=[[height, width]])
 | 
				
			||||||
 | 
					        sfx = "_{}x{}".format(width, height)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        factor = 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    imgdir = os.path.join(basedir, "images" + sfx)
 | 
				
			||||||
 | 
					    if not _exists(path_manager, imgdir):
 | 
				
			||||||
 | 
					        raise ValueError(f"{imgdir} does not exist, returning")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    imgfiles = [
 | 
				
			||||||
 | 
					        _local_path(path_manager, os.path.join(imgdir, f))
 | 
				
			||||||
 | 
					        for f in sorted(_ls(path_manager, imgdir))
 | 
				
			||||||
 | 
					        if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    if poses.shape[-1] != len(imgfiles):
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            "Mismatch between imgs {} and poses {} !!!!".format(
 | 
				
			||||||
 | 
					                len(imgfiles), poses.shape[-1]
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sh = imread(imgfiles[0]).shape
 | 
				
			||||||
 | 
					    poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
 | 
				
			||||||
 | 
					    poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not load_imgs:
 | 
				
			||||||
 | 
					        return poses, bds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles]
 | 
				
			||||||
 | 
					    imgs = np.stack(imgs, -1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger.info(f"Loaded image data, shape {imgs.shape}")
 | 
				
			||||||
 | 
					    return poses, bds, imgs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def normalize(x):
 | 
				
			||||||
 | 
					    denom = np.linalg.norm(x)
 | 
				
			||||||
 | 
					    if denom < 0.001:
 | 
				
			||||||
 | 
					        warnings.warn("unsafe normalize()")
 | 
				
			||||||
 | 
					    return x / denom
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def viewmatrix(z, up, pos):
 | 
				
			||||||
 | 
					    vec2 = normalize(z)
 | 
				
			||||||
 | 
					    vec1_avg = up
 | 
				
			||||||
 | 
					    vec0 = normalize(np.cross(vec1_avg, vec2))
 | 
				
			||||||
 | 
					    vec1 = normalize(np.cross(vec2, vec0))
 | 
				
			||||||
 | 
					    m = np.stack([vec0, vec1, vec2, pos], 1)
 | 
				
			||||||
 | 
					    return m
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def ptstocam(pts, c2w):
 | 
				
			||||||
 | 
					    tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
 | 
				
			||||||
 | 
					    return tt
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def poses_avg(poses):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    hwf = poses[0, :3, -1:]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    center = poses[:, :3, 3].mean(0)
 | 
				
			||||||
 | 
					    vec2 = normalize(poses[:, :3, 2].sum(0))
 | 
				
			||||||
 | 
					    up = poses[:, :3, 1].sum(0)
 | 
				
			||||||
 | 
					    c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return c2w
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
 | 
				
			||||||
 | 
					    render_poses = []
 | 
				
			||||||
 | 
					    rads = np.array(list(rads) + [1.0])
 | 
				
			||||||
 | 
					    hwf = c2w[:, 4:5]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
 | 
				
			||||||
 | 
					        c = np.dot(
 | 
				
			||||||
 | 
					            c2w[:3, :4],
 | 
				
			||||||
 | 
					            np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
 | 
				
			||||||
 | 
					            * rads,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
 | 
				
			||||||
 | 
					        render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
 | 
				
			||||||
 | 
					    return render_poses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def recenter_poses(poses):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    poses_ = poses + 0
 | 
				
			||||||
 | 
					    bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
 | 
				
			||||||
 | 
					    c2w = poses_avg(poses)
 | 
				
			||||||
 | 
					    c2w = np.concatenate([c2w[:3, :4], bottom], -2)
 | 
				
			||||||
 | 
					    bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
 | 
				
			||||||
 | 
					    poses = np.concatenate([poses[:, :3, :4], bottom], -2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    poses = np.linalg.inv(c2w) @ poses
 | 
				
			||||||
 | 
					    poses_[:, :3, :4] = poses[:, :3, :4]
 | 
				
			||||||
 | 
					    poses = poses_
 | 
				
			||||||
 | 
					    return poses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def spherify_poses(poses, bds):
 | 
				
			||||||
 | 
					    def add_row_to_homogenize_transform(p):
 | 
				
			||||||
 | 
					        r"""Add the last row to homogenize 3 x 4 transformation matrices."""
 | 
				
			||||||
 | 
					        return np.concatenate(
 | 
				
			||||||
 | 
					            [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # p34_to_44 = lambda p: np.concatenate(
 | 
				
			||||||
 | 
					    #     [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
 | 
				
			||||||
 | 
					    # )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    p34_to_44 = add_row_to_homogenize_transform
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rays_d = poses[:, :3, 2:3]
 | 
				
			||||||
 | 
					    rays_o = poses[:, :3, 3:4]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def min_line_dist(rays_o, rays_d):
 | 
				
			||||||
 | 
					        A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
 | 
				
			||||||
 | 
					        b_i = -A_i @ rays_o
 | 
				
			||||||
 | 
					        pt_mindist = np.squeeze(
 | 
				
			||||||
 | 
					            -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return pt_mindist
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pt_mindist = min_line_dist(rays_o, rays_d)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    center = pt_mindist
 | 
				
			||||||
 | 
					    up = (poses[:, :3, 3] - center).mean(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    vec0 = normalize(up)
 | 
				
			||||||
 | 
					    vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0))
 | 
				
			||||||
 | 
					    vec2 = normalize(np.cross(vec0, vec1))
 | 
				
			||||||
 | 
					    pos = center
 | 
				
			||||||
 | 
					    c2w = np.stack([vec1, vec2, vec0, pos], 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    sc = 1.0 / rad
 | 
				
			||||||
 | 
					    poses_reset[:, :3, 3] *= sc
 | 
				
			||||||
 | 
					    bds *= sc
 | 
				
			||||||
 | 
					    rad *= sc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    centroid = np.mean(poses_reset[:, :3, 3], 0)
 | 
				
			||||||
 | 
					    zh = centroid[2]
 | 
				
			||||||
 | 
					    radcircle = np.sqrt(rad**2 - zh**2)
 | 
				
			||||||
 | 
					    new_poses = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for th in np.linspace(0.0, 2.0 * np.pi, 120):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
 | 
				
			||||||
 | 
					        up = np.array([0, 0, -1.0])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        vec2 = normalize(camorigin)
 | 
				
			||||||
 | 
					        vec0 = normalize(np.cross(vec2, up))
 | 
				
			||||||
 | 
					        vec1 = normalize(np.cross(vec2, vec0))
 | 
				
			||||||
 | 
					        pos = camorigin
 | 
				
			||||||
 | 
					        p = np.stack([vec0, vec1, vec2, pos], 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        new_poses.append(p)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_poses = np.stack(new_poses, 0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    new_poses = np.concatenate(
 | 
				
			||||||
 | 
					        [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    poses_reset = np.concatenate(
 | 
				
			||||||
 | 
					        [
 | 
				
			||||||
 | 
					            poses_reset[:, :3, :4],
 | 
				
			||||||
 | 
					            np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape),
 | 
				
			||||||
 | 
					        ],
 | 
				
			||||||
 | 
					        -1,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return poses_reset, new_poses, bds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _local_path(path_manager, path):
 | 
				
			||||||
 | 
					    if path_manager is None:
 | 
				
			||||||
 | 
					        return path
 | 
				
			||||||
 | 
					    return path_manager.get_local_path(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _ls(path_manager, path):
 | 
				
			||||||
 | 
					    if path_manager is None:
 | 
				
			||||||
 | 
					        return os.path.listdir(path)
 | 
				
			||||||
 | 
					    return path_manager.ls(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _exists(path_manager, path):
 | 
				
			||||||
 | 
					    if path_manager is None:
 | 
				
			||||||
 | 
					        return os.path.exists(path)
 | 
				
			||||||
 | 
					    return path_manager.exists(path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def load_llff_data(
 | 
				
			||||||
 | 
					    basedir,
 | 
				
			||||||
 | 
					    factor=8,
 | 
				
			||||||
 | 
					    recenter=True,
 | 
				
			||||||
 | 
					    bd_factor=0.75,
 | 
				
			||||||
 | 
					    spherify=False,
 | 
				
			||||||
 | 
					    path_zflat=False,
 | 
				
			||||||
 | 
					    path_manager=None,
 | 
				
			||||||
 | 
					):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    poses, bds, imgs = _load_data(
 | 
				
			||||||
 | 
					        basedir, factor=factor, path_manager=path_manager
 | 
				
			||||||
 | 
					    )  # factor=8 downsamples original imgs by 8x
 | 
				
			||||||
 | 
					    logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Correct rotation matrix ordering and move variable dim to axis 0
 | 
				
			||||||
 | 
					    poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
 | 
				
			||||||
 | 
					    poses = np.moveaxis(poses, -1, 0).astype(np.float32)
 | 
				
			||||||
 | 
					    imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
 | 
				
			||||||
 | 
					    images = imgs
 | 
				
			||||||
 | 
					    bds = np.moveaxis(bds, -1, 0).astype(np.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Rescale if bd_factor is provided
 | 
				
			||||||
 | 
					    sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor)
 | 
				
			||||||
 | 
					    poses[:, :3, 3] *= sc
 | 
				
			||||||
 | 
					    bds *= sc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if recenter:
 | 
				
			||||||
 | 
					        poses = recenter_poses(poses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if spherify:
 | 
				
			||||||
 | 
					        poses, render_poses, bds = spherify_poses(poses, bds)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    images = images.astype(np.float32)
 | 
				
			||||||
 | 
					    poses = poses.astype(np.float32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return images, poses, bds
 | 
				
			||||||
							
								
								
									
										181
									
								
								pytorch3d/implicitron/dataset/single_sequence_dataset.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										181
									
								
								pytorch3d/implicitron/dataset/single_sequence_dataset.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,181 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# This file defines a base class for dataset map providers which
 | 
				
			||||||
 | 
					# provide data for a single scene.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from dataclasses import field
 | 
				
			||||||
 | 
					from typing import Iterable, List, Optional
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import numpy as np
 | 
				
			||||||
 | 
					import torch
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.tools.config import (
 | 
				
			||||||
 | 
					    Configurable,
 | 
				
			||||||
 | 
					    expand_args_fields,
 | 
				
			||||||
 | 
					    run_auto_creation,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from pytorch3d.renderer import PerspectiveCameras
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from .dataset_base import DatasetBase, FrameData
 | 
				
			||||||
 | 
					from .dataset_map_provider import (
 | 
				
			||||||
 | 
					    DatasetMap,
 | 
				
			||||||
 | 
					    DatasetMapProviderBase,
 | 
				
			||||||
 | 
					    PathManagerFactory,
 | 
				
			||||||
 | 
					    Task,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					_SINGLE_SEQUENCE_NAME: str = "one_sequence"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class SingleSceneDataset(DatasetBase, Configurable):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    A dataset from images from a single scene.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    images: List[torch.Tensor] = field()
 | 
				
			||||||
 | 
					    poses: List[PerspectiveCameras] = field()
 | 
				
			||||||
 | 
					    object_name: str = field()
 | 
				
			||||||
 | 
					    frame_types: List[str] = field()
 | 
				
			||||||
 | 
					    eval_batches: Optional[List[List[int]]] = field()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def sequence_names(self) -> Iterable[str]:
 | 
				
			||||||
 | 
					        return [_SINGLE_SEQUENCE_NAME]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __len__(self) -> int:
 | 
				
			||||||
 | 
					        return len(self.poses)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __getitem__(self, index) -> FrameData:
 | 
				
			||||||
 | 
					        if index >= len(self):
 | 
				
			||||||
 | 
					            raise IndexError(f"index {index} out of range {len(self)}")
 | 
				
			||||||
 | 
					        image = self.images[index]
 | 
				
			||||||
 | 
					        pose = self.poses[index]
 | 
				
			||||||
 | 
					        frame_type = self.frame_types[index]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        frame_data = FrameData(
 | 
				
			||||||
 | 
					            frame_number=index,
 | 
				
			||||||
 | 
					            sequence_name=_SINGLE_SEQUENCE_NAME,
 | 
				
			||||||
 | 
					            sequence_category=self.object_name,
 | 
				
			||||||
 | 
					            camera=pose,
 | 
				
			||||||
 | 
					            image_size_hw=torch.tensor(image.shape[1:]),
 | 
				
			||||||
 | 
					            image_rgb=image,
 | 
				
			||||||
 | 
					            frame_type=frame_type,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        return frame_data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_eval_batches(self) -> Optional[List[List[int]]]:
 | 
				
			||||||
 | 
					        return self.eval_batches
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# pyre-fixme[13]: Uninitialized attribute
 | 
				
			||||||
 | 
					class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Base for provider of data for one scene from LLFF or blender datasets.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Members:
 | 
				
			||||||
 | 
					        base_dir: directory holding the data for the scene.
 | 
				
			||||||
 | 
					        object_name: The name of the scene (e.g. "lego"). This is just used as a label.
 | 
				
			||||||
 | 
					            It will typically be equal to the name of the directory self.base_dir.
 | 
				
			||||||
 | 
					        path_manager_factory: Creates path manager which may be used for
 | 
				
			||||||
 | 
					            interpreting paths.
 | 
				
			||||||
 | 
					        n_known_frames_for_test: If set, training frames are included in the val
 | 
				
			||||||
 | 
					            and test datasets, and this many random training frames are added to
 | 
				
			||||||
 | 
					            each test batch. If not set, test batches each contain just a single
 | 
				
			||||||
 | 
					            testing frame.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    base_dir: str
 | 
				
			||||||
 | 
					    object_name: str
 | 
				
			||||||
 | 
					    path_manager_factory: PathManagerFactory
 | 
				
			||||||
 | 
					    path_manager_factory_class_type: str = "PathManagerFactory"
 | 
				
			||||||
 | 
					    n_known_frames_for_test: Optional[int] = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __post_init__(self) -> None:
 | 
				
			||||||
 | 
					        run_auto_creation(self)
 | 
				
			||||||
 | 
					        self._load_data()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _load_data(self) -> None:
 | 
				
			||||||
 | 
					        # This must be defined by each subclass,
 | 
				
			||||||
 | 
					        # and should set poses, images and i_split on self.
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _get_dataset(
 | 
				
			||||||
 | 
					        self, split_idx: int, frame_type: str, set_eval_batches: bool = False
 | 
				
			||||||
 | 
					    ) -> SingleSceneDataset:
 | 
				
			||||||
 | 
					        expand_args_fields(SingleSceneDataset)
 | 
				
			||||||
 | 
					        # pyre-ignore[16]
 | 
				
			||||||
 | 
					        split = self.i_split[split_idx]
 | 
				
			||||||
 | 
					        frame_types = [frame_type] * len(split)
 | 
				
			||||||
 | 
					        eval_batches = [[i] for i in range(len(split))]
 | 
				
			||||||
 | 
					        if split_idx != 0 and self.n_known_frames_for_test is not None:
 | 
				
			||||||
 | 
					            train_split = self.i_split[0]
 | 
				
			||||||
 | 
					            if set_eval_batches:
 | 
				
			||||||
 | 
					                generator = np.random.default_rng(seed=0)
 | 
				
			||||||
 | 
					                for batch in eval_batches:
 | 
				
			||||||
 | 
					                    to_add = generator.choice(
 | 
				
			||||||
 | 
					                        len(train_split), self.n_known_frames_for_test
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                    batch.extend((to_add + len(split)).tolist())
 | 
				
			||||||
 | 
					            split = np.concatenate([split, train_split])
 | 
				
			||||||
 | 
					            frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # pyre-ignore[28]
 | 
				
			||||||
 | 
					        return SingleSceneDataset(
 | 
				
			||||||
 | 
					            object_name=self.object_name,
 | 
				
			||||||
 | 
					            # pyre-ignore[16]
 | 
				
			||||||
 | 
					            images=self.images[split],
 | 
				
			||||||
 | 
					            # pyre-ignore[16]
 | 
				
			||||||
 | 
					            poses=[self.poses[i] for i in split],
 | 
				
			||||||
 | 
					            frame_types=frame_types,
 | 
				
			||||||
 | 
					            eval_batches=eval_batches if set_eval_batches else None,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_dataset_map(self) -> DatasetMap:
 | 
				
			||||||
 | 
					        return DatasetMap(
 | 
				
			||||||
 | 
					            train=self._get_dataset(0, DATASET_TYPE_KNOWN),
 | 
				
			||||||
 | 
					            val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
 | 
				
			||||||
 | 
					            test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_task(self) -> Task:
 | 
				
			||||||
 | 
					        return Task.SINGLE_SEQUENCE
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _interpret_blender_cameras(
 | 
				
			||||||
 | 
					    poses: torch.Tensor, H: int, W: int, focal: float
 | 
				
			||||||
 | 
					) -> List[PerspectiveCameras]:
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    Convert 4x4 matrices representing cameras in blender format
 | 
				
			||||||
 | 
					    to PyTorch3D format.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        poses: N x 3 x 4 camera matrices
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    pose_target_cameras = []
 | 
				
			||||||
 | 
					    for pose_target in poses:
 | 
				
			||||||
 | 
					        pose_target = pose_target[:3, :4]
 | 
				
			||||||
 | 
					        mtx = torch.eye(4, dtype=pose_target.dtype)
 | 
				
			||||||
 | 
					        mtx[:3, :3] = pose_target[:3, :3].t()
 | 
				
			||||||
 | 
					        mtx[3, :3] = pose_target[:, 3]
 | 
				
			||||||
 | 
					        mtx = mtx.inverse()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # flip the XZ coordinates.
 | 
				
			||||||
 | 
					        mtx[:, [0, 2]] *= -1.0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
 | 
				
			||||||
 | 
					        principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        cameras = PerspectiveCameras(
 | 
				
			||||||
 | 
					            focal_length=focal_length_pt3,
 | 
				
			||||||
 | 
					            principal_point=principal_point_pt3,
 | 
				
			||||||
 | 
					            R=Rpt3[None],
 | 
				
			||||||
 | 
					            T=Tpt3,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        pose_target_cameras.append(cameras)
 | 
				
			||||||
 | 
					    return pose_target_cameras
 | 
				
			||||||
@ -220,6 +220,7 @@ class Configurable:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
_X = TypeVar("X", bound=ReplaceableBase)
 | 
					_X = TypeVar("X", bound=ReplaceableBase)
 | 
				
			||||||
 | 
					_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _Registry:
 | 
					class _Registry:
 | 
				
			||||||
@ -307,20 +308,23 @@ class _Registry:
 | 
				
			|||||||
                        It determines the namespace.
 | 
					                        It determines the namespace.
 | 
				
			||||||
                        This will typically be a direct subclass of ReplaceableBase.
 | 
					                        This will typically be a direct subclass of ReplaceableBase.
 | 
				
			||||||
        Returns:
 | 
					        Returns:
 | 
				
			||||||
            list of class types
 | 
					            list of class types in alphabetical order of registered name.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if self._is_base_class(base_class_wanted):
 | 
					        if self._is_base_class(base_class_wanted):
 | 
				
			||||||
            return list(self._mapping[base_class_wanted].values())
 | 
					            source = self._mapping[base_class_wanted]
 | 
				
			||||||
 | 
					            return [source[key] for key in sorted(source)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        base_class = self._base_class_from_class(base_class_wanted)
 | 
					        base_class = self._base_class_from_class(base_class_wanted)
 | 
				
			||||||
        if base_class is None:
 | 
					        if base_class is None:
 | 
				
			||||||
            raise ValueError(
 | 
					            raise ValueError(
 | 
				
			||||||
                f"Cannot look up {base_class_wanted}. Cannot tell what it is."
 | 
					                f"Cannot look up {base_class_wanted}. Cannot tell what it is."
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        source = self._mapping[base_class]
 | 
				
			||||||
        return [
 | 
					        return [
 | 
				
			||||||
            class_
 | 
					            source[key]
 | 
				
			||||||
            for class_ in self._mapping[base_class].values()
 | 
					            for key in sorted(source)
 | 
				
			||||||
            if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
 | 
					            if issubclass(source[key], base_class_wanted)
 | 
				
			||||||
 | 
					            and source[key] is not base_class_wanted
 | 
				
			||||||
        ]
 | 
					        ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
@ -647,8 +651,8 @@ def _is_actually_dataclass(some_class) -> bool:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def expand_args_fields(
 | 
					def expand_args_fields(
 | 
				
			||||||
    some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
 | 
					    some_class: Type[_Y], *, _do_not_process: Tuple[type, ...] = ()
 | 
				
			||||||
) -> Type[_X]:
 | 
					) -> Type[_Y]:
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    This expands a class which inherits Configurable or ReplaceableBase classes,
 | 
					    This expands a class which inherits Configurable or ReplaceableBase classes,
 | 
				
			||||||
    including dataclass processing. some_class is modified in place by this function.
 | 
					    including dataclass processing. some_class is modified in place by this function.
 | 
				
			||||||
 | 
				
			|||||||
@ -13,6 +13,7 @@ from .blending import (
 | 
				
			|||||||
from .camera_utils import join_cameras_as_batch, rotate_on_spot
 | 
					from .camera_utils import join_cameras_as_batch, rotate_on_spot
 | 
				
			||||||
from .cameras import (  # deprecated  # deprecated  # deprecated  # deprecated
 | 
					from .cameras import (  # deprecated  # deprecated  # deprecated  # deprecated
 | 
				
			||||||
    camera_position_from_spherical_angles,
 | 
					    camera_position_from_spherical_angles,
 | 
				
			||||||
 | 
					    CamerasBase,
 | 
				
			||||||
    FoVOrthographicCameras,
 | 
					    FoVOrthographicCameras,
 | 
				
			||||||
    FoVPerspectiveCameras,
 | 
					    FoVPerspectiveCameras,
 | 
				
			||||||
    get_world_to_view_transform,
 | 
					    get_world_to_view_transform,
 | 
				
			||||||
 | 
				
			|||||||
@ -1,5 +1,12 @@
 | 
				
			|||||||
dataset_map_provider_class_type: ???
 | 
					dataset_map_provider_class_type: ???
 | 
				
			||||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
 | 
					data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
 | 
				
			||||||
 | 
					dataset_map_provider_BlenderDatasetMapProvider_args:
 | 
				
			||||||
 | 
					  base_dir: ???
 | 
				
			||||||
 | 
					  object_name: ???
 | 
				
			||||||
 | 
					  path_manager_factory_class_type: PathManagerFactory
 | 
				
			||||||
 | 
					  n_known_frames_for_test: null
 | 
				
			||||||
 | 
					  path_manager_factory_PathManagerFactory_args:
 | 
				
			||||||
 | 
					    silence_logs: true
 | 
				
			||||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
					dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
				
			||||||
  category: ???
 | 
					  category: ???
 | 
				
			||||||
  task_str: singlesequence
 | 
					  task_str: singlesequence
 | 
				
			||||||
@ -35,6 +42,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
 | 
				
			|||||||
    sort_frames: false
 | 
					    sort_frames: false
 | 
				
			||||||
  path_manager_factory_PathManagerFactory_args:
 | 
					  path_manager_factory_PathManagerFactory_args:
 | 
				
			||||||
    silence_logs: true
 | 
					    silence_logs: true
 | 
				
			||||||
 | 
					dataset_map_provider_LlffDatasetMapProvider_args:
 | 
				
			||||||
 | 
					  base_dir: ???
 | 
				
			||||||
 | 
					  object_name: ???
 | 
				
			||||||
 | 
					  path_manager_factory_class_type: PathManagerFactory
 | 
				
			||||||
 | 
					  n_known_frames_for_test: null
 | 
				
			||||||
 | 
					  path_manager_factory_PathManagerFactory_args:
 | 
				
			||||||
 | 
					    silence_logs: true
 | 
				
			||||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
					data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
				
			||||||
  batch_size: 1
 | 
					  batch_size: 1
 | 
				
			||||||
  num_workers: 0
 | 
					  num_workers: 0
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										97
									
								
								tests/implicitron/test_data_llff.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										97
									
								
								tests/implicitron/test_data_llff.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,97 @@
 | 
				
			|||||||
 | 
					# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
				
			||||||
 | 
					# All rights reserved.
 | 
				
			||||||
 | 
					#
 | 
				
			||||||
 | 
					# This source code is licensed under the BSD-style license found in the
 | 
				
			||||||
 | 
					# LICENSE file in the root directory of this source tree.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import os
 | 
				
			||||||
 | 
					import unittest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
 | 
				
			||||||
 | 
					    BlenderDatasetMapProvider,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.dataset.dataset_base import FrameData
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.dataset.llff_dataset_map_provider import (
 | 
				
			||||||
 | 
					    LlffDatasetMapProvider,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					from pytorch3d.implicitron.tools.config import expand_args_fields
 | 
				
			||||||
 | 
					from tests.common_testing import TestCaseMixin
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# These tests are only run internally, where the data is available.
 | 
				
			||||||
 | 
					internal = os.environ.get("FB_TEST", False)
 | 
				
			||||||
 | 
					inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False)
 | 
				
			||||||
 | 
					skip_tests = not internal or inside_re_worker
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@unittest.skipIf(skip_tests, "no data")
 | 
				
			||||||
 | 
					class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
				
			||||||
 | 
					    def test_synthetic(self):
 | 
				
			||||||
 | 
					        expand_args_fields(BlenderDatasetMapProvider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        provider = BlenderDatasetMapProvider(
 | 
				
			||||||
 | 
					            base_dir="manifold://co3d/tree/nerf_data/nerf_synthetic/lego",
 | 
				
			||||||
 | 
					            object_name="lego",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        dataset_map = provider.get_dataset_map()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, length in [("train", 100), ("val", 100), ("test", 200)]:
 | 
				
			||||||
 | 
					            dataset = getattr(dataset_map, name)
 | 
				
			||||||
 | 
					            self.assertEqual(len(dataset), length)
 | 
				
			||||||
 | 
					            # try getting a value
 | 
				
			||||||
 | 
					            value = dataset[0]
 | 
				
			||||||
 | 
					            self.assertIsInstance(value, FrameData)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_llff(self):
 | 
				
			||||||
 | 
					        expand_args_fields(LlffDatasetMapProvider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        provider = LlffDatasetMapProvider(
 | 
				
			||||||
 | 
					            base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
 | 
				
			||||||
 | 
					            object_name="fern",
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        dataset_map = provider.get_dataset_map()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, length, frame_type in [
 | 
				
			||||||
 | 
					            ("train", 17, "known"),
 | 
				
			||||||
 | 
					            ("test", 3, "unseen"),
 | 
				
			||||||
 | 
					            ("val", 3, "unseen"),
 | 
				
			||||||
 | 
					        ]:
 | 
				
			||||||
 | 
					            dataset = getattr(dataset_map, name)
 | 
				
			||||||
 | 
					            self.assertEqual(len(dataset), length)
 | 
				
			||||||
 | 
					            # try getting a value
 | 
				
			||||||
 | 
					            value = dataset[0]
 | 
				
			||||||
 | 
					            self.assertIsInstance(value, FrameData)
 | 
				
			||||||
 | 
					            self.assertEqual(value.frame_type, frame_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
 | 
				
			||||||
 | 
					        for batch in dataset_map.test.get_eval_batches():
 | 
				
			||||||
 | 
					            self.assertEqual(len(batch), 1)
 | 
				
			||||||
 | 
					            self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_include_known_frames(self):
 | 
				
			||||||
 | 
					        expand_args_fields(LlffDatasetMapProvider)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        provider = LlffDatasetMapProvider(
 | 
				
			||||||
 | 
					            base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
 | 
				
			||||||
 | 
					            object_name="fern",
 | 
				
			||||||
 | 
					            n_known_frames_for_test=2,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        dataset_map = provider.get_dataset_map()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for name, types in [
 | 
				
			||||||
 | 
					            ("train", ["known"] * 17),
 | 
				
			||||||
 | 
					            ("val", ["unseen"] * 3 + ["known"] * 17),
 | 
				
			||||||
 | 
					            ("test", ["unseen"] * 3 + ["known"] * 17),
 | 
				
			||||||
 | 
					        ]:
 | 
				
			||||||
 | 
					            dataset = getattr(dataset_map, name)
 | 
				
			||||||
 | 
					            self.assertEqual(len(dataset), len(types))
 | 
				
			||||||
 | 
					            for i, frame_type in enumerate(types):
 | 
				
			||||||
 | 
					                value = dataset[i]
 | 
				
			||||||
 | 
					                self.assertEqual(value.frame_type, frame_type)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
 | 
				
			||||||
 | 
					        for batch in dataset_map.test.get_eval_batches():
 | 
				
			||||||
 | 
					            self.assertEqual(len(batch), 3)
 | 
				
			||||||
 | 
					            self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen")
 | 
				
			||||||
 | 
					            for i in batch[1:]:
 | 
				
			||||||
 | 
					                self.assertEqual(dataset_map.test[i].frame_type, "known")
 | 
				
			||||||
@ -6,6 +6,7 @@
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
import unittest
 | 
					import unittest
 | 
				
			||||||
 | 
					import unittest.mock
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from omegaconf import OmegaConf
 | 
					from omegaconf import OmegaConf
 | 
				
			||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
					from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user