Fix circular import

Summary: This fixes a recently introduced circular import: the problem went unnoticed by having `pytorch3d.renderer` imported first...

Reviewed By: bottler

Differential Revision: D29686235

fbshipit-source-id: 4b9f2faecec2cc8347ee259cfc359dc9e4f67784
This commit is contained in:
Patrick Labatut 2021-07-30 03:05:13 -07:00 committed by Facebook GitHub Bot
parent 5eec5e289e
commit 9a14f54e8b
4 changed files with 189 additions and 139 deletions

View File

@ -0,0 +1,176 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Tuple
import torch
from ..transforms import matrix_to_rotation_6d
from .cameras import PerspectiveCameras
LOGGER = logging.getLogger(__name__)
def _cameras_from_opencv_projection(
R: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
) -> PerspectiveCameras:
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
principal_point = camera_matrix[:, :2, 2]
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
# Get the PyTorch3D focal length and principal point.
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
return PerspectiveCameras(
R=R_pytorch3d,
T=T_pytorch3d,
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)
def _opencv_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
R_pytorch3d = cameras.R.clone() # pyre-ignore
T_pytorch3d = cameras.T.clone() # pyre-ignore
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1
R_pytorch3d[:, :, :2] *= -1
tvec = T_pytorch3d
R = R_pytorch3d.permute(0, 2, 1)
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
camera_matrix = torch.zeros_like(R)
camera_matrix[:, :2, 2] = principal_point
camera_matrix[:, 2, 2] = 1.0
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
return R, tvec, camera_matrix
def _pulsar_from_opencv_projection(
R: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
znear: float = 0.1,
) -> torch.Tensor:
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
assert len(R.size()) == 3, "This function requires batched inputs!"
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
batch_size,
R.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]
image_h = image_size_wh[0, 1]
assert torch.all(
image_size_wh[:, 0] == image_w
), "All images in a batch must have the same width!"
assert torch.all(
image_size_wh[:, 1] == image_h
), "All images in a batch must have the same height!"
# Focal length.
fx = camera_matrix[:, 0, 0].unsqueeze(1)
fy = camera_matrix[:, 1, 1].unsqueeze(1)
# Check that we introduce less than 1% error by averaging the focal lengths.
fx_y = fx / fy
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
LOGGER.warning(
"Pulsar only supports a single focal lengths. For converting OpenCV "
"focal lengths, we average them for x and y directions. "
"The focal lengths for x and y you provided differ by more than 1%, "
"which means this could introduce a noticeable error."
)
f = (fx + fy) / 2
# Normalize f into normalized device coordinates.
focal_length_px = f / image_w
# Transfer into focal_length and sensor_width.
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
focal_length = focal_length[None, :].repeat(batch_size, 1)
sensor_width = focal_length / focal_length_px
# Principal point.
cx = camera_matrix[:, 0, 2].unsqueeze(1)
cy = camera_matrix[:, 1, 2].unsqueeze(1)
# Transfer principal point offset into centered offset.
cx = -(cx - image_w / 2)
cy = cy - image_h / 2
# Concatenate to final vector.
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
R_trans = R.permute(0, 2, 1)
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
cam_rot = matrix_to_rotation_6d(R_trans)
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
return cam_params
def _pulsar_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> torch.Tensor:
opencv_R, opencv_T, opencv_K = _opencv_from_cameras_projection(cameras, image_size)
return _pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
from ....utils import pulsar_from_cameras_projection
from ...camera_conversions import _pulsar_from_cameras_projection
from ...cameras import (
FoVOrthographicCameras,
FoVPerspectiveCameras,
@ -378,7 +378,7 @@ class PulsarPointsRenderer(nn.Module):
size_tensor = torch.tensor(
[[self.renderer._renderer.height, self.renderer._renderer.width]]
)
pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor)
pulsar_cam = _pulsar_from_cameras_projection(tmp_cams, size_tensor)
cam_pos = pulsar_cam[0, :3]
cam_rot = pulsar_cam[0, 3:9]
return cam_pos, cam_rot

