diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 117a98fe..b4b8c6af 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -70,7 +70,7 @@ class CamerasBase(TensorProperties): arguments to override the default values set in `__init__`. Return: - P: a `Transform3d` object which represents a batch of projection + a `Transform3d` object which represents a batch of projection matrices of shape (N, 3, 3) """ raise NotImplementedError() @@ -333,10 +333,10 @@ class FoVPerspectiveCameras(CamerasBase): degrees: bool = True, R=_R, T=_T, + K=None, device="cpu", ): """ - __init__(self, znear, zfar, aspect_ratio, fov, degrees, R, T, device) -> None # noqa Args: znear: near clipping plane of the view frustrum. @@ -346,6 +346,8 @@ class FoVPerspectiveCameras(CamerasBase): degrees: bool, set to True if fov is specified in degrees. R: Rotation matrix of shape (N, 3, 3) T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need znear, zfar, fov, aspect_ratio, degrees device: torch.device or string """ # The initializer formats all inputs to torch tensors and broadcasts @@ -358,55 +360,29 @@ class FoVPerspectiveCameras(CamerasBase): fov=fov, R=R, T=T, + K=K, ) # No need to convert to tensor or broadcast. self.degrees = degrees - def get_projection_transform(self, **kwargs) -> Transform3d: + def compute_projection_matrix( + self, znear, zfar, fov, aspect_ratio, degrees + ) -> torch.Tensor: """ - Calculate the perpective projection matrix with a symmetric - viewing frustrum. Use column major order. - The viewing frustrum will be projected into ndc, s.t. - (max_x, max_y) -> (+1, +1) - (min_x, min_y) -> (-1, -1) + Compute the calibration matrix K of shape (N, 4, 4) Args: - **kwargs: parameters for the projection can be passed in as keyword - arguments to override the default values set in `__init__`. + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + fov: field of view angle of the camera. + aspect_ratio: ratio of screen_width/screen_height. + degrees: bool, set to True if fov is specified in degrees. - Return: - P: a Transform3d object which represents a batch of projection - matrices of shape (N, 4, 4) - - .. code-block:: python - - h1 = (max_y + min_y)/(max_y - min_y) - w1 = (max_x + min_x)/(max_x - min_x) - tanhalffov = tan((fov/2)) - s1 = 1/tanhalffov - s2 = 1/(tanhalffov * (aspect_ratio)) - - # To map z to the range [0, 1] use: - f1 = far / (far - near) - f2 = -(far * near) / (far - near) - - # Projection matrix - P = [ - [s1, 0, w1, 0], - [0, s2, h1, 0], - [0, 0, f1, f2], - [0, 0, 1, 0], - ] + Returns: + torch.floatTensor of the calibration matrix with shape (N, 4, 4) """ - znear = kwargs.get("znear", self.znear) # pyre-ignore[16] - zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16] - fov = kwargs.get("fov", self.fov) # pyre-ignore[16] - # pyre-ignore[16] - aspect_ratio = kwargs.get("aspect_ratio", self.aspect_ratio) - degrees = kwargs.get("degrees", self.degrees) - - P = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) + K = torch.zeros((self._N, 4, 4), device=self.device, dtype=torch.float32) ones = torch.ones((self._N), dtype=torch.float32, device=self.device) if degrees: fov = (np.pi / 180) * fov @@ -427,21 +403,73 @@ class FoVPerspectiveCameras(CamerasBase): # so the so the z sign is 1.0. z_sign = 1.0 - P[:, 0, 0] = 2.0 * znear / (max_x - min_x) - P[:, 1, 1] = 2.0 * znear / (max_y - min_y) - P[:, 0, 2] = (max_x + min_x) / (max_x - min_x) - P[:, 1, 2] = (max_y + min_y) / (max_y - min_y) - P[:, 3, 2] = z_sign * ones + K[:, 0, 0] = 2.0 * znear / (max_x - min_x) + K[:, 1, 1] = 2.0 * znear / (max_y - min_y) + K[:, 0, 2] = (max_x + min_x) / (max_x - min_x) + K[:, 1, 2] = (max_y + min_y) / (max_y - min_y) + K[:, 3, 2] = z_sign * ones # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point # is at the near clipping plane and z = 1 when the point is at the far # clipping plane. - P[:, 2, 2] = z_sign * zfar / (zfar - znear) - P[:, 2, 3] = -(zfar * znear) / (zfar - znear) + K[:, 2, 2] = z_sign * zfar / (zfar - znear) + K[:, 2, 3] = -(zfar * znear) / (zfar - znear) + + return K + + def get_projection_transform(self, **kwargs) -> Transform3d: + """ + Calculate the perpective projection matrix with a symmetric + viewing frustrum. Use column major order. + The viewing frustrum will be projected into ndc, s.t. + (max_x, max_y) -> (+1, +1) + (min_x, min_y) -> (-1, -1) + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + + Return: + a Transform3d object which represents a batch of projection + matrices of shape (N, 4, 4) + + .. code-block:: python + + h1 = (max_y + min_y)/(max_y - min_y) + w1 = (max_x + min_x)/(max_x - min_x) + tanhalffov = tan((fov/2)) + s1 = 1/tanhalffov + s2 = 1/(tanhalffov * (aspect_ratio)) + + # To map z to the range [0, 1] use: + f1 = far / (far - near) + f2 = -(far * near) / (far - near) + + # Projection matrix + K = [ + [s1, 0, w1, 0], + [0, s2, h1, 0], + [0, 0, f1, f2], + [0, 0, 1, 0], + ] + """ + K = kwargs.get("K", self.K) # pyre-ignore[16] + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = self.compute_projection_matrix( + kwargs.get("znear", self.znear), # pyre-ignore[16] + kwargs.get("zfar", self.zfar), # pyre-ignore[16] + kwargs.get("fov", self.fov), # pyre-ignore[16] + kwargs.get("aspect_ratio", self.aspect_ratio), # pyre-ignore[16] + kwargs.get("degrees", self.degrees), + ) # Transpose the projection matrix as PyTorch3D transforms use row vectors. transform = Transform3d(device=self.device) - transform._matrix = P.transpose(1, 2).contiguous() + transform._matrix = K.transpose(1, 2).contiguous() return transform def unproject_points( @@ -473,12 +501,12 @@ class FoVPerspectiveCameras(CamerasBase): xy_sdepth = xy_depth else: # parse out important values from the projection matrix - P_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix() - # parse out f1, f2 from P_matrix + K_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix() + # parse out f1, f2 from K_matrix unsqueeze_shape = [1] * xy_depth.dim() - unsqueeze_shape[0] = P_matrix.shape[0] - f1 = P_matrix[:, 2, 2].reshape(unsqueeze_shape) - f2 = P_matrix[:, 3, 2].reshape(unsqueeze_shape) + unsqueeze_shape[0] = K_matrix.shape[0] + f1 = K_matrix[:, 2, 2].reshape(unsqueeze_shape) + f2 = K_matrix[:, 3, 2].reshape(unsqueeze_shape) # get the scaled depth sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3] # concatenate xy + scaled depth @@ -545,10 +573,10 @@ class FoVOrthographicCameras(CamerasBase): scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) R=_R, T=_T, + K=None, device="cpu", ): """ - __init__(self, znear, zfar, max_y, min_y, max_x, min_x, scale_xyz, R, T, device) -> None # noqa Args: znear: near clipping plane of the view frustrum. @@ -560,6 +588,8 @@ class FoVOrthographicCameras(CamerasBase): scale_xyz: scale factors for each axis of shape (N, 3). R: Rotation matrix of shape (N, 3, 3). T: Translation of shape (N, 3). + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need znear, zfar, max_y, min_y, max_x, min_x, scale_xyz device: torch.device or string. Only need to set min_x, max_x, min_y, max_y for viewing frustrums @@ -578,8 +608,44 @@ class FoVOrthographicCameras(CamerasBase): scale_xyz=scale_xyz, R=R, T=T, + K=K, ) + def compute_projection_matrix( + self, znear, zfar, max_x, min_x, max_y, min_y, scale_xyz + ) -> torch.Tensor: + """ + Compute the calibration matrix K of shape (N, 4, 4) + + Args: + znear: near clipping plane of the view frustrum. + zfar: far clipping plane of the view frustrum. + max_x: maximum x coordinate of the frustrum. + min_x: minumum x coordinage of the frustrum + max_y: maximum y coordinate of the frustrum. + min_y: minimum y coordinate of the frustrum. + scale_xyz: scale factors for each axis of shape (N, 3). + """ + K = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device) + ones = torch.ones((self._N), dtype=torch.float32, device=self.device) + # NOTE: OpenGL flips handedness of coordinate system between camera + # space and NDC space so z sign is -ve. In PyTorch3D we maintain a + # right handed coordinate system throughout. + z_sign = +1.0 + + K[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0] + K[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1] + K[:, 0, 3] = -(max_x + min_x) / (max_x - min_x) + K[:, 1, 3] = -(max_y + min_y) / (max_y - min_y) + K[:, 3, 3] = ones + + # NOTE: This maps the z coordinate to the range [0, 1] and replaces the + # the OpenGL z normalization to [-1, 1] + K[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2] + K[:, 2, 3] = -znear / (zfar - znear) + + return K + def get_projection_transform(self, **kwargs) -> Transform3d: """ Calculate the orthographic projection matrix. @@ -589,7 +655,7 @@ class FoVOrthographicCameras(CamerasBase): **kwargs: parameters for the projection can be passed in to override the default values set in __init__. Return: - P: a Transform3d object which represents a batch of projection + a Transform3d object which represents a batch of projection matrices of shape (N, 4, 4) .. code-block:: python @@ -601,41 +667,31 @@ class FoVOrthographicCameras(CamerasBase): mix_y = (max_y + min_y) / (max_y - min_y) mid_z = (far + near) / (far−near) - P = [ + K = [ [scale_x, 0, 0, -mid_x], [0, scale_y, 0, -mix_y], [0, 0, -scale_z, -mid_z], [0, 0, 0, 1], ] """ - znear = kwargs.get("znear", self.znear) # pyre-ignore[16] - zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16] - max_x = kwargs.get("max_x", self.max_x) # pyre-ignore[16] - min_x = kwargs.get("min_x", self.min_x) # pyre-ignore[16] - max_y = kwargs.get("max_y", self.max_y) # pyre-ignore[16] - min_y = kwargs.get("min_y", self.min_y) # pyre-ignore[16] - scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16] - - P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device) - ones = torch.ones((self._N), dtype=torch.float32, device=self.device) - # NOTE: OpenGL flips handedness of coordinate system between camera - # space and NDC space so z sign is -ve. In PyTorch3D we maintain a - # right handed coordinate system throughout. - z_sign = +1.0 - - P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0] - P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1] - P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x) - P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y) - P[:, 3, 3] = ones - - # NOTE: This maps the z coordinate to the range [0, 1] and replaces the - # the OpenGL z normalization to [-1, 1] - P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2] - P[:, 2, 3] = -znear / (zfar - znear) + K = kwargs.get("K", self.K) # pyre-ignore[16] + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + K = self.compute_projection_matrix( + kwargs.get("znear", self.znear), # pyre-ignore[16] + kwargs.get("zfar", self.zfar), # pyre-ignore[16] + kwargs.get("max_x", self.max_x), # pyre-ignore[16] + kwargs.get("min_x", self.min_x), # pyre-ignore[16] + kwargs.get("max_y", self.max_y), # pyre-ignore[16] + kwargs.get("min_y", self.min_y), # pyre-ignore[16] + kwargs.get("scale_xyz", self.scale_xyz), # pyre-ignore[16] + ) transform = Transform3d(device=self.device) - transform._matrix = P.transpose(1, 2).contiguous() + transform._matrix = K.transpose(1, 2).contiguous() return transform def unproject_points( @@ -666,11 +722,11 @@ class FoVOrthographicCameras(CamerasBase): xy_sdepth = xy_depth else: # we have to obtain the scaled depth first - P = self.get_projection_transform(**kwargs).get_matrix() - unsqueeze_shape = [1] * P.dim() - unsqueeze_shape[0] = P.shape[0] - mid_z = P[:, 3, 2].reshape(unsqueeze_shape) - scale_z = P[:, 2, 2].reshape(unsqueeze_shape) + K = self.get_projection_transform(**kwargs).get_matrix() + unsqueeze_shape = [1] * K.dim() + unsqueeze_shape[0] = K.shape[0] + mid_z = K[:, 3, 2].reshape(unsqueeze_shape) + scale_z = K[:, 2, 2].reshape(unsqueeze_shape) scaled_depth = scale_z * xy_depth[..., 2:3] + mid_z # cat xy and scaled depth xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1) @@ -752,11 +808,11 @@ class PerspectiveCameras(CamerasBase): principal_point=((0.0, 0.0),), R=_R, T=_T, + K=None, device="cpu", image_size=((-1, -1),), ): """ - __init__(self, focal_length, principal_point, R, T, device, image_size) -> None Args: focal_length: Focal length of the camera in world units. @@ -767,6 +823,9 @@ class PerspectiveCameras(CamerasBase): A tensor of shape (N, 2). R: Rotation matrix of shape (N, 3, 3) T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need focal_length, principal_point, image_size + device: torch.device or string image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0 is provided, the camera parameters are assumed to be in screen @@ -782,6 +841,7 @@ class PerspectiveCameras(CamerasBase): principal_point=principal_point, R=R, T=T, + K=K, image_size=image_size, ) @@ -795,7 +855,7 @@ class PerspectiveCameras(CamerasBase): arguments to override the default values set in __init__. Returns: - P: A `Transform3d` object with a batch of `N` projection transforms. + A `Transform3d` object with a batch of `N` projection transforms. .. code-block:: python @@ -804,35 +864,35 @@ class PerspectiveCameras(CamerasBase): px = principal_point[:, 0] py = principal_point[:, 1] - P = [ + K = [ [fx, 0, px, 0], [0, fy, py, 0], [0, 0, 0, 1], [0, 0, 1, 0], ] """ - # pyre-ignore[16] - principal_point = kwargs.get("principal_point", self.principal_point) - # pyre-ignore[16] - focal_length = kwargs.get("focal_length", self.focal_length) - # pyre-ignore[16] - image_size = kwargs.get("image_size", self.image_size) + K = kwargs.get("K", self.K) # pyre-ignore[16] + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + # pyre-ignore[16] + image_size = kwargs.get("image_size", self.image_size) + # if imwidth > 0, parameters are in screen space + image_size = image_size if image_size[0][0] > 0 else None - # if imwidth > 0, parameters are in screen space - in_screen = image_size[0][0] > 0 - image_size = image_size if in_screen else None - - P = _get_sfm_calibration_matrix( - self._N, - self.device, - focal_length, - principal_point, - orthographic=False, - image_size=image_size, - ) + K = _get_sfm_calibration_matrix( + self._N, + self.device, + kwargs.get("focal_length", self.focal_length), # pyre-ignore[16] + kwargs.get("principal_point", self.principal_point), # pyre-ignore[16] + orthographic=False, + image_size=image_size, + ) transform = Transform3d(device=self.device) - transform._matrix = P.transpose(1, 2).contiguous() + transform._matrix = K.transpose(1, 2).contiguous() return transform def unproject_points( @@ -911,11 +971,11 @@ class OrthographicCameras(CamerasBase): principal_point=((0.0, 0.0),), R=_R, T=_T, + K=None, device="cpu", image_size=((-1, -1),), ): """ - __init__(self, focal_length, principal_point, R, T, device, image_size) -> None Args: focal_length: Focal length of the camera in world units. @@ -926,6 +986,8 @@ class OrthographicCameras(CamerasBase): A tensor of shape (N, 2). R: Rotation matrix of shape (N, 3, 3) T: Translation matrix of shape (N, 3) + K: (optional) A calibration matrix of shape (N, 4, 4) + If provided, don't need focal_length, principal_point, image_size device: torch.device or string image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0 is provided, the camera parameters are assumed to be in screen @@ -941,6 +1003,7 @@ class OrthographicCameras(CamerasBase): principal_point=principal_point, R=R, T=T, + K=K, image_size=image_size, ) @@ -954,7 +1017,7 @@ class OrthographicCameras(CamerasBase): arguments to override the default values set in __init__. Returns: - P: A `Transform3d` object with a batch of `N` projection transforms. + A `Transform3d` object with a batch of `N` projection transforms. .. code-block:: python @@ -963,35 +1026,35 @@ class OrthographicCameras(CamerasBase): px = principal_point[:,0] py = principal_point[:,1] - P = [ + K = [ [fx, 0, 0, px], [0, fy, 0, py], [0, 0, 1, 0], [0, 0, 0, 1], ] """ - # pyre-ignore[16] - principal_point = kwargs.get("principal_point", self.principal_point) - # pyre-ignore[16] - focal_length = kwargs.get("focal_length", self.focal_length) - # pyre-ignore[16] - image_size = kwargs.get("image_size", self.image_size) + K = kwargs.get("K", self.K) # pyre-ignore[16] + if K is not None: + if K.shape != (self._N, 4, 4): + msg = "Expected K to have shape of (%r, 4, 4)" + raise ValueError(msg % (self._N)) + else: + # pyre-ignore[16] + image_size = kwargs.get("image_size", self.image_size) + # if imwidth > 0, parameters are in screen space + image_size = image_size if image_size[0][0] > 0 else None - # if imwidth > 0, parameters are in screen space - in_screen = image_size[0][0] > 0 - image_size = image_size if in_screen else None - - P = _get_sfm_calibration_matrix( - self._N, - self.device, - focal_length, - principal_point, - orthographic=True, - image_size=image_size, - ) + K = _get_sfm_calibration_matrix( + self._N, + self.device, + kwargs.get("focal_length", self.focal_length), # pyre-ignore[16] + kwargs.get("principal_point", self.principal_point), # pyre-ignore[16] + orthographic=True, + image_size=image_size, + ) transform = Transform3d(device=self.device) - transform._matrix = P.transpose(1, 2).contiguous() + transform._matrix = K.transpose(1, 2).contiguous() return transform def unproject_points( diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 2122845c..b940bd21 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -106,7 +106,7 @@ class TensorProperties(nn.Module): # set as attributes anything else e.g. strings, bools args_to_broadcast = {} for k, v in kwargs.items(): - if isinstance(v, (str, bool)): + if v is None or isinstance(v, (str, bool)): setattr(self, k, v) elif isinstance(v, BROADCAST_TYPES): args_to_broadcast[k] = v diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 298301d2..c1ff86bc 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -449,6 +449,20 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase): class TestCamerasCommon(TestCaseMixin, unittest.TestCase): + def test_K(self, batch_size=10): + T = torch.randn(batch_size, 3) + R = random_rotations(batch_size) + K = torch.randn(batch_size, 4, 4) + for cam_type in ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + ): + cam = cam_type(R=R, T=T, K=K) + cam.get_projection_transform() + # Just checking that we don't crash or anything + def test_view_transform_class_method(self): T = torch.tensor([0.0, 0.0, -1.0], requires_grad=True).view(1, -1) R = look_at_rotation(T)