Fix circular import

Summary: This fixes a recently introduced circular import: the problem went unnoticed by having `pytorch3d.renderer` imported first...

Reviewed By: bottler

Differential Revision: D29686235

fbshipit-source-id: 4b9f2faecec2cc8347ee259cfc359dc9e4f67784
This commit is contained in:
Patrick Labatut 2021-07-30 03:05:13 -07:00 committed by Facebook GitHub Bot
parent 5eec5e289e
commit 9a14f54e8b
4 changed files with 189 additions and 139 deletions

View File

@ -0,0 +1,176 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Tuple
import torch
from ..transforms import matrix_to_rotation_6d
from .cameras import PerspectiveCameras
LOGGER = logging.getLogger(__name__)
def _cameras_from_opencv_projection(
R: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
) -> PerspectiveCameras:
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1)
principal_point = camera_matrix[:, :2, 2]
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
# Get the PyTorch3D focal length and principal point.
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
return PerspectiveCameras(
R=R_pytorch3d,
T=T_pytorch3d,
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)
def _opencv_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
R_pytorch3d = cameras.R.clone() # pyre-ignore
T_pytorch3d = cameras.T.clone() # pyre-ignore
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1
R_pytorch3d[:, :, :2] *= -1
tvec = T_pytorch3d
R = R_pytorch3d.permute(0, 2, 1)
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
camera_matrix = torch.zeros_like(R)
camera_matrix[:, :2, 2] = principal_point
camera_matrix[:, 2, 2] = 1.0
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
return R, tvec, camera_matrix
def _pulsar_from_opencv_projection(
R: torch.Tensor,
tvec: torch.Tensor,
camera_matrix: torch.Tensor,
image_size: torch.Tensor,
znear: float = 0.1,
) -> torch.Tensor:
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!"
assert len(R.size()) == 3, "This function requires batched inputs!"
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
batch_size,
R.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]
image_h = image_size_wh[0, 1]
assert torch.all(
image_size_wh[:, 0] == image_w
), "All images in a batch must have the same width!"
assert torch.all(
image_size_wh[:, 1] == image_h
), "All images in a batch must have the same height!"
# Focal length.
fx = camera_matrix[:, 0, 0].unsqueeze(1)
fy = camera_matrix[:, 1, 1].unsqueeze(1)
# Check that we introduce less than 1% error by averaging the focal lengths.
fx_y = fx / fy
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
LOGGER.warning(
"Pulsar only supports a single focal lengths. For converting OpenCV "
"focal lengths, we average them for x and y directions. "
"The focal lengths for x and y you provided differ by more than 1%, "
"which means this could introduce a noticeable error."
)
f = (fx + fy) / 2
# Normalize f into normalized device coordinates.
focal_length_px = f / image_w
# Transfer into focal_length and sensor_width.
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
focal_length = focal_length[None, :].repeat(batch_size, 1)
sensor_width = focal_length / focal_length_px
# Principal point.
cx = camera_matrix[:, 0, 2].unsqueeze(1)
cy = camera_matrix[:, 1, 2].unsqueeze(1)
# Transfer principal point offset into centered offset.
cx = -(cx - image_w / 2)
cy = cy - image_h / 2
# Concatenate to final vector.
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
R_trans = R.permute(0, 2, 1)
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
cam_rot = matrix_to_rotation_6d(R_trans)
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
return cam_params
def _pulsar_from_cameras_projection(
cameras: PerspectiveCameras,
image_size: torch.Tensor,
) -> torch.Tensor:
opencv_R, opencv_T, opencv_K = _opencv_from_cameras_projection(cameras, image_size)
return _pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)

View File

