diff --git a/LICENSE-3RD-PARTY b/LICENSE-3RD-PARTY index 147c9748..f55b7dce 100644 --- a/LICENSE-3RD-PARTY +++ b/LICENSE-3RD-PARTY @@ -46,3 +46,26 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +NeRF https://github.com/bmild/nerf/ + +Copyright (c) 2020 bmild + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index 5106f72c..5376c422 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -5,7 +5,7 @@ Implicitron is a PyTorch3D-based framework for new-view synthesis via modeling t # License Implicitron is distributed as part of PyTorch3D under the [BSD license](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE). -It includes code from [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos. +It includes code from the [NeRF](https://github.com/bmild/nerf), [SRN](http://github.com/vsitzmann/scene-representation-networks) and [IDR](http://github.com/lioryariv/idr) repos. See [LICENSE-3RD-PARTY](https://github.com/facebookresearch/pytorch3d/blob/main/LICENSE-3RD-PARTY) for their licenses. diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index 198eb77d..c4c80c0a 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -315,7 +315,7 @@ def trainvalidate( epoch, loader, optimizer, - validation, + validation: bool, bp_var: str = "objective", metric_print_interval: int = 5, visualize_interval: int = 100, diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 1cc510ed..d560c8d7 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -95,13 +95,6 @@ generic_model_args: append_coarse_samples_to_fine: true density_noise_std_train: 0.0 return_weights: false - raymarcher_EmissionAbsorptionRaymarcher_args: - surface_thickness: 1 - bg_color: - - 0.0 - background_opacity: 10000000000.0 - density_relu: true - blend_output: false raymarcher_CumsumRaymarcher_args: surface_thickness: 1 bg_color: @@ -109,6 +102,13 @@ generic_model_args: background_opacity: 0.0 density_relu: true blend_output: false + raymarcher_EmissionAbsorptionRaymarcher_args: + surface_thickness: 1 + bg_color: + - 0.0 + background_opacity: 10000000000.0 + density_relu: true + blend_output: false renderer_SignedDistanceFunctionRenderer_args: render_features_dimensions: 3 ray_tracer_args: @@ -157,6 +157,21 @@ generic_model_args: view_sampler_args: masked_sampling: false sampling_mode: bilinear + feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + weight_by_ray_angle_gamma: 1.0 + min_ray_angle_weight: 0.1 + feature_aggregator_AngleWeightedReductionFeatureAggregator_args: + exclude_target_view: true + exclude_target_view_mask_features: true + concatenate_output: true + reduction_functions: + - AVG + - STD + weight_by_ray_angle_gamma: 1.0 + min_ray_angle_weight: 0.1 feature_aggregator_IdentityFeatureAggregator_args: exclude_target_view: true exclude_target_view_mask_features: true @@ -168,21 +183,6 @@ generic_model_args: reduction_functions: - AVG - STD - feature_aggregator_AngleWeightedReductionFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - reduction_functions: - - AVG - - STD - weight_by_ray_angle_gamma: 1.0 - min_ray_angle_weight: 0.1 - feature_aggregator_AngleWeightedIdentityFeatureAggregator_args: - exclude_target_view: true - exclude_target_view_mask_features: true - concatenate_output: true - weight_by_ray_angle_gamma: 1.0 - min_ray_angle_weight: 0.1 implicit_function_IdrFeatureField_args: feature_vector_size: 3 d_in: 3 @@ -203,19 +203,6 @@ generic_model_args: n_harmonic_functions_xyz: 0 pooled_feature_dim: 0 encoding_dim: 0 - implicit_function_NeuralRadianceFieldImplicitFunction_args: - n_harmonic_functions_xyz: 10 - n_harmonic_functions_dir: 4 - n_hidden_neurons_dir: 128 - latent_dim: 0 - input_xyz: true - xyz_ray_dir_in_camera_coords: false - color_dim: 3 - transformer_dim_down_factor: 1.0 - n_hidden_neurons_xyz: 256 - n_layers_xyz: 8 - append_xyz: - - 5 implicit_function_NeRFormerImplicitFunction_args: n_harmonic_functions_xyz: 10 n_harmonic_functions_dir: 4 @@ -229,24 +216,19 @@ generic_model_args: n_layers_xyz: 2 append_xyz: - 1 - implicit_function_SRNImplicitFunction_args: - raymarch_function_args: - n_harmonic_functions: 3 - n_hidden_units: 256 - n_layers: 2 - in_features: 3 - out_features: 256 - latent_dim: 0 - xyz_in_camera_coords: false - raymarch_function: null - pixel_generator_args: - n_harmonic_functions: 4 - n_hidden_units: 256 - n_hidden_units_color: 128 - n_layers: 2 - in_features: 256 - out_features: 3 - ray_dir_in_camera_coords: false + implicit_function_NeuralRadianceFieldImplicitFunction_args: + n_harmonic_functions_xyz: 10 + n_harmonic_functions_dir: 4 + n_hidden_neurons_dir: 128 + latent_dim: 0 + input_xyz: true + xyz_ray_dir_in_camera_coords: false + color_dim: 3 + transformer_dim_down_factor: 1.0 + n_hidden_neurons_xyz: 256 + n_layers_xyz: 8 + append_xyz: + - 5 implicit_function_SRNHyperNetImplicitFunction_args: hypernet_args: n_harmonic_functions: 3 @@ -267,6 +249,24 @@ generic_model_args: in_features: 256 out_features: 3 ray_dir_in_camera_coords: false + implicit_function_SRNImplicitFunction_args: + raymarch_function_args: + n_harmonic_functions: 3 + n_hidden_units: 256 + n_layers: 2 + in_features: 3 + out_features: 256 + latent_dim: 0 + xyz_in_camera_coords: false + raymarch_function: null + pixel_generator_args: + n_harmonic_functions: 4 + n_hidden_units: 256 + n_hidden_units_color: 128 + n_layers: 2 + in_features: 256 + out_features: 3 + ray_dir_in_camera_coords: false solver_args: breed: adam weight_decay: 0.0 @@ -282,6 +282,13 @@ solver_args: data_source_args: dataset_map_provider_class_type: ??? data_loader_map_provider_class_type: SequenceDataLoaderMapProvider + dataset_map_provider_BlenderDatasetMapProvider_args: + base_dir: ??? + object_name: ??? + path_manager_factory_class_type: PathManagerFactory + n_known_frames_for_test: null + path_manager_factory_PathManagerFactory_args: + silence_logs: true dataset_map_provider_JsonIndexDatasetMapProvider_args: category: ??? task_str: singlesequence @@ -317,6 +324,13 @@ data_source_args: sort_frames: false path_manager_factory_PathManagerFactory_args: silence_logs: true + dataset_map_provider_LlffDatasetMapProvider_args: + base_dir: ??? + object_name: ??? + path_manager_factory_class_type: PathManagerFactory + n_known_frames_for_test: null + path_manager_factory_PathManagerFactory_args: + silence_logs: true data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 diff --git a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py new file mode 100644 index 00000000..c37a3a60 --- /dev/null +++ b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from pytorch3d.implicitron.tools.config import registry + +from .load_blender import load_blender_data +from .single_sequence_dataset import ( + _interpret_blender_cameras, + SingleSceneDatasetMapProviderBase, +) + + +@registry.register +class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase): + """ + Provides data for one scene from Blender synthetic dataset. + Uses the code in load_blender.py + + Members: + base_dir: directory holding the data for the scene. + object_name: The name of the scene (e.g. "lego"). This is just used as a label. + It will typically be equal to the name of the directory self.base_dir. + path_manager_factory: Creates path manager which may be used for + interpreting paths. + n_known_frames_for_test: If set, training frames are included in the val + and test datasets, and this many random training frames are added to + each test batch. If not set, test batches each contain just a single + testing frame. + """ + + def _load_data(self) -> None: + path_manager = self.path_manager_factory.get() + images, poses, _, hwf, i_split = load_blender_data( + self.base_dir, + testskip=1, + path_manager=path_manager, + ) + H, W, focal = hwf + H, W = int(H), int(W) + images = torch.from_numpy(images) + + # pyre-ignore[16] + self.poses = _interpret_blender_cameras(poses, H, W, focal) + # pyre-ignore[16] + self.images = images + # pyre-ignore[16] + self.i_split = i_split diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index dc2e7054..d8c1f805 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -8,9 +8,11 @@ from typing import Tuple from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation -from . import json_index_dataset_map_provider # noqa +from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task +from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa +from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa class DataSourceBase(ReplaceableBase): diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 11a0cbae..4b5501ae 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -36,10 +36,11 @@ class FrameData(Mapping[str, Any]): Args: frame_number: The number of the frame within its sequence. 0-based continuous integers. - frame_timestamp: The time elapsed since the start of a sequence in sec. sequence_name: The unique name of the frame's sequence. sequence_category: The object category of the sequence. - image_size_hw: The size of the image in pixels; (height, width) tuple. + frame_timestamp: The time elapsed since the start of a sequence in sec. + image_size_hw: The size of the image in pixels; (height, width) tensor + of shape (2,). image_path: The qualified path to the loaded image (with dataset_root). image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image of the frame; elements are floats in [0, 1]. @@ -81,9 +82,9 @@ class FrameData(Mapping[str, Any]): """ frame_number: Optional[torch.LongTensor] - frame_timestamp: Optional[torch.Tensor] sequence_name: Union[str, List[str]] sequence_category: Union[str, List[str]] + frame_timestamp: Optional[torch.Tensor] = None image_size_hw: Optional[torch.Tensor] = None image_path: Union[str, List[str], None] = None image_rgb: Optional[torch.Tensor] = None @@ -101,7 +102,7 @@ class FrameData(Mapping[str, Any]): sequence_point_cloud_path: Union[str, List[str], None] = None sequence_point_cloud: Optional[Pointclouds] = None sequence_point_cloud_idx: Optional[torch.Tensor] = None - frame_type: Union[str, List[str], None] = None # seen | unseen + frame_type: Union[str, List[str], None] = None # known | unseen meta: dict = field(default_factory=lambda: {}) def to(self, *args, **kwargs): diff --git a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py new file mode 100644 index 00000000..c4e180f3 --- /dev/null +++ b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import numpy as np +import torch +from pytorch3d.implicitron.tools.config import registry + +from .load_llff import load_llff_data + +from .single_sequence_dataset import ( + _interpret_blender_cameras, + SingleSceneDatasetMapProviderBase, +) + + +@registry.register +class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase): + """ + Provides data for one scene from the LLFF dataset. + + Members: + base_dir: directory holding the data for the scene. + object_name: The name of the scene (e.g. "fern"). This is just used as a label. + It will typically be equal to the name of the directory self.base_dir. + path_manager_factory: Creates path manager which may be used for + interpreting paths. + n_known_frames_for_test: If set, training frames are included in the val + and test datasets, and this many random training frames are added to + each test batch. If not set, test batches each contain just a single + testing frame. + """ + + def _load_data(self) -> None: + path_manager = self.path_manager_factory.get() + images, poses, _ = load_llff_data( + self.base_dir, factor=8, path_manager=path_manager + ) + hwf = poses[0, :3, -1] + poses = poses[:, :3, :4] + + i_test = np.arange(images.shape[0])[::8] + i_test_index = set(i_test.tolist()) + i_train = np.array( + [i for i in np.arange(images.shape[0]) if i not in i_test_index] + ) + i_split = (i_train, i_test, i_test) + H, W, focal = hwf + H, W = int(H), int(W) + images = torch.from_numpy(images) + poses = torch.from_numpy(poses) + + # pyre-ignore[16] + self.poses = _interpret_blender_cameras(poses, H, W, focal) + # pyre-ignore[16] + self.images = images + # pyre-ignore[16] + self.i_split = i_split diff --git a/pytorch3d/implicitron/dataset/load_blender.py b/pytorch3d/implicitron/dataset/load_blender.py new file mode 100644 index 00000000..f1bdeb1c --- /dev/null +++ b/pytorch3d/implicitron/dataset/load_blender.py @@ -0,0 +1,131 @@ +# @lint-ignore-every LICENSELINT +# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py +# Copyright (c) 2020 bmild +import json +import os + +import numpy as np +import torch +from PIL import Image + + +def translate_by_t_along_z(t): + tform = np.eye(4).astype(np.float32) + tform[2][3] = t + return tform + + +def rotate_by_phi_along_x(phi): + tform = np.eye(4).astype(np.float32) + tform[1, 1] = tform[2, 2] = np.cos(phi) + tform[1, 2] = -np.sin(phi) + tform[2, 1] = -tform[1, 2] + return tform + + +def rotate_by_theta_along_y(theta): + tform = np.eye(4).astype(np.float32) + tform[0, 0] = tform[2, 2] = np.cos(theta) + tform[0, 2] = -np.sin(theta) + tform[2, 0] = -tform[0, 2] + return tform + + +def pose_spherical(theta, phi, radius): + c2w = translate_by_t_along_z(radius) + c2w = rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w + c2w = rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w + c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w + return c2w + + +def _local_path(path_manager, path): + if path_manager is None: + return path + return path_manager.get_local_path(path) + + +def load_blender_data( + basedir, half_res=False, testskip=1, debug=False, path_manager=None +): + splits = ["train", "val", "test"] + metas = {} + for s in splits: + path = os.path.join(basedir, f"transforms_{s}.json") + with open(_local_path(path_manager, path)) as fp: + metas[s] = json.load(fp) + + all_imgs = [] + all_poses = [] + counts = [0] + for s in splits: + meta = metas[s] + imgs = [] + poses = [] + if s == "train" or testskip == 0: + skip = 1 + else: + skip = testskip + + for frame in meta["frames"][::skip]: + fname = os.path.join(basedir, frame["file_path"] + ".png") + imgs.append(np.array(Image.open(_local_path(path_manager, fname)))) + poses.append(np.array(frame["transform_matrix"])) + imgs = (np.array(imgs) / 255.0).astype(np.float32) + poses = np.array(poses).astype(np.float32) + counts.append(counts[-1] + imgs.shape[0]) + all_imgs.append(imgs) + all_poses.append(poses) + + i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)] + + imgs = np.concatenate(all_imgs, 0) + poses = np.concatenate(all_poses, 0) + + H, W = imgs[0].shape[:2] + camera_angle_x = float(meta["camera_angle_x"]) + focal = 0.5 * W / np.tan(0.5 * camera_angle_x) + + render_poses = torch.stack( + [ + torch.from_numpy(pose_spherical(angle, -30.0, 4.0)) + for angle in np.linspace(-180, 180, 40 + 1)[:-1] + ], + 0, + ) + + # In debug mode, return extremely tiny images + if debug: + import cv2 + + H = H // 32 + W = W // 32 + focal = focal / 32.0 + imgs = [ + torch.from_numpy( + cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA) + ) + for i in range(imgs.shape[0]) + ] + imgs = torch.stack(imgs, 0) + poses = torch.from_numpy(poses) + return imgs, poses, render_poses, [H, W, focal], i_split + + if half_res: + import cv2 + + # TODO: resize images using INTER_AREA (cv2) + H = H // 2 + W = W // 2 + focal = focal / 2.0 + imgs = [ + torch.from_numpy( + cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA) + ) + for i in range(imgs.shape[0]) + ] + imgs = torch.stack(imgs, 0) + + poses = torch.from_numpy(poses) + + return imgs, poses, render_poses, [H, W, focal], i_split diff --git a/pytorch3d/implicitron/dataset/load_llff.py b/pytorch3d/implicitron/dataset/load_llff.py new file mode 100644 index 00000000..eb508a8e --- /dev/null +++ b/pytorch3d/implicitron/dataset/load_llff.py @@ -0,0 +1,343 @@ +# @lint-ignore-every LICENSELINT +# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py +# Copyright (c) 2020 bmild +import logging +import os +import warnings + +import numpy as np + +from PIL import Image + + +# Slightly modified version of LLFF data loading code +# see https://github.com/Fyusion/LLFF for original + +logger = logging.getLogger(__name__) + + +def _minify(basedir, path_manager, factors=(), resolutions=()): + needtoload = False + for r in factors: + imgdir = os.path.join(basedir, "images_{}".format(r)) + if not _exists(path_manager, imgdir): + needtoload = True + for r in resolutions: + imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0])) + if not _exists(path_manager, imgdir): + needtoload = True + if not needtoload: + return + assert path_manager is None + + from subprocess import check_output + + imgdir = os.path.join(basedir, "images") + imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))] + imgs = [ + f + for f in imgs + if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]]) + ] + imgdir_orig = imgdir + + wd = os.getcwd() + + for r in factors + resolutions: + if isinstance(r, int): + name = "images_{}".format(r) + resizearg = "{}%".format(100.0 / r) + else: + name = "images_{}x{}".format(r[1], r[0]) + resizearg = "{}x{}".format(r[1], r[0]) + imgdir = os.path.join(basedir, name) + if os.path.exists(imgdir): + continue + + logger.info(f"Minifying {r}, {basedir}") + + os.makedirs(imgdir) + check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True) + + ext = imgs[0].split(".")[-1] + args = " ".join( + ["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)] + ) + logger.info(args) + os.chdir(imgdir) + check_output(args, shell=True) + os.chdir(wd) + + if ext != "png": + check_output("rm {}/*.{}".format(imgdir, ext), shell=True) + logger.info("Removed duplicates") + logger.info("Done") + + +def _load_data( + basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None +): + + poses_arr = np.load( + _local_path(path_manager, os.path.join(basedir, "poses_bounds.npy")) + ) + poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0]) + bds = poses_arr[:, -2:].transpose([1, 0]) + + img0 = [ + os.path.join(basedir, "images", f) + for f in sorted(_ls(path_manager, os.path.join(basedir, "images"))) + if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") + ][0] + + def imread(f): + return np.array(Image.open(f)) + + sh = imread(_local_path(path_manager, img0)).shape + + sfx = "" + + if factor is not None: + sfx = "_{}".format(factor) + _minify(basedir, path_manager, factors=[factor]) + factor = factor + elif height is not None: + factor = sh[0] / float(height) + width = int(sh[1] / factor) + _minify(basedir, path_manager, resolutions=[[height, width]]) + sfx = "_{}x{}".format(width, height) + elif width is not None: + factor = sh[1] / float(width) + height = int(sh[0] / factor) + _minify(basedir, path_manager, resolutions=[[height, width]]) + sfx = "_{}x{}".format(width, height) + else: + factor = 1 + + imgdir = os.path.join(basedir, "images" + sfx) + if not _exists(path_manager, imgdir): + raise ValueError(f"{imgdir} does not exist, returning") + + imgfiles = [ + _local_path(path_manager, os.path.join(imgdir, f)) + for f in sorted(_ls(path_manager, imgdir)) + if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png") + ] + if poses.shape[-1] != len(imgfiles): + raise ValueError( + "Mismatch between imgs {} and poses {} !!!!".format( + len(imgfiles), poses.shape[-1] + ) + ) + + sh = imread(imgfiles[0]).shape + poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1]) + poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor + + if not load_imgs: + return poses, bds + + imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles] + imgs = np.stack(imgs, -1) + + logger.info(f"Loaded image data, shape {imgs.shape}") + return poses, bds, imgs + + +def normalize(x): + denom = np.linalg.norm(x) + if denom < 0.001: + warnings.warn("unsafe normalize()") + return x / denom + + +def viewmatrix(z, up, pos): + vec2 = normalize(z) + vec1_avg = up + vec0 = normalize(np.cross(vec1_avg, vec2)) + vec1 = normalize(np.cross(vec2, vec0)) + m = np.stack([vec0, vec1, vec2, pos], 1) + return m + + +def ptstocam(pts, c2w): + tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0] + return tt + + +def poses_avg(poses): + + hwf = poses[0, :3, -1:] + + center = poses[:, :3, 3].mean(0) + vec2 = normalize(poses[:, :3, 2].sum(0)) + up = poses[:, :3, 1].sum(0) + c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1) + + return c2w + + +def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N): + render_poses = [] + rads = np.array(list(rads) + [1.0]) + hwf = c2w[:, 4:5] + + for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]: + c = np.dot( + c2w[:3, :4], + np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]) + * rads, + ) + z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0]))) + render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1)) + return render_poses + + +def recenter_poses(poses): + + poses_ = poses + 0 + bottom = np.reshape([0, 0, 0, 1.0], [1, 4]) + c2w = poses_avg(poses) + c2w = np.concatenate([c2w[:3, :4], bottom], -2) + bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1]) + poses = np.concatenate([poses[:, :3, :4], bottom], -2) + + poses = np.linalg.inv(c2w) @ poses + poses_[:, :3, :4] = poses[:, :3, :4] + poses = poses_ + return poses + + +def spherify_poses(poses, bds): + def add_row_to_homogenize_transform(p): + r"""Add the last row to homogenize 3 x 4 transformation matrices.""" + return np.concatenate( + [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 + ) + + # p34_to_44 = lambda p: np.concatenate( + # [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1 + # ) + + p34_to_44 = add_row_to_homogenize_transform + + rays_d = poses[:, :3, 2:3] + rays_o = poses[:, :3, 3:4] + + def min_line_dist(rays_o, rays_d): + A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1]) + b_i = -A_i @ rays_o + pt_mindist = np.squeeze( + -np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0) + ) + return pt_mindist + + pt_mindist = min_line_dist(rays_o, rays_d) + + center = pt_mindist + up = (poses[:, :3, 3] - center).mean(0) + + vec0 = normalize(up) + vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0)) + vec2 = normalize(np.cross(vec0, vec1)) + pos = center + c2w = np.stack([vec1, vec2, vec0, pos], 1) + + poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4]) + + rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1))) + + sc = 1.0 / rad + poses_reset[:, :3, 3] *= sc + bds *= sc + rad *= sc + + centroid = np.mean(poses_reset[:, :3, 3], 0) + zh = centroid[2] + radcircle = np.sqrt(rad**2 - zh**2) + new_poses = [] + + for th in np.linspace(0.0, 2.0 * np.pi, 120): + + camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh]) + up = np.array([0, 0, -1.0]) + + vec2 = normalize(camorigin) + vec0 = normalize(np.cross(vec2, up)) + vec1 = normalize(np.cross(vec2, vec0)) + pos = camorigin + p = np.stack([vec0, vec1, vec2, pos], 1) + + new_poses.append(p) + + new_poses = np.stack(new_poses, 0) + + new_poses = np.concatenate( + [new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1 + ) + poses_reset = np.concatenate( + [ + poses_reset[:, :3, :4], + np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape), + ], + -1, + ) + + return poses_reset, new_poses, bds + + +def _local_path(path_manager, path): + if path_manager is None: + return path + return path_manager.get_local_path(path) + + +def _ls(path_manager, path): + if path_manager is None: + return os.path.listdir(path) + return path_manager.ls(path) + + +def _exists(path_manager, path): + if path_manager is None: + return os.path.exists(path) + return path_manager.exists(path) + + +def load_llff_data( + basedir, + factor=8, + recenter=True, + bd_factor=0.75, + spherify=False, + path_zflat=False, + path_manager=None, +): + + poses, bds, imgs = _load_data( + basedir, factor=factor, path_manager=path_manager + ) # factor=8 downsamples original imgs by 8x + logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}") + + # Correct rotation matrix ordering and move variable dim to axis 0 + poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) + poses = np.moveaxis(poses, -1, 0).astype(np.float32) + imgs = np.moveaxis(imgs, -1, 0).astype(np.float32) + images = imgs + bds = np.moveaxis(bds, -1, 0).astype(np.float32) + + # Rescale if bd_factor is provided + sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor) + poses[:, :3, 3] *= sc + bds *= sc + + if recenter: + poses = recenter_poses(poses) + + if spherify: + poses, render_poses, bds = spherify_poses(poses, bds) + + images = images.astype(np.float32) + poses = poses.astype(np.float32) + + return images, poses, bds diff --git a/pytorch3d/implicitron/dataset/single_sequence_dataset.py b/pytorch3d/implicitron/dataset/single_sequence_dataset.py new file mode 100644 index 00000000..b52118ce --- /dev/null +++ b/pytorch3d/implicitron/dataset/single_sequence_dataset.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# This file defines a base class for dataset map providers which +# provide data for a single scene. + +from dataclasses import field +from typing import Iterable, List, Optional + +import numpy as np +import torch +from pytorch3d.implicitron.tools.config import ( + Configurable, + expand_args_fields, + run_auto_creation, +) +from pytorch3d.renderer import PerspectiveCameras + +from .dataset_base import DatasetBase, FrameData +from .dataset_map_provider import ( + DatasetMap, + DatasetMapProviderBase, + PathManagerFactory, + Task, +) +from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN + +_SINGLE_SEQUENCE_NAME: str = "one_sequence" + + +class SingleSceneDataset(DatasetBase, Configurable): + """ + A dataset from images from a single scene. + """ + + images: List[torch.Tensor] = field() + poses: List[PerspectiveCameras] = field() + object_name: str = field() + frame_types: List[str] = field() + eval_batches: Optional[List[List[int]]] = field() + + def sequence_names(self) -> Iterable[str]: + return [_SINGLE_SEQUENCE_NAME] + + def __len__(self) -> int: + return len(self.poses) + + def __getitem__(self, index) -> FrameData: + if index >= len(self): + raise IndexError(f"index {index} out of range {len(self)}") + image = self.images[index] + pose = self.poses[index] + frame_type = self.frame_types[index] + + frame_data = FrameData( + frame_number=index, + sequence_name=_SINGLE_SEQUENCE_NAME, + sequence_category=self.object_name, + camera=pose, + image_size_hw=torch.tensor(image.shape[1:]), + image_rgb=image, + frame_type=frame_type, + ) + return frame_data + + def get_eval_batches(self) -> Optional[List[List[int]]]: + return self.eval_batches + + +# pyre-fixme[13]: Uninitialized attribute +class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase): + """ + Base for provider of data for one scene from LLFF or blender datasets. + + Members: + base_dir: directory holding the data for the scene. + object_name: The name of the scene (e.g. "lego"). This is just used as a label. + It will typically be equal to the name of the directory self.base_dir. + path_manager_factory: Creates path manager which may be used for + interpreting paths. + n_known_frames_for_test: If set, training frames are included in the val + and test datasets, and this many random training frames are added to + each test batch. If not set, test batches each contain just a single + testing frame. + """ + + base_dir: str + object_name: str + path_manager_factory: PathManagerFactory + path_manager_factory_class_type: str = "PathManagerFactory" + n_known_frames_for_test: Optional[int] = None + + def __post_init__(self) -> None: + run_auto_creation(self) + self._load_data() + + def _load_data(self) -> None: + # This must be defined by each subclass, + # and should set poses, images and i_split on self. + raise NotImplementedError + + def _get_dataset( + self, split_idx: int, frame_type: str, set_eval_batches: bool = False + ) -> SingleSceneDataset: + expand_args_fields(SingleSceneDataset) + # pyre-ignore[16] + split = self.i_split[split_idx] + frame_types = [frame_type] * len(split) + eval_batches = [[i] for i in range(len(split))] + if split_idx != 0 and self.n_known_frames_for_test is not None: + train_split = self.i_split[0] + if set_eval_batches: + generator = np.random.default_rng(seed=0) + for batch in eval_batches: + to_add = generator.choice( + len(train_split), self.n_known_frames_for_test + ) + batch.extend((to_add + len(split)).tolist()) + split = np.concatenate([split, train_split]) + frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split)) + + # pyre-ignore[28] + return SingleSceneDataset( + object_name=self.object_name, + # pyre-ignore[16] + images=self.images[split], + # pyre-ignore[16] + poses=[self.poses[i] for i in split], + frame_types=frame_types, + eval_batches=eval_batches if set_eval_batches else None, + ) + + def get_dataset_map(self) -> DatasetMap: + return DatasetMap( + train=self._get_dataset(0, DATASET_TYPE_KNOWN), + val=self._get_dataset(1, DATASET_TYPE_UNKNOWN), + test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True), + ) + + def get_task(self) -> Task: + return Task.SINGLE_SEQUENCE + + +def _interpret_blender_cameras( + poses: torch.Tensor, H: int, W: int, focal: float +) -> List[PerspectiveCameras]: + """ + Convert 4x4 matrices representing cameras in blender format + to PyTorch3D format. + + Args: + poses: N x 3 x 4 camera matrices + """ + pose_target_cameras = [] + for pose_target in poses: + pose_target = pose_target[:3, :4] + mtx = torch.eye(4, dtype=pose_target.dtype) + mtx[:3, :3] = pose_target[:3, :3].t() + mtx[3, :3] = pose_target[:, 3] + mtx = mtx.inverse() + + # flip the XZ coordinates. + mtx[:, [0, 2]] *= -1.0 + + Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0) + + focal_length_pt3 = torch.FloatTensor([[-focal, focal]]) + principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]]) + + cameras = PerspectiveCameras( + focal_length=focal_length_pt3, + principal_point=principal_point_pt3, + R=Rpt3[None], + T=Tpt3, + ) + pose_target_cameras.append(cameras) + return pose_target_cameras diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 11e344e0..79cda30a 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -220,6 +220,7 @@ class Configurable: _X = TypeVar("X", bound=ReplaceableBase) +_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable]) class _Registry: @@ -307,20 +308,23 @@ class _Registry: It determines the namespace. This will typically be a direct subclass of ReplaceableBase. Returns: - list of class types + list of class types in alphabetical order of registered name. """ if self._is_base_class(base_class_wanted): - return list(self._mapping[base_class_wanted].values()) + source = self._mapping[base_class_wanted] + return [source[key] for key in sorted(source)] base_class = self._base_class_from_class(base_class_wanted) if base_class is None: raise ValueError( f"Cannot look up {base_class_wanted}. Cannot tell what it is." ) + source = self._mapping[base_class] return [ - class_ - for class_ in self._mapping[base_class].values() - if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted + source[key] + for key in sorted(source) + if issubclass(source[key], base_class_wanted) + and source[key] is not base_class_wanted ] @staticmethod @@ -647,8 +651,8 @@ def _is_actually_dataclass(some_class) -> bool: def expand_args_fields( - some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = () -) -> Type[_X]: + some_class: Type[_Y], *, _do_not_process: Tuple[type, ...] = () +) -> Type[_Y]: """ This expands a class which inherits Configurable or ReplaceableBase classes, including dataclass processing. some_class is modified in place by this function. diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 4e566f5e..437cbad4 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -13,6 +13,7 @@ from .blending import ( from .camera_utils import join_cameras_as_batch, rotate_on_spot from .cameras import ( # deprecated # deprecated # deprecated # deprecated camera_position_from_spherical_angles, + CamerasBase, FoVOrthographicCameras, FoVPerspectiveCameras, get_world_to_view_transform, diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 64aa1283..12113ff5 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -1,5 +1,12 @@ dataset_map_provider_class_type: ??? data_loader_map_provider_class_type: SequenceDataLoaderMapProvider +dataset_map_provider_BlenderDatasetMapProvider_args: + base_dir: ??? + object_name: ??? + path_manager_factory_class_type: PathManagerFactory + n_known_frames_for_test: null + path_manager_factory_PathManagerFactory_args: + silence_logs: true dataset_map_provider_JsonIndexDatasetMapProvider_args: category: ??? task_str: singlesequence @@ -35,6 +42,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args: sort_frames: false path_manager_factory_PathManagerFactory_args: silence_logs: true +dataset_map_provider_LlffDatasetMapProvider_args: + base_dir: ??? + object_name: ??? + path_manager_factory_class_type: PathManagerFactory + n_known_frames_for_test: null + path_manager_factory_PathManagerFactory_args: + silence_logs: true data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 num_workers: 0 diff --git a/tests/implicitron/test_data_llff.py b/tests/implicitron/test_data_llff.py new file mode 100644 index 00000000..7dd69245 --- /dev/null +++ b/tests/implicitron/test_data_llff.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import unittest + +from pytorch3d.implicitron.dataset.blender_dataset_map_provider import ( + BlenderDatasetMapProvider, +) +from pytorch3d.implicitron.dataset.dataset_base import FrameData +from pytorch3d.implicitron.dataset.llff_dataset_map_provider import ( + LlffDatasetMapProvider, +) +from pytorch3d.implicitron.tools.config import expand_args_fields +from tests.common_testing import TestCaseMixin + + +# These tests are only run internally, where the data is available. +internal = os.environ.get("FB_TEST", False) +inside_re_worker = os.environ.get("INSIDE_RE_WORKER", False) +skip_tests = not internal or inside_re_worker + + +@unittest.skipIf(skip_tests, "no data") +class TestDataLlff(TestCaseMixin, unittest.TestCase): + def test_synthetic(self): + expand_args_fields(BlenderDatasetMapProvider) + + provider = BlenderDatasetMapProvider( + base_dir="manifold://co3d/tree/nerf_data/nerf_synthetic/lego", + object_name="lego", + ) + dataset_map = provider.get_dataset_map() + + for name, length in [("train", 100), ("val", 100), ("test", 200)]: + dataset = getattr(dataset_map, name) + self.assertEqual(len(dataset), length) + # try getting a value + value = dataset[0] + self.assertIsInstance(value, FrameData) + + def test_llff(self): + expand_args_fields(LlffDatasetMapProvider) + + provider = LlffDatasetMapProvider( + base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern", + object_name="fern", + ) + dataset_map = provider.get_dataset_map() + + for name, length, frame_type in [ + ("train", 17, "known"), + ("test", 3, "unseen"), + ("val", 3, "unseen"), + ]: + dataset = getattr(dataset_map, name) + self.assertEqual(len(dataset), length) + # try getting a value + value = dataset[0] + self.assertIsInstance(value, FrameData) + self.assertEqual(value.frame_type, frame_type) + + self.assertEqual(len(dataset_map.test.get_eval_batches()), 3) + for batch in dataset_map.test.get_eval_batches(): + self.assertEqual(len(batch), 1) + self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen") + + def test_include_known_frames(self): + expand_args_fields(LlffDatasetMapProvider) + + provider = LlffDatasetMapProvider( + base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern", + object_name="fern", + n_known_frames_for_test=2, + ) + dataset_map = provider.get_dataset_map() + + for name, types in [ + ("train", ["known"] * 17), + ("val", ["unseen"] * 3 + ["known"] * 17), + ("test", ["unseen"] * 3 + ["known"] * 17), + ]: + dataset = getattr(dataset_map, name) + self.assertEqual(len(dataset), len(types)) + for i, frame_type in enumerate(types): + value = dataset[i] + self.assertEqual(value.frame_type, frame_type) + + self.assertEqual(len(dataset_map.test.get_eval_batches()), 3) + for batch in dataset_map.test.get_eval_batches(): + self.assertEqual(len(batch), 3) + self.assertEqual(dataset_map.test[batch[0]].frame_type, "unseen") + for i in batch[1:]: + self.assertEqual(dataset_map.test[i].frame_type, "known") diff --git a/tests/implicitron/test_data_source.py b/tests/implicitron/test_data_source.py index d61957b0..e5823e05 100644 --- a/tests/implicitron/test_data_source.py +++ b/tests/implicitron/test_data_source.py @@ -6,6 +6,7 @@ import os import unittest +import unittest.mock from omegaconf import OmegaConf from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource