loading llff and blender datasets

Summary: Copy code from NeRF for loading LLFF data and blender synthetic data, and create dataset objects for them

Reviewed By: shapovalov

Differential Revision: D35581039

fbshipit-source-id: af7a6f3e9a42499700693381b5b147c991f57e5d
This commit is contained in:
Jeremy Reizenstein
2022-06-16 03:09:15 -07:00
committed by Facebook GitHub Bot
parent 7978ffd1e4
commit 65f667fd2e
16 changed files with 992 additions and 67 deletions

View File

@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import torch
from pytorch3d.implicitron.tools.config import registry
from .load_blender import load_blender_data
from .single_sequence_dataset import (
_interpret_blender_cameras,
SingleSceneDatasetMapProviderBase,
)
@registry.register
class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
"""
Provides data for one scene from Blender synthetic dataset.
Uses the code in load_blender.py
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
def _load_data(self) -> None:
path_manager = self.path_manager_factory.get()
images, poses, _, hwf, i_split = load_blender_data(
self.base_dir,
testskip=1,
path_manager=path_manager,
)
H, W, focal = hwf
H, W = int(H), int(W)
images = torch.from_numpy(images)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]
self.i_split = i_split

View File

@@ -8,9 +8,11 @@ from typing import Tuple
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
from . import json_index_dataset_map_provider # noqa
from .blender_dataset_map_provider import BlenderDatasetMapProvider # noqa
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
from .json_index_dataset_map_provider import JsonIndexDatasetMapProvider # noqa
from .llff_dataset_map_provider import LlffDatasetMapProvider # noqa
class DataSourceBase(ReplaceableBase):

View File

@@ -36,10 +36,11 @@ class FrameData(Mapping[str, Any]):
Args:
frame_number: The number of the frame within its sequence.
0-based continuous integers.
frame_timestamp: The time elapsed since the start of a sequence in sec.
sequence_name: The unique name of the frame's sequence.
sequence_category: The object category of the sequence.
image_size_hw: The size of the image in pixels; (height, width) tuple.
frame_timestamp: The time elapsed since the start of a sequence in sec.
image_size_hw: The size of the image in pixels; (height, width) tensor
of shape (2,).
image_path: The qualified path to the loaded image (with dataset_root).
image_rgb: A Tensor of shape `(3, H, W)` holding the RGB image
of the frame; elements are floats in [0, 1].
@@ -81,9 +82,9 @@ class FrameData(Mapping[str, Any]):
"""
frame_number: Optional[torch.LongTensor]
frame_timestamp: Optional[torch.Tensor]
sequence_name: Union[str, List[str]]
sequence_category: Union[str, List[str]]
frame_timestamp: Optional[torch.Tensor] = None
image_size_hw: Optional[torch.Tensor] = None
image_path: Union[str, List[str], None] = None
image_rgb: Optional[torch.Tensor] = None
@@ -101,7 +102,7 @@ class FrameData(Mapping[str, Any]):
sequence_point_cloud_path: Union[str, List[str], None] = None
sequence_point_cloud: Optional[Pointclouds] = None
sequence_point_cloud_idx: Optional[torch.Tensor] = None
frame_type: Union[str, List[str], None] = None # seen | unseen
frame_type: Union[str, List[str], None] = None # known | unseen
meta: dict = field(default_factory=lambda: {})
def to(self, *args, **kwargs):

View File

