mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Filtering outlier input cameras in trajectory estimation
Summary: Useful for visualising colmap output where some frames are not correctly registered. Reviewed By: bottler Differential Revision: D38743191 fbshipit-source-id: e823df2997870dc41d76784e112d4349f904d311
This commit is contained in:
		
							parent
							
								
									b7c826b786
								
							
						
					
					
						commit
						d281f8efd1
					
				@ -4,16 +4,21 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import math
 | 
			
		||||
from typing import Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.common.compat import eigh
 | 
			
		||||
from pytorch3d.implicitron.tools import utils
 | 
			
		||||
from pytorch3d.implicitron.tools.circle_fitting import fit_circle_in_3d
 | 
			
		||||
from pytorch3d.renderer import look_at_view_transform, PerspectiveCameras
 | 
			
		||||
from pytorch3d.transforms import Scale
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def generate_eval_video_cameras(
 | 
			
		||||
    train_cameras,
 | 
			
		||||
    n_eval_cams: int = 100,
 | 
			
		||||
@ -27,6 +32,7 @@ def generate_eval_video_cameras(
 | 
			
		||||
    infer_up_as_plane_normal: bool = True,
 | 
			
		||||
    traj_offset: Optional[Tuple[float, float, float]] = None,
 | 
			
		||||
    traj_offset_canonical: Optional[Tuple[float, float, float]] = None,
 | 
			
		||||
    remove_outliers_rate: float = 0.0,
 | 
			
		||||
) -> PerspectiveCameras:
 | 
			
		||||
    """
 | 
			
		||||
    Generate a camera trajectory rendering a scene from multiple viewpoints.
 | 
			
		||||
@ -50,9 +56,16 @@ def generate_eval_video_cameras(
 | 
			
		||||
            Active for the `trajectory_type="circular"`.
 | 
			
		||||
        scene_center: The center of the scene in world coordinates which all
 | 
			
		||||
            the cameras from the generated trajectory look at.
 | 
			
		||||
        remove_outliers_rate: the number between 0 and 1; if > 0,
 | 
			
		||||
            some outlier train_cameras will be removed from trajectory estimation;
 | 
			
		||||
            the filtering is based on camera center coordinates; top and
 | 
			
		||||
            bottom `remove_outliers_rate` cameras on each dimension are removed.
 | 
			
		||||
    Returns:
 | 
			
		||||
        Dictionary of camera instances which can be used as the test dataset
 | 
			
		||||
    """
 | 
			
		||||
    if remove_outliers_rate > 0.0:
 | 
			
		||||
        train_cameras = _remove_outlier_cameras(train_cameras, remove_outliers_rate)
 | 
			
		||||
 | 
			
		||||
    if trajectory_type in ("figure_eight", "trefoil_knot", "figure_eight_knot"):
 | 
			
		||||
        cam_centers = train_cameras.get_camera_center()
 | 
			
		||||
        # get the nearest camera center to the mean of centers
 | 
			
		||||
@ -167,6 +180,20 @@ def generate_eval_video_cameras(
 | 
			
		||||
    return test_cameras
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _remove_outlier_cameras(
 | 
			
		||||
    cameras: PerspectiveCameras, outlier_rate: float
 | 
			
		||||
) -> PerspectiveCameras:
 | 
			
		||||
    keep_indices = utils.get_inlier_indicators(
 | 
			
		||||
        cameras.get_camera_center(), dim=0, outlier_rate=outlier_rate
 | 
			
		||||
    )
 | 
			
		||||
    clean_cameras = cameras[keep_indices]
 | 
			
		||||
    logger.info(
 | 
			
		||||
        "Filtered outlier cameras when estimating the trajectory: "
 | 
			
		||||
        f"{len(cameras)} → {len(clean_cameras)}"
 | 
			
		||||
    )
 | 
			
		||||
    return clean_cameras
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _disambiguate_normal(normal, up):
 | 
			
		||||
    up_t = torch.tensor(up).to(normal)
 | 
			
		||||
    flip = (up_t * normal).sum().sign()
 | 
			
		||||
 | 
			
		||||
@ -9,7 +9,7 @@ import collections
 | 
			
		||||
import dataclasses
 | 
			
		||||
import time
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from typing import Any, Callable, Dict
 | 
			
		||||
from typing import Any, Callable, Dict, Iterable, Iterator
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
@ -157,6 +157,26 @@ def cat_dataclass(batch, tensor_collator: Callable):
 | 
			
		||||
    return type(elem)(**collated)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
 | 
			
		||||
    for x in it:
 | 
			
		||||
        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
 | 
			
		||||
            yield from recursive_visitor(x)
 | 
			
		||||
        else:
 | 
			
		||||
            yield x
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_inlier_indicators(
 | 
			
		||||
    tensor: torch.Tensor, dim: int, outlier_rate: float
 | 
			
		||||
) -> torch.Tensor:
 | 
			
		||||
    remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
 | 
			
		||||
    hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
 | 
			
		||||
    lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
 | 
			
		||||
    remove_indices = set(recursive_visitor([hi, lo]))
 | 
			
		||||
    keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
 | 
			
		||||
    keep_indices[list(remove_indices)] = False
 | 
			
		||||
    return keep_indices
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Timer:
 | 
			
		||||
    """
 | 
			
		||||
    A simple class for timing execution.
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user