mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Update cameras to accept projection matrix as input
Summary: To initialize the Cameras class currently we require the principal point, focal length and other parameters to be specified from which we calculate the intrinsic matrix. In some cases the matrix might be directly available e.g. from a dataset and the associated metadata for an image. Reviewed By: nikhilaravi Differential Revision: D24489509 fbshipit-source-id: 1b411f19c5f6c8074bcfbf613f3339d5e242c119
This commit is contained in:
parent
6f4697bc1b
commit
36fb257ef1
@ -70,7 +70,7 @@ class CamerasBase(TensorProperties):
|
||||
arguments to override the default values set in `__init__`.
|
||||
|
||||
Return:
|
||||
P: a `Transform3d` object which represents a batch of projection
|
||||
a `Transform3d` object which represents a batch of projection
|
||||
matrices of shape (N, 3, 3)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
@ -333,10 +333,10 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
degrees: bool = True,
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
__init__(self, znear, zfar, aspect_ratio, fov, degrees, R, T, device) -> None # noqa
|
||||
|
||||
Args:
|
||||
znear: near clipping plane of the view frustrum.
|
||||
@ -346,6 +346,8 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
degrees: bool, set to True if fov is specified in degrees.
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need znear, zfar, fov, aspect_ratio, degrees
|
||||
device: torch.device or string
|
||||
"""
|
||||
# The initializer formats all inputs to torch tensors and broadcasts
|
||||
@ -358,55 +360,29 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
fov=fov,
|
||||
R=R,
|
||||
T=T,
|
||||
K=K,
|
||||
)
|
||||
|
||||
# No need to convert to tensor or broadcast.
|
||||
self.degrees = degrees
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
def compute_projection_matrix(
|
||||
self, znear, zfar, fov, aspect_ratio, degrees
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the perpective projection matrix with a symmetric
|
||||
viewing frustrum. Use column major order.
|
||||
The viewing frustrum will be projected into ndc, s.t.
|
||||
(max_x, max_y) -> (+1, +1)
|
||||
(min_x, min_y) -> (-1, -1)
|
||||
Compute the calibration matrix K of shape (N, 4, 4)
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the projection can be passed in as keyword
|
||||
arguments to override the default values set in `__init__`.
|
||||
znear: near clipping plane of the view frustrum.
|
||||
zfar: far clipping plane of the view frustrum.
|
||||
fov: field of view angle of the camera.
|
||||
aspect_ratio: ratio of screen_width/screen_height.
|
||||
degrees: bool, set to True if fov is specified in degrees.
|
||||
|
||||
Return:
|
||||
P: a Transform3d object which represents a batch of projection
|
||||
matrices of shape (N, 4, 4)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
h1 = (max_y + min_y)/(max_y - min_y)
|
||||
w1 = (max_x + min_x)/(max_x - min_x)
|
||||
tanhalffov = tan((fov/2))
|
||||
s1 = 1/tanhalffov
|
||||
s2 = 1/(tanhalffov * (aspect_ratio))
|
||||
|
||||
# To map z to the range [0, 1] use:
|
||||
f1 = far / (far - near)
|
||||
f2 = -(far * near) / (far - near)
|
||||
|
||||
# Projection matrix
|
||||
P = [
|
||||
[s1, 0, w1, 0],
|
||||
[0, s2, h1, 0],
|
||||
[0, 0, f1, f2],
|
||||
[0, 0, 1, 0],
|
||||
]
|
||||
Returns:
|
||||
torch.floatTensor of the calibration matrix with shape (N, 4, 4)
|
||||
"""
|
||||
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
|
||||
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
|
||||
fov = kwargs.get("fov", self.fov) # pyre-ignore[16]
|
||||
# pyre-ignore[16]
|
||||
aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio)
|
||||
degrees = kwargs.get("degrees", self.degrees)
|
||||
|
||||
P = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
|
||||
K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32)
|
||||
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
|
||||
if degrees:
|
||||
fov = (np.pi / 180) * fov
|
||||
@ -427,21 +403,73 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
# so the so the z sign is 1.0.
|
||||
z_sign = 1.0
|
||||
|
||||
P[:, 0, 0] = 2.0 * znear / (max_x - min_x)
|
||||
P[:, 1, 1] = 2.0 * znear / (max_y - min_y)
|
||||
P[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
|
||||
P[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
|
||||
P[:, 3, 2] = z_sign * ones
|
||||
K[:, 0, 0] = 2.0 * znear / (max_x - min_x)
|
||||
K[:, 1, 1] = 2.0 * znear / (max_y - min_y)
|
||||
K[:, 0, 2] = (max_x + min_x) / (max_x - min_x)
|
||||
K[:, 1, 2] = (max_y + min_y) / (max_y - min_y)
|
||||
K[:, 3, 2] = z_sign * ones
|
||||
|
||||
# NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point
|
||||
# is at the near clipping plane and z = 1 when the point is at the far
|
||||
# clipping plane.
|
||||
P[:, 2, 2] = z_sign * zfar / (zfar - znear)
|
||||
P[:, 2, 3] = -(zfar * znear) / (zfar - znear)
|
||||
K[:, 2, 2] = z_sign * zfar / (zfar - znear)
|
||||
K[:, 2, 3] = -(zfar * znear) / (zfar - znear)
|
||||
|
||||
return K
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the perpective projection matrix with a symmetric
|
||||
viewing frustrum. Use column major order.
|
||||
The viewing frustrum will be projected into ndc, s.t.
|
||||
(max_x, max_y) -> (+1, +1)
|
||||
(min_x, min_y) -> (-1, -1)
|
||||
|
||||
Args:
|
||||
**kwargs: parameters for the projection can be passed in as keyword
|
||||
arguments to override the default values set in `__init__`.
|
||||
|
||||
Return:
|
||||
a Transform3d object which represents a batch of projection
|
||||
matrices of shape (N, 4, 4)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
h1 = (max_y + min_y)/(max_y - min_y)
|
||||
w1 = (max_x + min_x)/(max_x - min_x)
|
||||
tanhalffov = tan((fov/2))
|
||||
s1 = 1/tanhalffov
|
||||
s2 = 1/(tanhalffov * (aspect_ratio))
|
||||
|
||||
# To map z to the range [0, 1] use:
|
||||
f1 = far / (far - near)
|
||||
f2 = -(far * near) / (far - near)
|
||||
|
||||
# Projection matrix
|
||||
K = [
|
||||
[s1, 0, w1, 0],
|
||||
[0, s2, h1, 0],
|
||||
[0, 0, f1, f2],
|
||||
[0, 0, 1, 0],
|
||||
]
|
||||
"""
|
||||
K = kwargs.get("K", self.K) # pyre-ignore[16]
|
||||
if K is not None:
|
||||
if K.shape != (self._N, 4, 4):
|
||||
msg = "Expected K to have shape of (%r, 4, 4)"
|
||||
raise ValueError(msg % (self._N))
|
||||
else:
|
||||
K = self.compute_projection_matrix(
|
||||
kwargs.get("znear", self.znear), # pyre-ignore[16]
|
||||
kwargs.get("zfar", self.zfar), # pyre-ignore[16]
|
||||
kwargs.get("fov", self.fov), # pyre-ignore[16]
|
||||
kwargs.get("aspect_ratio", self.aspect_ratio), # pyre-ignore[16]
|
||||
kwargs.get("degrees", self.degrees),
|
||||
)
|
||||
|
||||
# Transpose the projection matrix as PyTorch3D transforms use row vectors.
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = P.transpose(1, 2).contiguous()
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
@ -473,12 +501,12 @@ class FoVPerspectiveCameras(CamerasBase):
|
||||
xy_sdepth = xy_depth
|
||||
else:
|
||||
# parse out important values from the projection matrix
|
||||
P_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix()
|
||||
# parse out f1, f2 from P_matrix
|
||||
K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix()
|
||||
# parse out f1, f2 from K_matrix
|
||||
unsqueeze_shape = [1] * xy_depth.dim()
|
||||
unsqueeze_shape[0] = P_matrix.shape[0]
|
||||
f1 = P_matrix[:, 2, 2].reshape(unsqueeze_shape)
|
||||
f2 = P_matrix[:, 3, 2].reshape(unsqueeze_shape)
|
||||
unsqueeze_shape[0] = K_matrix.shape[0]
|
||||
f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape)
|
||||
f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape)
|
||||
# get the scaled depth
|
||||
sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3]
|
||||
# concatenate xy + scaled depth
|
||||
@ -545,10 +573,10 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
scale_xyz=((1.0, 1.0, 1.0),), # (1, 3)
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
__init__(self, znear, zfar, max_y, min_y, max_x, min_x, scale_xyz, R, T, device) -> None # noqa
|
||||
|
||||
Args:
|
||||
znear: near clipping plane of the view frustrum.
|
||||
@ -560,6 +588,8 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
scale_xyz: scale factors for each axis of shape (N, 3).
|
||||
R: Rotation matrix of shape (N, 3, 3).
|
||||
T: Translation of shape (N, 3).
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need znear, zfar, max_y, min_y, max_x, min_x, scale_xyz
|
||||
device: torch.device or string.
|
||||
|
||||
Only need to set min_x, max_x, min_y, max_y for viewing frustrums
|
||||
@ -578,8 +608,44 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
scale_xyz=scale_xyz,
|
||||
R=R,
|
||||
T=T,
|
||||
K=K,
|
||||
)
|
||||
|
||||
def compute_projection_matrix(
|
||||
self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the calibration matrix K of shape (N, 4, 4)
|
||||
|
||||
Args:
|
||||
znear: near clipping plane of the view frustrum.
|
||||
zfar: far clipping plane of the view frustrum.
|
||||
max_x: maximum x coordinate of the frustrum.
|
||||
min_x: minumum x coordinage of the frustrum
|
||||
max_y: maximum y coordinate of the frustrum.
|
||||
min_y: minimum y coordinate of the frustrum.
|
||||
scale_xyz: scale factors for each axis of shape (N, 3).
|
||||
"""
|
||||
K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
|
||||
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
|
||||
# NOTE: OpenGL flips handedness of coordinate system between camera
|
||||
# space and NDC space so z sign is -ve. In PyTorch3D we maintain a
|
||||
# right handed coordinate system throughout.
|
||||
z_sign = +1.0
|
||||
|
||||
K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
|
||||
K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
|
||||
K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
|
||||
K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
|
||||
K[:, 3, 3] = ones
|
||||
|
||||
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
|
||||
# the OpenGL z normalization to [-1, 1]
|
||||
K[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
|
||||
K[:, 2, 3] = -znear / (zfar - znear)
|
||||
|
||||
return K
|
||||
|
||||
def get_projection_transform(self, **kwargs) -> Transform3d:
|
||||
"""
|
||||
Calculate the orthographic projection matrix.
|
||||
@ -589,7 +655,7 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
**kwargs: parameters for the projection can be passed in to
|
||||
override the default values set in __init__.
|
||||
Return:
|
||||
P: a Transform3d object which represents a batch of projection
|
||||
a Transform3d object which represents a batch of projection
|
||||
matrices of shape (N, 4, 4)
|
||||
|
||||
.. code-block:: python
|
||||
@ -601,41 +667,31 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
mix_y = (max_y + min_y) / (max_y - min_y)
|
||||
mid_z = (far + near) / (far−near)
|
||||
|
||||
P = [
|
||||
K = [
|
||||
[scale_x, 0, 0, -mid_x],
|
||||
[0, scale_y, 0, -mix_y],
|
||||
[0, 0, -scale_z, -mid_z],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
"""
|
||||
znear = kwargs.get("znear", self.znear) # pyre-ignore[16]
|
||||
zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16]
|
||||
max_x = kwargs.get("max_x", self.max_x) # pyre-ignore[16]
|
||||
min_x = kwargs.get("min_x", self.min_x) # pyre-ignore[16]
|
||||
max_y = kwargs.get("max_y", self.max_y) # pyre-ignore[16]
|
||||
min_y = kwargs.get("min_y", self.min_y) # pyre-ignore[16]
|
||||
scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16]
|
||||
|
||||
P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device)
|
||||
ones = torch.ones((self._N), dtype=torch.float32, device=self.device)
|
||||
# NOTE: OpenGL flips handedness of coordinate system between camera
|
||||
# space and NDC space so z sign is -ve. In PyTorch3D we maintain a
|
||||
# right handed coordinate system throughout.
|
||||
z_sign = +1.0
|
||||
|
||||
P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0]
|
||||
P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1]
|
||||
P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x)
|
||||
P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y)
|
||||
P[:, 3, 3] = ones
|
||||
|
||||
# NOTE: This maps the z coordinate to the range [0, 1] and replaces the
|
||||
# the OpenGL z normalization to [-1, 1]
|
||||
P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2]
|
||||
P[:, 2, 3] = -znear / (zfar - znear)
|
||||
K = kwargs.get("K", self.K) # pyre-ignore[16]
|
||||
if K is not None:
|
||||
if K.shape != (self._N, 4, 4):
|
||||
msg = "Expected K to have shape of (%r, 4, 4)"
|
||||
raise ValueError(msg % (self._N))
|
||||
else:
|
||||
K = self.compute_projection_matrix(
|
||||
kwargs.get("znear", self.znear), # pyre-ignore[16]
|
||||
kwargs.get("zfar", self.zfar), # pyre-ignore[16]
|
||||
kwargs.get("max_x", self.max_x), # pyre-ignore[16]
|
||||
kwargs.get("min_x", self.min_x), # pyre-ignore[16]
|
||||
kwargs.get("max_y", self.max_y), # pyre-ignore[16]
|
||||
kwargs.get("min_y", self.min_y), # pyre-ignore[16]
|
||||
kwargs.get("scale_xyz", self.scale_xyz), # pyre-ignore[16]
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = P.transpose(1, 2).contiguous()
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
@ -666,11 +722,11 @@ class FoVOrthographicCameras(CamerasBase):
|
||||
xy_sdepth = xy_depth
|
||||
else:
|
||||
# we have to obtain the scaled depth first
|
||||
P = self.get_projection_transform(**kwargs).get_matrix()
|
||||
unsqueeze_shape = [1] * P.dim()
|
||||
unsqueeze_shape[0] = P.shape[0]
|
||||
mid_z = P[:, 3, 2].reshape(unsqueeze_shape)
|
||||
scale_z = P[:, 2, 2].reshape(unsqueeze_shape)
|
||||
K = self.get_projection_transform(**kwargs).get_matrix()
|
||||
unsqueeze_shape = [1] * K.dim()
|
||||
unsqueeze_shape[0] = K.shape[0]
|
||||
mid_z = K[:, 3, 2].reshape(unsqueeze_shape)
|
||||
scale_z = K[:, 2, 2].reshape(unsqueeze_shape)
|
||||
scaled_depth = scale_z * xy_depth[..., 2:3] + mid_z
|
||||
# cat xy and scaled depth
|
||||
xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1)
|
||||
@ -752,11 +808,11 @@ class PerspectiveCameras(CamerasBase):
|
||||
principal_point=((0.0, 0.0),),
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
image_size=((-1, -1),),
|
||||
):
|
||||
"""
|
||||
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
|
||||
|
||||
Args:
|
||||
focal_length: Focal length of the camera in world units.
|
||||
@ -767,6 +823,9 @@ class PerspectiveCameras(CamerasBase):
|
||||
A tensor of shape (N, 2).
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need focal_length, principal_point, image_size
|
||||
|
||||
device: torch.device or string
|
||||
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
|
||||
is provided, the camera parameters are assumed to be in screen
|
||||
@ -782,6 +841,7 @@ class PerspectiveCameras(CamerasBase):
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
K=K,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
@ -795,7 +855,7 @@ class PerspectiveCameras(CamerasBase):
|
||||
arguments to override the default values set in __init__.
|
||||
|
||||
Returns:
|
||||
P: A `Transform3d` object with a batch of `N` projection transforms.
|
||||
A `Transform3d` object with a batch of `N` projection transforms.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -804,35 +864,35 @@ class PerspectiveCameras(CamerasBase):
|
||||
px = principal_point[:, 0]
|
||||
py = principal_point[:, 1]
|
||||
|
||||
P = [
|
||||
K = [
|
||||
[fx, 0, px, 0],
|
||||
[0, fy, py, 0],
|
||||
[0, 0, 0, 1],
|
||||
[0, 0, 1, 0],
|
||||
]
|
||||
"""
|
||||
# pyre-ignore[16]
|
||||
principal_point = kwargs.get("principal_point", self.principal_point)
|
||||
# pyre-ignore[16]
|
||||
focal_length = kwargs.get("focal_length", self.focal_length)
|
||||
# pyre-ignore[16]
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
K = kwargs.get("K", self.K) # pyre-ignore[16]
|
||||
if K is not None:
|
||||
if K.shape != (self._N, 4, 4):
|
||||
msg = "Expected K to have shape of (%r, 4, 4)"
|
||||
raise ValueError(msg % (self._N))
|
||||
else:
|
||||
# pyre-ignore[16]
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
image_size = image_size if image_size[0][0] > 0 else None
|
||||
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
in_screen = image_size[0][0] > 0
|
||||
image_size = image_size if in_screen else None
|
||||
|
||||
P = _get_sfm_calibration_matrix(
|
||||
self._N,
|
||||
self.device,
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
K = _get_sfm_calibration_matrix(
|
||||
self._N,
|
||||
self.device,
|
||||
kwargs.get("focal_length", self.focal_length), # pyre-ignore[16]
|
||||
kwargs.get("principal_point", self.principal_point), # pyre-ignore[16]
|
||||
orthographic=False,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = P.transpose(1, 2).contiguous()
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
@ -911,11 +971,11 @@ class OrthographicCameras(CamerasBase):
|
||||
principal_point=((0.0, 0.0),),
|
||||
R=_R,
|
||||
T=_T,
|
||||
K=None,
|
||||
device="cpu",
|
||||
image_size=((-1, -1),),
|
||||
):
|
||||
"""
|
||||
__init__(self, focal_length, principal_point, R, T, device, image_size) -> None
|
||||
|
||||
Args:
|
||||
focal_length: Focal length of the camera in world units.
|
||||
@ -926,6 +986,8 @@ class OrthographicCameras(CamerasBase):
|
||||
A tensor of shape (N, 2).
|
||||
R: Rotation matrix of shape (N, 3, 3)
|
||||
T: Translation matrix of shape (N, 3)
|
||||
K: (optional) A calibration matrix of shape (N, 4, 4)
|
||||
If provided, don't need focal_length, principal_point, image_size
|
||||
device: torch.device or string
|
||||
image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0
|
||||
is provided, the camera parameters are assumed to be in screen
|
||||
@ -941,6 +1003,7 @@ class OrthographicCameras(CamerasBase):
|
||||
principal_point=principal_point,
|
||||
R=R,
|
||||
T=T,
|
||||
K=K,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
@ -954,7 +1017,7 @@ class OrthographicCameras(CamerasBase):
|
||||
arguments to override the default values set in __init__.
|
||||
|
||||
Returns:
|
||||
P: A `Transform3d` object with a batch of `N` projection transforms.
|
||||
A `Transform3d` object with a batch of `N` projection transforms.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -963,35 +1026,35 @@ class OrthographicCameras(CamerasBase):
|
||||
px = principal_point[:,0]
|
||||
py = principal_point[:,1]
|
||||
|
||||
P = [
|
||||
K = [
|
||||
[fx, 0, 0, px],
|
||||
[0, fy, 0, py],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 0, 1],
|
||||
]
|
||||
"""
|
||||
# pyre-ignore[16]
|
||||
principal_point = kwargs.get("principal_point", self.principal_point)
|
||||
# pyre-ignore[16]
|
||||
focal_length = kwargs.get("focal_length", self.focal_length)
|
||||
# pyre-ignore[16]
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
K = kwargs.get("K", self.K) # pyre-ignore[16]
|
||||
if K is not None:
|
||||
if K.shape != (self._N, 4, 4):
|
||||
msg = "Expected K to have shape of (%r, 4, 4)"
|
||||
raise ValueError(msg % (self._N))
|
||||
else:
|
||||
# pyre-ignore[16]
|
||||
image_size = kwargs.get("image_size", self.image_size)
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
image_size = image_size if image_size[0][0] > 0 else None
|
||||
|
||||
# if imwidth > 0, parameters are in screen space
|
||||
in_screen = image_size[0][0] > 0
|
||||
image_size = image_size if in_screen else None
|
||||
|
||||
P = _get_sfm_calibration_matrix(
|
||||
self._N,
|
||||
self.device,
|
||||
focal_length,
|
||||
principal_point,
|
||||
orthographic=True,
|
||||
image_size=image_size,
|
||||
)
|
||||
K = _get_sfm_calibration_matrix(
|
||||
self._N,
|
||||
self.device,
|
||||
kwargs.get("focal_length", self.focal_length), # pyre-ignore[16]
|
||||
kwargs.get("principal_point", self.principal_point), # pyre-ignore[16]
|
||||
orthographic=True,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
transform = Transform3d(device=self.device)
|
||||
transform._matrix = P.transpose(1, 2).contiguous()
|
||||
transform._matrix = K.transpose(1, 2).contiguous()
|
||||
return transform
|
||||
|
||||
def unproject_points(
|
||||
|
@ -106,7 +106,7 @@ class TensorProperties(nn.Module):
|
||||
# set as attributes anything else e.g. strings, bools
|
||||
args_to_broadcast = {}
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, (str, bool)):
|
||||
if v is None or isinstance(v, (str, bool)):
|
||||
setattr(self, k, v)
|
||||
elif isinstance(v, BROADCAST_TYPES):
|
||||
args_to_broadcast[k] = v
|
||||
|
@ -449,6 +449,20 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
|
||||
class TestCamerasCommon(TestCaseMixin, unittest.TestCase):
|
||||
def test_K(self, batch_size=10):
|
||||
T = torch.randn(batch_size, 3)
|
||||
R = random_rotations(batch_size)
|
||||
K = torch.randn(batch_size, 4, 4)
|
||||
for cam_type in (
|
||||
FoVOrthographicCameras,
|
||||
FoVPerspectiveCameras,
|
||||
OrthographicCameras,
|
||||
PerspectiveCameras,
|
||||
):
|
||||
cam = cam_type(R=R, T=T, K=K)
|
||||
cam.get_projection_transform()
|
||||
# Just checking that we don't crash or anything
|
||||
|
||||
def test_view_transform_class_method(self):
|
||||
T = torch.tensor([0.0, 0.0, -1.0], requires_grad=True).view(1, -1)
|
||||
R = look_at_rotation(T)
|
||||
|
Loading…
x
Reference in New Issue
Block a user