mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Filtering outlier input cameras in trajectory estimation
Summary: Useful for visualising colmap output where some frames are not correctly registered. Reviewed By: bottler Differential Revision: D38743191 fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311
This commit is contained in:
parent
b7c826b786
commit
d281f8efd1
@ -4,16 +4,21 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.common.compat import eigh
|
from pytorch3d.common.compat import eigh
|
||||||
|
from pytorch3d.implicitron.tools import utils
|
||||||
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
|
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
|
||||||
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
|
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
|
||||||
from pytorch3d.transforms import Scale
|
from pytorch3d.transforms import Scale
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def generate_eval_video_cameras(
|
def generate_eval_video_cameras(
|
||||||
train_cameras,
|
train_cameras,
|
||||||
n_eval_cams: int = 100,
|
n_eval_cams: int = 100,
|
||||||
@ -27,6 +32,7 @@ def generate_eval_video_cameras(
|
|||||||
infer_up_as_plane_normal: bool = True,
|
infer_up_as_plane_normal: bool = True,
|
||||||
traj_offset: Optional[Tuple[float, float, float]] = None,
|
traj_offset: Optional[Tuple[float, float, float]] = None,
|
||||||
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
|
||||||
|
remove_outliers_rate: float = 0.0,
|
||||||
) -> PerspectiveCameras:
|
) -> PerspectiveCameras:
|
||||||
"""
|
"""
|
||||||
Generate a camera trajectory rendering a scene from multiple viewpoints.
|
Generate a camera trajectory rendering a scene from multiple viewpoints.
|
||||||
@ -50,9 +56,16 @@ def generate_eval_video_cameras(
|
|||||||
Active for the `trajectory_type="circular"`.
|
Active for the `trajectory_type="circular"`.
|
||||||
scene_center: The center of the scene in world coordinates which all
|
scene_center: The center of the scene in world coordinates which all
|
||||||
the cameras from the generated trajectory look at.
|
the cameras from the generated trajectory look at.
|
||||||
|
remove_outliers_rate: the number between 0 and 1; if > 0,
|
||||||
|
some outlier train_cameras will be removed from trajectory estimation;
|
||||||
|
the filtering is based on camera center coordinates; top and
|
||||||
|
bottom `remove_outliers_rate` cameras on each dimension are removed.
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary of camera instances which can be used as the test dataset
|
Dictionary of camera instances which can be used as the test dataset
|
||||||
"""
|
"""
|
||||||
|
if remove_outliers_rate > 0.0:
|
||||||
|
train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
|
||||||
|
|
||||||
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
|
if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
|
||||||
cam_centers = train_cameras.get_camera_center()
|
cam_centers = train_cameras.get_camera_center()
|
||||||
# get the nearest camera center to the mean of centers
|
# get the nearest camera center to the mean of centers
|
||||||
@ -167,6 +180,20 @@ def generate_eval_video_cameras(
|
|||||||
return test_cameras
|
return test_cameras
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_outlier_cameras(
|
||||||
|
cameras: PerspectiveCameras, outlier_rate: float
|
||||||
|
) -> PerspectiveCameras:
|
||||||
|
keep_indices = utils.get_inlier_indicators(
|
||||||
|
cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate
|
||||||
|
)
|
||||||
|
clean_cameras = cameras[keep_indices]
|
||||||
|
logger.info(
|
||||||
|
"Filtered outlier cameras when estimating the trajectory: "
|
||||||
|
f"{len(cameras)} → {len(clean_cameras)}"
|
||||||
|
)
|
||||||
|
return clean_cameras
|
||||||
|
|
||||||
|
|
||||||
def _disambiguate_normal(normal, up):
|
def _disambiguate_normal(normal, up):
|
||||||
up_t = torch.tensor(up).to(normal)
|
up_t = torch.tensor(up).to(normal)
|
||||||
flip = (up_t * normal).sum().sign()
|
flip = (up_t * normal).sum().sign()
|
||||||
|
@ -9,7 +9,7 @@ import collections
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict
|
from typing import Any, Callable, Dict, Iterable, Iterator
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable):
|
|||||||
return type(elem)(**collated)
|
return type(elem)(**collated)
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
|
||||||
|
for x in it:
|
||||||
|
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||||
|
yield from recursive_visitor(x)
|
||||||
|
else:
|
||||||
|
yield x
|
||||||
|
|
||||||
|
|
||||||
|
def get_inlier_indicators(
|
||||||
|
tensor: torch.Tensor, dim: int, outlier_rate: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
|
||||||
|
hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
|
||||||
|
lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
|
||||||
|
remove_indices = set(recursive_visitor([hi, lo]))
|
||||||
|
keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
|
||||||
|
keep_indices[list(remove_indices)] = False
|
||||||
|
return keep_indices
|
||||||
|
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
"""
|
"""
|
||||||
A simple class for timing execution.
|
A simple class for timing execution.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user