@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
from pytorch3d.implicitron.tools.config import registry
from .load_llff import load_llff_data
from .single_sequence_dataset import (
_interpret_blender_cameras,
SingleSceneDatasetMapProviderBase,
)
@registry.register
class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
"""
Provides data for one scene from the LLFF dataset.
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "fern"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
def _load_data(self) -> None:
path_manager = self.path_manager_factory.get()
images, poses, _ = load_llff_data(
self.base_dir, factor=8, path_manager=path_manager
)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
i_test = np.arange(images.shape[0])[::8]
i_test_index = set(i_test.tolist())
i_train = np.array(
[i for i in np.arange(images.shape[0]) if i not in i_test_index]
)
i_split = (i_train, i_test, i_test)
H, W, focal = hwf
H, W = int(H), int(W)
images = torch.from_numpy(images)
poses = torch.from_numpy(poses)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]
self.i_split = i_split

View File

@@ -0,0 +1,131 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/bmild/nerf/blob/master/load_blender.py
# Copyright (c) 2020 bmild
import json
import os
import numpy as np
import torch
from PIL import Image
def translate_by_t_along_z(t):
tform = np.eye(4).astype(np.float32)
tform[2][3] = t
return tform
def rotate_by_phi_along_x(phi):
tform = np.eye(4).astype(np.float32)
tform[1, 1] = tform[2, 2] = np.cos(phi)
tform[1, 2] = -np.sin(phi)
tform[2, 1] = -tform[1, 2]
return tform
def rotate_by_theta_along_y(theta):
tform = np.eye(4).astype(np.float32)
tform[0, 0] = tform[2, 2] = np.cos(theta)
tform[0, 2] = -np.sin(theta)
tform[2, 0] = -tform[0, 2]
return tform
def pose_spherical(theta, phi, radius):
c2w = translate_by_t_along_z(radius)
c2w = rotate_by_phi_along_x(phi / 180.0 * np.pi) @ c2w
c2w = rotate_by_theta_along_y(theta / 180 * np.pi) @ c2w
c2w = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]) @ c2w
return c2w
def _local_path(path_manager, path):
if path_manager is None:
return path
return path_manager.get_local_path(path)
def load_blender_data(
basedir, half_res=False, testskip=1, debug=False, path_manager=None
):
splits = ["train", "val", "test"]
metas = {}
for s in splits:
path = os.path.join(basedir, f"transforms_{s}.json")
with open(_local_path(path_manager, path)) as fp:
metas[s] = json.load(fp)
all_imgs = []
all_poses = []
counts = [0]
for s in splits:
meta = metas[s]
imgs = []
poses = []
if s == "train" or testskip == 0:
skip = 1
else:
skip = testskip
for frame in meta["frames"][::skip]:
fname = os.path.join(basedir, frame["file_path"] + ".png")
imgs.append(np.array(Image.open(_local_path(path_manager, fname))))
poses.append(np.array(frame["transform_matrix"]))
imgs = (np.array(imgs) / 255.0).astype(np.float32)
poses = np.array(poses).astype(np.float32)
counts.append(counts[-1] + imgs.shape[0])
all_imgs.append(imgs)
all_poses.append(poses)
i_split = [np.arange(counts[i], counts[i + 1]) for i in range(3)]
imgs = np.concatenate(all_imgs, 0)
poses = np.concatenate(all_poses, 0)
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
render_poses = torch.stack(
[
torch.from_numpy(pose_spherical(angle, -30.0, 4.0))
for angle in np.linspace(-180, 180, 40 + 1)[:-1]
],
0,
)
# In debug mode, return extremely tiny images
if debug:
import cv2
H = H // 32
W = W // 32
focal = focal / 32.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
)
for i in range(imgs.shape[0])
]
imgs = torch.stack(imgs, 0)
poses = torch.from_numpy(poses)
return imgs, poses, render_poses, [H, W, focal], i_split
if half_res:
import cv2
# TODO: resize images using INTER_AREA (cv2)
H = H // 2
W = W // 2
focal = focal / 2.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
)
for i in range(imgs.shape[0])
]
imgs = torch.stack(imgs, 0)
poses = torch.from_numpy(poses)
return imgs, poses, render_poses, [H, W, focal], i_split

View File

@@ -0,0 +1,343 @@
# @lint-ignore-every LICENSELINT
# Adapted from https://github.com/bmild/nerf/blob/master/load_llff.py
# Copyright (c) 2020 bmild
import logging
import os
import warnings
import numpy as np
from PIL import Image
# Slightly modified version of LLFF data loading code
# see https://github.com/Fyusion/LLFF for original
logger = logging.getLogger(__name__)
def _minify(basedir, path_manager, factors=(), resolutions=()):
needtoload = False
for r in factors:
imgdir = os.path.join(basedir, "images_{}".format(r))
if not _exists(path_manager, imgdir):
needtoload = True
for r in resolutions:
imgdir = os.path.join(basedir, "images_{}x{}".format(r[1], r[0]))
if not _exists(path_manager, imgdir):
needtoload = True
if not needtoload:
return
assert path_manager is None
from subprocess import check_output
imgdir = os.path.join(basedir, "images")
imgs = [os.path.join(imgdir, f) for f in sorted(_ls(path_manager, imgdir))]
imgs = [
f
for f in imgs
if any([f.endswith(ex) for ex in ["JPG", "jpg", "png", "jpeg", "PNG"]])
]
imgdir_orig = imgdir
wd = os.getcwd()
for r in factors + resolutions:
if isinstance(r, int):
name = "images_{}".format(r)
resizearg = "{}%".format(100.0 / r)
else:
name = "images_{}x{}".format(r[1], r[0])
resizearg = "{}x{}".format(r[1], r[0])
imgdir = os.path.join(basedir, name)
if os.path.exists(imgdir):
continue
logger.info(f"Minifying {r}, {basedir}")
os.makedirs(imgdir)
check_output("cp {}/* {}".format(imgdir_orig, imgdir), shell=True)
ext = imgs[0].split(".")[-1]
args = " ".join(
["mogrify", "-resize", resizearg, "-format", "png", "*.{}".format(ext)]
)
logger.info(args)
os.chdir(imgdir)
check_output(args, shell=True)
os.chdir(wd)
if ext != "png":
check_output("rm {}/*.{}".format(imgdir, ext), shell=True)
logger.info("Removed duplicates")
logger.info("Done")
def _load_data(
basedir, factor=None, width=None, height=None, load_imgs=True, path_manager=None
):
poses_arr = np.load(
_local_path(path_manager, os.path.join(basedir, "poses_bounds.npy"))
)
poses = poses_arr[:, :-2].reshape([-1, 3, 5]).transpose([1, 2, 0])
bds = poses_arr[:, -2:].transpose([1, 0])
img0 = [
os.path.join(basedir, "images", f)
for f in sorted(_ls(path_manager, os.path.join(basedir, "images")))
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
][0]
def imread(f):
return np.array(Image.open(f))
sh = imread(_local_path(path_manager, img0)).shape
sfx = ""
if factor is not None:
sfx = "_{}".format(factor)
_minify(basedir, path_manager, factors=[factor])
factor = factor
elif height is not None:
factor = sh[0] / float(height)
width = int(sh[1] / factor)
_minify(basedir, path_manager, resolutions=[[height, width]])
sfx = "_{}x{}".format(width, height)
elif width is not None:
factor = sh[1] / float(width)
height = int(sh[0] / factor)
_minify(basedir, path_manager, resolutions=[[height, width]])
sfx = "_{}x{}".format(width, height)
else:
factor = 1
imgdir = os.path.join(basedir, "images" + sfx)
if not _exists(path_manager, imgdir):
raise ValueError(f"{imgdir} does not exist, returning")
imgfiles = [
_local_path(path_manager, os.path.join(imgdir, f))
for f in sorted(_ls(path_manager, imgdir))
if f.endswith("JPG") or f.endswith("jpg") or f.endswith("png")
]
if poses.shape[-1] != len(imgfiles):
raise ValueError(
"Mismatch between imgs {} and poses {} !!!!".format(
len(imgfiles), poses.shape[-1]
)
)
sh = imread(imgfiles[0]).shape
poses[:2, 4, :] = np.array(sh[:2]).reshape([2, 1])
poses[2, 4, :] = poses[2, 4, :] * 1.0 / factor
if not load_imgs:
return poses, bds
imgs = imgs = [imread(f)[..., :3] / 255.0 for f in imgfiles]
imgs = np.stack(imgs, -1)
logger.info(f"Loaded image data, shape {imgs.shape}")
return poses, bds, imgs
def normalize(x):
denom = np.linalg.norm(x)
if denom < 0.001:
warnings.warn("unsafe normalize()")
return x / denom
def viewmatrix(z, up, pos):
vec2 = normalize(z)
vec1_avg = up
vec0 = normalize(np.cross(vec1_avg, vec2))
vec1 = normalize(np.cross(vec2, vec0))
m = np.stack([vec0, vec1, vec2, pos], 1)
return m
def ptstocam(pts, c2w):
tt = np.matmul(c2w[:3, :3].T, (pts - c2w[:3, 3])[..., np.newaxis])[..., 0]
return tt
def poses_avg(poses):
hwf = poses[0, :3, -1:]
center = poses[:, :3, 3].mean(0)
vec2 = normalize(poses[:, :3, 2].sum(0))
up = poses[:, :3, 1].sum(0)
c2w = np.concatenate([viewmatrix(vec2, up, center), hwf], 1)
return c2w
def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, rots, N):
render_poses = []
rads = np.array(list(rads) + [1.0])
hwf = c2w[:, 4:5]
for theta in np.linspace(0.0, 2.0 * np.pi * rots, N + 1)[:-1]:
c = np.dot(
c2w[:3, :4],
np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0])
* rads,
)
z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.0])))
render_poses.append(np.concatenate([viewmatrix(z, up, c), hwf], 1))
return render_poses
def recenter_poses(poses):
poses_ = poses + 0
bottom = np.reshape([0, 0, 0, 1.0], [1, 4])
c2w = poses_avg(poses)
c2w = np.concatenate([c2w[:3, :4], bottom], -2)
bottom = np.tile(np.reshape(bottom, [1, 1, 4]), [poses.shape[0], 1, 1])
poses = np.concatenate([poses[:, :3, :4], bottom], -2)
poses = np.linalg.inv(c2w) @ poses
poses_[:, :3, :4] = poses[:, :3, :4]
poses = poses_
return poses
def spherify_poses(poses, bds):
def add_row_to_homogenize_transform(p):
r"""Add the last row to homogenize 3 x 4 transformation matrices."""
return np.concatenate(
[p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
)
# p34_to_44 = lambda p: np.concatenate(
# [p, np.tile(np.reshape(np.eye(4)[-1, :], [1, 1, 4]), [p.shape[0], 1, 1])], 1
# )
p34_to_44 = add_row_to_homogenize_transform
rays_d = poses[:, :3, 2:3]
rays_o = poses[:, :3, 3:4]
def min_line_dist(rays_o, rays_d):
A_i = np.eye(3) - rays_d * np.transpose(rays_d, [0, 2, 1])
b_i = -A_i @ rays_o
pt_mindist = np.squeeze(
-np.linalg.inv((np.transpose(A_i, [0, 2, 1]) @ A_i).mean(0)) @ (b_i).mean(0)
)
return pt_mindist
pt_mindist = min_line_dist(rays_o, rays_d)
center = pt_mindist
up = (poses[:, :3, 3] - center).mean(0)
vec0 = normalize(up)
vec1 = normalize(np.cross([0.1, 0.2, 0.3], vec0))
vec2 = normalize(np.cross(vec0, vec1))
pos = center
c2w = np.stack([vec1, vec2, vec0, pos], 1)
poses_reset = np.linalg.inv(p34_to_44(c2w[None])) @ p34_to_44(poses[:, :3, :4])
rad = np.sqrt(np.mean(np.sum(np.square(poses_reset[:, :3, 3]), -1)))
sc = 1.0 / rad
poses_reset[:, :3, 3] *= sc
bds *= sc
rad *= sc
centroid = np.mean(poses_reset[:, :3, 3], 0)
zh = centroid[2]
radcircle = np.sqrt(rad**2 - zh**2)
new_poses = []
for th in np.linspace(0.0, 2.0 * np.pi, 120):
camorigin = np.array([radcircle * np.cos(th), radcircle * np.sin(th), zh])
up = np.array([0, 0, -1.0])
vec2 = normalize(camorigin)
vec0 = normalize(np.cross(vec2, up))
vec1 = normalize(np.cross(vec2, vec0))
pos = camorigin
p = np.stack([vec0, vec1, vec2, pos], 1)
new_poses.append(p)
new_poses = np.stack(new_poses, 0)
new_poses = np.concatenate(
[new_poses, np.broadcast_to(poses[0, :3, -1:], new_poses[:, :3, -1:].shape)], -1
)
poses_reset = np.concatenate(
[
poses_reset[:, :3, :4],
np.broadcast_to(poses[0, :3, -1:], poses_reset[:, :3, -1:].shape),
],
-1,
)
return poses_reset, new_poses, bds
def _local_path(path_manager, path):
if path_manager is None:
return path
return path_manager.get_local_path(path)
def _ls(path_manager, path):
if path_manager is None:
return os.path.listdir(path)
return path_manager.ls(path)
def _exists(path_manager, path):
if path_manager is None:
return os.path.exists(path)
return path_manager.exists(path)
def load_llff_data(
basedir,
factor=8,
recenter=True,
bd_factor=0.75,
spherify=False,
path_zflat=False,
path_manager=None,
):
poses, bds, imgs = _load_data(
basedir, factor=factor, path_manager=path_manager
) # factor=8 downsamples original imgs by 8x
logger.info(f"Loaded {basedir}, {bds.min()}, {bds.max()}")
# Correct rotation matrix ordering and move variable dim to axis 0
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1)
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
imgs = np.moveaxis(imgs, -1, 0).astype(np.float32)
images = imgs
bds = np.moveaxis(bds, -1, 0).astype(np.float32)
# Rescale if bd_factor is provided
sc = 1.0 if bd_factor is None else 1.0 / (bds.min() * bd_factor)
poses[:, :3, 3] *= sc
bds *= sc
if recenter:
poses = recenter_poses(poses)
if spherify:
poses, render_poses, bds = spherify_poses(poses, bds)
images = images.astype(np.float32)
poses = poses.astype(np.float32)
return images, poses, bds

View File

@@ -0,0 +1,181 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# This file defines a base class for dataset map providers which
# provide data for a single scene.
from dataclasses import field
from typing import Iterable, List, Optional
import numpy as np
import torch
from pytorch3d.implicitron.tools.config import (
Configurable,
expand_args_fields,
run_auto_creation,
)
from pytorch3d.renderer import PerspectiveCameras
from .dataset_base import DatasetBase, FrameData
from .dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
PathManagerFactory,
Task,
)
from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
class SingleSceneDataset(DatasetBase, Configurable):
"""
A dataset from images from a single scene.
"""
images: List[torch.Tensor] = field()
poses: List[PerspectiveCameras] = field()
object_name: str = field()
frame_types: List[str] = field()
eval_batches: Optional[List[List[int]]] = field()
def sequence_names(self) -> Iterable[str]:
return [_SINGLE_SEQUENCE_NAME]
def __len__(self) -> int:
return len(self.poses)
def __getitem__(self, index) -> FrameData:
if index >= len(self):
raise IndexError(f"index {index} out of range {len(self)}")
image = self.images[index]
pose = self.poses[index]
frame_type = self.frame_types[index]
frame_data = FrameData(
frame_number=index,
sequence_name=_SINGLE_SEQUENCE_NAME,
sequence_category=self.object_name,
camera=pose,
image_size_hw=torch.tensor(image.shape[1:]),
image_rgb=image,
frame_type=frame_type,
)
return frame_data
def get_eval_batches(self) -> Optional[List[List[int]]]:
return self.eval_batches
# pyre-fixme[13]: Uninitialized attribute
class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
"""
Base for provider of data for one scene from LLFF or blender datasets.
Members:
base_dir: directory holding the data for the scene.
object_name: The name of the scene (e.g. "lego"). This is just used as a label.
It will typically be equal to the name of the directory self.base_dir.
path_manager_factory: Creates path manager which may be used for
interpreting paths.
n_known_frames_for_test: If set, training frames are included in the val
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
"""
base_dir: str
object_name: str
path_manager_factory: PathManagerFactory
path_manager_factory_class_type: str = "PathManagerFactory"
n_known_frames_for_test: Optional[int] = None
def __post_init__(self) -> None:
run_auto_creation(self)
self._load_data()
def _load_data(self) -> None:
# This must be defined by each subclass,
# and should set poses, images and i_split on self.
raise NotImplementedError
def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
) -> SingleSceneDataset:
expand_args_fields(SingleSceneDataset)
# pyre-ignore[16]
split = self.i_split[split_idx]
frame_types = [frame_type] * len(split)
eval_batches = [[i] for i in range(len(split))]
if split_idx != 0 and self.n_known_frames_for_test is not None:
train_split = self.i_split[0]
if set_eval_batches:
generator = np.random.default_rng(seed=0)
for batch in eval_batches:
to_add = generator.choice(
len(train_split), self.n_known_frames_for_test
)
batch.extend((to_add + len(split)).tolist())
split = np.concatenate([split, train_split])
frame_types.extend([DATASET_TYPE_KNOWN] * len(train_split))
# pyre-ignore[28]
return SingleSceneDataset(
object_name=self.object_name,
# pyre-ignore[16]
images=self.images[split],
# pyre-ignore[16]
poses=[self.poses[i] for i in split],
frame_types=frame_types,
eval_batches=eval_batches if set_eval_batches else None,
)
def get_dataset_map(self) -> DatasetMap:
return DatasetMap(
train=self._get_dataset(0, DATASET_TYPE_KNOWN),
val=self._get_dataset(1, DATASET_TYPE_UNKNOWN),
test=self._get_dataset(2, DATASET_TYPE_UNKNOWN, True),
)
def get_task(self) -> Task:
return Task.SINGLE_SEQUENCE
def _interpret_blender_cameras(
poses: torch.Tensor, H: int, W: int, focal: float
) -> List[PerspectiveCameras]:
"""
Convert 4x4 matrices representing cameras in blender format
to PyTorch3D format.
Args:
poses: N x 3 x 4 camera matrices
"""
pose_target_cameras = []
for pose_target in poses:
pose_target = pose_target[:3, :4]
mtx = torch.eye(4, dtype=pose_target.dtype)
mtx[:3, :3] = pose_target[:3, :3].t()
mtx[3, :3] = pose_target[:, 3]
mtx = mtx.inverse()
# flip the XZ coordinates.
mtx[:, [0, 2]] *= -1.0
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
cameras = PerspectiveCameras(
focal_length=focal_length_pt3,
principal_point=principal_point_pt3,
R=Rpt3[None],
T=Tpt3,
)
pose_target_cameras.append(cameras)
return pose_target_cameras

View File

@@ -220,6 +220,7 @@ class Configurable:
_X = TypeVar("X", bound=ReplaceableBase)
_Y = TypeVar("Y", bound=Union[ReplaceableBase, Configurable])
class _Registry:
@@ -307,20 +308,23 @@ class _Registry:
It determines the namespace.
This will typically be a direct subclass of ReplaceableBase.
Returns:
list of class types
list of class types in alphabetical order of registered name.
"""
if self._is_base_class(base_class_wanted):
return list(self._mapping[base_class_wanted].values())
source = self._mapping[base_class_wanted]
return [source[key] for key in sorted(source)]
base_class = self._base_class_from_class(base_class_wanted)
if base_class is None:
raise ValueError(
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
)
source = self._mapping[base_class]
return [
class_
for class_ in self._mapping[base_class].values()
if issubclass(class_, base_class_wanted) and class_ is not base_class_wanted
source[key]
for key in sorted(source)
if issubclass(source[key], base_class_wanted)
and source[key] is not base_class_wanted
]
@staticmethod
@@ -647,8 +651,8 @@ def _is_actually_dataclass(some_class) -> bool:
def expand_args_fields(
some_class: Type[_X], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_X]:
some_class: Type[_Y], *, _do_not_process: Tuple[type, ...] = ()
) -> Type[_Y]:
"""
This expands a class which inherits Configurable or ReplaceableBase classes,
including dataclass processing. some_class is modified in place by this function.

View File

@@ -13,6 +13,7 @@ from .blending import (
from .camera_utils import join_cameras_as_batch, rotate_on_spot
from .cameras import ( # deprecated # deprecated # deprecated # deprecated
camera_position_from_spherical_angles,
CamerasBase,
FoVOrthographicCameras,
FoVPerspectiveCameras,
get_world_to_view_transform,