mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Filtering outlier input cameras in trajectory estimation
Summary: Useful for visualising colmap output where some frames are not correctly registered. Reviewed By: bottler Differential Revision: D38743191 fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311
This commit is contained in:
parent
b7c826b786
commit
d281f8efd1
@ -4,16 +4,21 @@
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.common.compat import eigh
|
||||
from pytorch3d.implicitron.tools import utils
|
||||
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
|
||||
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
|
||||
from pytorch3d.transforms import Scale
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_eval_video_cameras(
|
||||
train_cameras,
|
||||
n_eval_cams: int = 100,
|
||||
@ -27,6 +32,7 @@ def generate_eval_video_cameras(
|
||||
infer_up_as_plane_normal: bool = True,
|
||||
traj_offset: Optional[Tuple[float, float, float]] = None,
|
||||
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
||||
remove_outliers_rate: float = 0.0,
|
||||
) -> PerspectiveCameras:
|
||||
"""
|
||||
Generate a camera trajectory rendering a scene from multiple viewpoints.
|
||||
@ -50,9 +56,16 @@ def generate_eval_video_cameras(
|
||||
Active for the `trajectory_type="circular"`.
|
||||
scene_center: The center of the scene in world coordinates which all
|
||||
the cameras from the generated trajectory look at.
|
||||
remove_outliers_rate: the number between 0 and 1; if > 0,
|
||||
some outlier train_cameras will be removed from trajectory estimation;
|
||||
the filtering is based on camera center coordinates; top and
|
||||
bottom `remove_outliers_rate` cameras on each dimension are removed.
|
||||
Returns:
|
||||
Dictionary of camera instances which can be used as the test dataset
|
||||
"""
|
||||
if remove_outliers_rate > 0.0:
|
||||
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
|
||||
|
||||
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
|
||||
cam_centers = train_cameras.get_camera_center()
|
||||
# get the nearest camera center to the mean of centers
|
||||
@ -167,6 +180,20 @@ def generate_eval_video_cameras(
|
||||
return test_cameras
|
||||
|
||||
|
||||
def _remove_outlier_cameras(
|
||||
cameras: PerspectiveCameras, outlier_rate: float
|
||||
) -> PerspectiveCameras:
|
||||
keep_indices = utils.get_inlier_indicators(
|
||||
cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate
|
||||
)
|
||||
clean_cameras = cameras[keep_indices]
|
||||
logger.info(
|
||||
"Filtered outlier cameras when estimating the trajectory: "
|
||||
f"{len(cameras)} → {len(clean_cameras)}"
|
||||
)
|
||||
return clean_cameras
|
||||
|
||||
|
||||
def _disambiguate_normal(normal, up):
|
||||
up_t = torch.tensor(up).to(normal)
|
||||
flip = (up_t * normal).sum().sign()
|
||||
|
@ -9,7 +9,7 @@ import collections
|
||||
import dataclasses
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Dict
|
||||
from typing import Any, Callable, Dict, Iterable, Iterator
|
||||
|
||||
import torch
|
||||
|
||||
@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable):
|
||||
return type(elem)(**collated)
|
||||
|
||||
|
||||
def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
|
||||
for x in it:
|
||||
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||
yield from recursive_visitor(x)
|
||||
else:
|
||||
yield x
|
||||
|
||||
|
||||
def get_inlier_indicators(
|
||||
tensor: torch.Tensor, dim: int, outlier_rate: float
|
||||
) -> torch.Tensor:
|
||||
remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
|
||||
hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
|
||||
lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
|
||||
remove_indices = set(recursive_visitor([hi, lo]))
|
||||
keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
|
||||
keep_indices[list(remove_indices)] = False
|
||||
return keep_indices
|
||||
|
||||
|
||||
class Timer:
|
||||
"""
|
||||
A simple class for timing execution.
|
||||
|
Loading…
x
Reference in New Issue
Block a user