Filtering outlier input cameras in trajectory estimation

Summary: Useful for visualising colmap output where some frames are not correctly registered.

Reviewed By: bottler

Differential Revision: D38743191

fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311
This commit is contained in:
Roman Shapovalov 2022-08-17 03:47:31 -07:00 committed by Facebook GitHub Bot
parent b7c826b786
commit d281f8efd1
2 changed files with 48 additions and 1 deletions

View File

@ -4,16 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import math
from typing import Optional, Tuple
import torch
from pytorch3d.common.compat import eigh
from pytorch3d.implicitron.tools import utils
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
from pytorch3d.transforms import Scale
logger = logging.getLogger(__name__)
def generate_eval_video_cameras(
train_cameras,
n_eval_cams: int = 100,
@ -27,6 +32,7 @@ def generate_eval_video_cameras(
infer_up_as_plane_normal: bool = True,
traj_offset: Optional[Tuple[float, float, float]] = None,
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
remove_outliers_rate: float = 0.0,
) -> PerspectiveCameras:
"""
Generate a camera trajectory rendering a scene from multiple viewpoints.
@ -50,9 +56,16 @@ def generate_eval_video_cameras(
Active for the `trajectory_type="circular"`.
scene_center: The center of the scene in world coordinates which all
the cameras from the generated trajectory look at.
remove_outliers_rate: the number between 0 and 1; if > 0,
some outlier train_cameras will be removed from trajectory estimation;
the filtering is based on camera center coordinates; top and
bottom `remove_outliers_rate` cameras on each dimension are removed.
Returns:
Dictionary of camera instances which can be used as the test dataset
"""
if remove_outliers_rate > 0.0:
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
cam_centers = train_cameras.get_camera_center()
# get the nearest camera center to the mean of centers
@ -167,6 +180,20 @@ def generate_eval_video_cameras(
return test_cameras
def _remove_outlier_cameras(
cameras: PerspectiveCameras, outlier_rate: float
) -> PerspectiveCameras:
keep_indices = utils.get_inlier_indicators(
cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate
)
clean_cameras = cameras[keep_indices]
logger.info(
"Filtered outlier cameras when estimating the trajectory: "
f"{len(cameras)}{len(clean_cameras)}"
)
return clean_cameras
def _disambiguate_normal(normal, up):
up_t = torch.tensor(up).to(normal)
flip = (up_t * normal).sum().sign()

View File

@ -9,7 +9,7 @@ import collections
import dataclasses
import time
from contextlib import contextmanager
from typing import Any, Callable, Dict
from typing import Any, Callable, Dict, Iterable, Iterator
import torch
@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable):
return type(elem)(**collated)
def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
for x in it:
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
yield from recursive_visitor(x)
else:
yield x
def get_inlier_indicators(
tensor: torch.Tensor, dim: int, outlier_rate: float
) -> torch.Tensor:
remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
remove_indices = set(recursive_visitor([hi, lo]))
keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
keep_indices[list(remove_indices)] = False
return keep_indices
class Timer:
"""
A simple class for timing execution.