@ -11,7 +11,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from ....utils import pulsar_from_cameras_projection from ...camera_conversions import _pulsar_from_cameras_projection
from ...cameras import ( from ...cameras import (
FoVOrthographicCameras, FoVOrthographicCameras,
FoVPerspectiveCameras, FoVPerspectiveCameras,
@ -378,7 +378,7 @@ class PulsarPointsRenderer(nn.Module):
size_tensor = torch.tensor( size_tensor = torch.tensor(
[[self.renderer._renderer.height, self.renderer._renderer.width]] [[self.renderer._renderer.height, self.renderer._renderer.width]]
) )
pulsar_cam = pulsar_from_cameras_projection(tmp_cams, size_tensor) pulsar_cam = _pulsar_from_cameras_projection(tmp_cams, size_tensor)
cam_pos = pulsar_cam[0, :3] cam_pos = pulsar_cam[0, :3]
cam_rot = pulsar_cam[0, 3:9] cam_rot = pulsar_cam[0, 3:9]
return cam_pos, cam_rot return cam_pos, cam_rot

View File

@ -7,8 +7,8 @@
from .camera_conversions import ( from .camera_conversions import (
cameras_from_opencv_projection, cameras_from_opencv_projection,
opencv_from_cameras_projection, opencv_from_cameras_projection,
pulsar_from_opencv_projection,
pulsar_from_cameras_projection, pulsar_from_cameras_projection,
pulsar_from_opencv_projection,
) )
from .ico_sphere import ico_sphere from .ico_sphere import ico_sphere
from .torus import torus from .torus import torus

View File

@ -4,16 +4,17 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging
from typing import Tuple from typing import Tuple
import torch import torch
from ..renderer import PerspectiveCameras from ..renderer import PerspectiveCameras
from ..transforms import matrix_to_rotation_6d from ..renderer.camera_conversions import (
_cameras_from_opencv_projection,
_opencv_from_cameras_projection,
LOGGER = logging.getLogger(__name__) _pulsar_from_cameras_projection,
_pulsar_from_opencv_projection,
)
def cameras_from_opencv_projection( def cameras_from_opencv_projection(
@ -58,30 +59,7 @@ def cameras_from_opencv_projection(
Returns: Returns:
cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention. cameras_pytorch3d: A batch of `N` cameras in the PyTorch3D convention.
""" """
focal_length = torch.stack([camera_matrix[:, 0, 0], camera_matrix[:, 1, 1]], dim=-1) return _cameras_from_opencv_projection(R, tvec, camera_matrix, image_size)
principal_point = camera_matrix[:, :2, 2]
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
# Get the PyTorch3D focal length and principal point.
focal_pytorch3d = focal_length / (0.5 * image_size_wh)
p0_pytorch3d = -(principal_point / (0.5 * image_size_wh) - 1)
# For R, T we flip x, y axes (opencv screen space has an opposite
# orientation of screen axes).
# We also transpose R (opencv multiplies points from the opposite=left side).
R_pytorch3d = R.clone().permute(0, 2, 1)
T_pytorch3d = tvec.clone()
R_pytorch3d[:, :, :2] *= -1
T_pytorch3d[:, :2] *= -1
return PerspectiveCameras(
R=R_pytorch3d,
T=T_pytorch3d,
focal_length=focal_pytorch3d,
principal_point=p0_pytorch3d,
)
def opencv_from_cameras_projection( def opencv_from_cameras_projection(
@ -114,27 +92,7 @@ def opencv_from_cameras_projection(
tvec: A batch of translation vectors of shape `(N, 3)`. tvec: A batch of translation vectors of shape `(N, 3)`.
camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`. camera_matrix: A batch of camera calibration matrices of shape `(N, 3, 3)`.
""" """
R_pytorch3d = cameras.R.clone() # pyre-ignore return _opencv_from_cameras_projection(cameras, image_size)
T_pytorch3d = cameras.T.clone() # pyre-ignore
focal_pytorch3d = cameras.focal_length
p0_pytorch3d = cameras.principal_point
T_pytorch3d[:, :2] *= -1
R_pytorch3d[:, :, :2] *= -1
tvec = T_pytorch3d
R = R_pytorch3d.permute(0, 2, 1)
# Retype the image_size correctly and flip to width, height.
image_size_wh = image_size.to(R).flip(dims=(1,))
principal_point = (-p0_pytorch3d + 1.0) * (0.5 * image_size_wh) # pyre-ignore
focal_length = focal_pytorch3d * (0.5 * image_size_wh)
camera_matrix = torch.zeros_like(R)
camera_matrix[:, :2, 2] = principal_point
camera_matrix[:, 2, 2] = 1.0
camera_matrix[:, 0, 0] = focal_length[:, 0]
camera_matrix[:, 1, 1] = focal_length[:, 1]
return R, tvec, camera_matrix
def pulsar_from_opencv_projection( def pulsar_from_opencv_projection(
@ -170,90 +128,7 @@ def pulsar_from_opencv_projection(
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width, convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y). c_x, c_y).
""" """
assert len(camera_matrix.size()) == 3, "This function requires batched inputs!" return _pulsar_from_opencv_projection(R, tvec, camera_matrix, image_size, znear)
assert len(R.size()) == 3, "This function requires batched inputs!"
assert len(tvec.size()) in (2, 3), "This function reuqires batched inputs!"
# Validate parameters.
image_size_wh = image_size.to(R).flip(dims=(1,))
assert torch.all(
image_size_wh > 0
), "height and width must be positive but min is: %s" % (
str(image_size_wh.min().item())
)
assert (
camera_matrix.size(1) == 3 and camera_matrix.size(2) == 3
), "Incorrect camera matrix shape: expected 3x3 but got %dx%d" % (
camera_matrix.size(1),
camera_matrix.size(2),
)
assert (
R.size(1) == 3 and R.size(2) == 3
), "Incorrect R shape: expected 3x3 but got %dx%d" % (
R.size(1),
R.size(2),
)
if len(tvec.size()) == 2:
tvec = tvec.unsqueeze(2)
assert (
tvec.size(1) == 3 and tvec.size(2) == 1
), "Incorrect tvec shape: expected 3x1 but got %dx%d" % (
tvec.size(1),
tvec.size(2),
)
# Check batch size.
batch_size = camera_matrix.size(0)
assert R.size(0) == batch_size, "Expected R to have batch size %d. Has size %d." % (
batch_size,
R.size(0),
)
assert (
tvec.size(0) == batch_size
), "Expected tvec to have batch size %d. Has size %d." % (
batch_size,
tvec.size(0),
)
# Check image sizes.
image_w = image_size_wh[0, 0]
image_h = image_size_wh[0, 1]
assert torch.all(
image_size_wh[:, 0] == image_w
), "All images in a batch must have the same width!"
assert torch.all(
image_size_wh[:, 1] == image_h
), "All images in a batch must have the same height!"
# Focal length.
fx = camera_matrix[:, 0, 0].unsqueeze(1)
fy = camera_matrix[:, 1, 1].unsqueeze(1)
# Check that we introduce less than 1% error by averaging the focal lengths.
fx_y = fx / fy
if torch.any(fx_y > 1.01) or torch.any(fx_y < 0.99):
LOGGER.warning(
"Pulsar only supports a single focal lengths. For converting OpenCV "
"focal lengths, we average them for x and y directions. "
"The focal lengths for x and y you provided differ by more than 1%, "
"which means this could introduce a noticeable error."
)
f = (fx + fy) / 2
# Normalize f into normalized device coordinates.
focal_length_px = f / image_w
# Transfer into focal_length and sensor_width.
focal_length = torch.tensor([znear - 1e-5], dtype=torch.float32, device=R.device)
focal_length = focal_length[None, :].repeat(batch_size, 1)
sensor_width = focal_length / focal_length_px
# Principal point.
cx = camera_matrix[:, 0, 2].unsqueeze(1)
cy = camera_matrix[:, 1, 2].unsqueeze(1)
# Transfer principal point offset into centered offset.
cx = -(cx - image_w / 2)
cy = cy - image_h / 2
# Concatenate to final vector.
param = torch.cat([focal_length, sensor_width, cx, cy], dim=1)
R_trans = R.permute(0, 2, 1)
cam_pos = -torch.bmm(R_trans, tvec).squeeze(2)
cam_rot = matrix_to_rotation_6d(R_trans)
cam_params = torch.cat([cam_pos, cam_rot, param], dim=1)
return cam_params
def pulsar_from_cameras_projection( def pulsar_from_cameras_projection(
@ -281,5 +156,4 @@ def pulsar_from_cameras_projection(
convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width, convention `(N, 13)` (3 translation, 6 rotation, focal_length, sensor_width,
c_x, c_y). c_x, c_y).
""" """
opencv_R, opencv_T, opencv_K = opencv_from_cameras_projection(cameras, image_size) return _pulsar_from_cameras_projection(cameras, image_size)
return pulsar_from_opencv_projection(opencv_R, opencv_T, opencv_K, image_size)