View File

@ -7,8 +7,8 @@
from .camera_conversions import (
cameras_from_opencv_projection,
opencv_from_cameras_projection,
pulsar_from_opencv_projection,
pulsar_from_cameras_projection,
pulsar_from_opencv_projection,
)
from .ico_sphere import ico_sphere
from .torus import torus

View File

@ -4,16 +4,17 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Tuple
import torch
from ..renderer import PerspectiveCameras
from ..transforms import matrix_to_rotation_6d
LOGGER = logging.getLogger(__name__)
from ..renderer.camera_conversions import (
_cameras_from_opencv_projection,
_opencv_from_cameras_projection,
_pulsar_from_cameras_projection,
_pulsar_from_opencv_projection,
)
def cameras_from_opencv_projection(
@ -58,30 +59,7 @@ def cameras_from_opencv_projection(
Returns:
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
"""
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
principal_point = camera_matrix[:, :2, 2]
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
# Get the PyTorch3D focal length and principal point.
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
return PerspectiveCameras(
R=R_pytorch3d,
T=T_pytorch3d,
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)
return _cameras_from_opencv_projection(R, tvec, camera_matrix, image_size)
def opencv_from_cameras_projection(
@ -114,27 +92,7 @@ def opencv_from_cameras_projection(
tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
"""
R_pytorch3d = cameras.R.clone() # pyre-ignore
T_pytorch3d = cameras.T.clone() # pyre-ignore
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1
R_pytorch3d[:, :, :2] *= -1
tvec = T_pytorch3d
R = R_pytorch3d.permute(0, 2, 1)
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
camera_matrix = torch.zeros_like(R)
camera_matrix[:, :2, 2] = principal_point
camera_matrix[:, 2, 2] = 1.0
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
return R, tvec, camera_matrix
return _opencv_from_cameras_projection(cameras, image_size)
def pulsar_from_opencv_projection(
@ -170,90 +128,7 @@ def pulsar_from_opencv_projection(
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y).
"""
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
assert len(R.size()) == 3, "This function requires batched inputs!"
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
batch_size,
R.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]
image_h = image_size_wh[0, 1]
assert torch.all(
image_size_wh[:, 0] == image_w
), "All images in a batch must have the same width!"
assert torch.all(
image_size_wh[:, 1] == image_h
), "All images in a batch must have the same height!"
# Focal length.
fx = camera_matrix[:, 0, 0].unsqueeze(1)
fy = camera_matrix[:, 1, 1].unsqueeze(1)
# Check that we introduce less than 1% error by averaging the focal lengths.
fx_y = fx / fy
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
LOGGER.warning(
"Pulsar only supports a single focal lengths. For converting OpenCV "
"focal lengths, we average them for x and y directions. "
"The focal lengths for x and y you provided differ by more than 1%, "
"which means this could introduce a noticeable error."
)
f = (fx + fy) / 2
# Normalize f into normalized device coordinates.
focal_length_px = f / image_w
# Transfer into focal_length and sensor_width.
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
focal_length = focal_length[None, :].repeat(batch_size, 1)
sensor_width = focal_length / focal_length_px
# Principal point.
cx = camera_matrix[:, 0, 2].unsqueeze(1)
cy = camera_matrix[:, 1, 2].unsqueeze(1)
# Transfer principal point offset into centered offset.
cx = -(cx - image_w / 2)
cy = cy - image_h / 2
# Concatenate to final vector.
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
R_trans = R.permute(0, 2, 1)
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
cam_rot = matrix_to_rotation_6d(R_trans)
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
return cam_params
return _pulsar_from_opencv_projection(R, tvec, camera_matrix, image_size, znear)
def pulsar_from_cameras_projection(
@ -281,5 +156,4 @@ def pulsar_from_cameras_projection(
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y).
"""
opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size)
return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)
return _pulsar_from_cameras_projection(cameras, image_size)