diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 9251f83b..0e8b44e9 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -9,6 +9,8 @@ from .blending import ( from .cameras import ( OpenGLOrthographicCameras, OpenGLPerspectiveCameras, + SfMOrthographicCameras, + SfMPerspectiveCameras, camera_position_from_spherical_angles, get_world_to_view_transform, look_at_rotation, diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index 6444cbae..dffc4b6b 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -16,7 +16,202 @@ r = np.expand_dims(np.eye(3), axis=0) # (1, 3, 3) t = np.expand_dims(np.zeros(3), axis=0) # (1, 3) -class OpenGLPerspectiveCameras(TensorProperties): +class CamerasBase(TensorProperties): + """ + `CamerasBase` implements a base class for all cameras. + + It defines methods that are common to all camera models: + - `get_camera_center` that returns the optical center of the camera in + world coordinates + - `get_world_to_view_transform` which returns a 3D transform from + world coordinates to the camera coordinates + - `get_full_projection_transform` which composes the projection + transform with the world-to-view transform + - `transform_points` which takes a set of input points and + projects them onto a 2D camera plane. + + For each new camera, one should implement the `get_projection_transform` + routine that returns the mapping from camera coordinates in world units + to the screen coordinates. + + Another useful function that is specific to each camera model is + `unproject_points` which sends points from screen coordinates back to + camera or world coordinates depending on the `world_coordinates` + boolean argument of the function. + """ + + def get_projection_transform(self): + """ + Calculate the projective transformation matrix. + + Args: + **kwargs: parameters for the projection can be passed in as keyword + arguments to override the default values set in `__init__`. + + Return: + P: a `Transform3d` object which represents a batch of projection + matrices of shape (N, 3, 3) + """ + raise NotImplementedError() + + def unproject_points(self): + """ + Transform input points in screen coodinates + to the world / camera coordinates. + + Each of the input points `xy_depth` of shape (..., 3) is + a concatenation of the x, y location and its depth. + + For instance, for an input 2D tensor of shape `(num_points, 3)` + `xy_depth` takes the following form: + `xy_depth[i] = [x[i], y[i], depth[i]]`, + for a each point at an index `i`. + + The following example demonstrates the relationship between + `transform_points` and `unproject_points`: + + .. code-block:: python + + cameras = # camera object derived from CamerasBase + xyz = # 3D points of shape (batch_size, num_points, 3) + # transform xyz to the camera coordinates + xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz) + # extract the depth of each point as the 3rd coord of xyz_cam + depth = xyz_cam[:, :, 2:] + # project the points xyz to the camera + xy = cameras.transform_points(xyz)[:, :, :2] + # append depth to xy + xy_depth = torch.cat((xy, depth), dim=2) + # unproject to the world coordinates + xyz_unproj_world = cameras.unproject_points(xy_depth, world_coordinates=True) + print(torch.allclose(xyz, xyz_unproj_world)) # True + # unproject to the camera coordinates + xyz_unproj = cameras.unproject_points(xy_depth, world_coordinates=False) + print(torch.allclose(xyz_cam, xyz_unproj)) # True + + Args: + xy_depth: torch tensor of shape (..., 3). + world_coordinates: If `True`, unprojects the points back to world + coordinates using the camera extrinsics `R` and `T`. + `False` ignores `R` and `T` and unprojects to + the camera coordinates. + + Returns + new_points: unprojected points with the same shape as `xy_depth`. + """ + raise NotImplementedError() + + def get_camera_center(self, **kwargs) -> torch.Tensor: + """ + Return the 3D location of the camera optical center + in the world coordinates. + + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + + Setting T here will update the values set in init as this + value may be needed later on in the rendering pipeline e.g. for + lighting calculations. + + Returns: + C: a batch of 3D locations of shape (N, 3) denoting + the locations of the center of each camera in the batch. + """ + w2v_trans = self.get_world_to_view_transform(**kwargs) + P = w2v_trans.inverse().get_matrix() + # the camera center is the translation component (the first 3 elements + # of the last row) of the inverted world-to-view + # transform (4x4 RT matrix) + C = P[:, 3, :3] + return C + + def get_world_to_view_transform(self, **kwargs) -> Transform3d: + """ + Return the world-to-view transform. + + Args: + **kwargs: parameters for the camera extrinsics can be passed in + as keyword arguments to override the default values + set in __init__. + + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + + Returns: + T: a Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + self.R = kwargs.get("R", self.R) # pyre-ignore[16] + self.T = kwargs.get("T", self.T) # pyre-ignore[16] + world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T) + return world_to_view_transform + + def get_full_projection_transform(self, **kwargs) -> Transform3d: + """ + Return the full world-to-screen transform composing the + world-to-view and view-to-screen transforms. + + Args: + **kwargs: parameters for the projection transforms can be passed in + as keyword arguments to override the default values + set in __init__. + + Setting R and T here will update the values set in init as these + values may be needed later on in the rendering pipeline e.g. for + lighting calculations. + + Returns: + T: a Transform3d object which represents a batch of transforms + of shape (N, 3, 3) + """ + self.R = kwargs.get("R", self.R) # pyre-ignore[16] + self.T = kwargs.get("T", self.T) # pyre-ignore[16] + world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) + view_to_screen_transform = self.get_projection_transform(**kwargs) + return world_to_view_transform.compose(view_to_screen_transform) + + def transform_points( + self, points, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transform input points from world to screen space. + + Args: + points: torch tensor of shape (..., 3). + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the screen space. Plese see + `transforms.Transform3D.transform_points` for details. + + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessivelly low numbers for points close to the + camera plane. + + Returns + new_points: transformed points with the same shape as the input. + """ + world_to_screen_transform = self.get_full_projection_transform(**kwargs) + return world_to_screen_transform.transform_points(points, eps=eps) + + def clone(self): + """ + Returns a copy of `self`. + """ + cam_type = type(self) + other = cam_type(device=self.device) + return super().clone(other) + + +######################## +# Specific camera classes +######################## + + +class OpenGLPerspectiveCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of projection matrices using the OpenGL convention for a perspective camera. @@ -97,7 +292,7 @@ class OpenGLPerspectiveCameras(TensorProperties): [s1, 0, w1, 0], [0, s2, h1, 0], [0, 0, f1, f2], - [0, 0, -1, 0], + [0, 0, 1, 0], ] """ znear = kwargs.get("znear", self.znear) # pyre-ignore[16] @@ -154,97 +349,52 @@ class OpenGLPerspectiveCameras(TensorProperties): transform._matrix = P.transpose(1, 2).contiguous() return transform - def clone(self): - other = OpenGLPerspectiveCameras(device=self.device) - return super().clone(other) - - def get_camera_center(self, **kwargs): - """ - Return the 3D location of the camera optical center - in the world coordinates. + def unproject_points( + self, + xy_depth: torch.Tensor, + world_coordinates: bool = True, + scaled_depth_input: bool = False, + **kwargs + ) -> torch.Tensor: + """>! + OpenGL cameras further allow for passing depth in world units + (`scaled_depth_input=False`) or in the [0, 1]-normalized units + (`scaled_depth_input=True`) Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting T here will update the values set in init as this - value may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - C: a batch of 3D locations of shape (N, 3) denoting - the locations of the center of each camera in the batch. + scaled_depth_input: If `True`, assumes the input depth is in + the [0, 1]-normalized units. If `False` the input depth is in + the world units. """ - w2v_trans = self.get_world_to_view_transform(**kwargs) - P = w2v_trans.inverse().get_matrix() - # the camera center is the translation component (the first 3 elements - # of the last row) of the inverted world-to-view - # transform (4x4 RT matrix) - C = P[:, 3, :3] - return C - def get_world_to_view_transform(self, **kwargs) -> Transform3d: - """ - Return the world-to-view transform. + # obtain the relevant transformation to screen + if world_coordinates: + to_screen_transform = self.get_full_projection_transform() + else: + to_screen_transform = self.get_projection_transform() - Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. + if scaled_depth_input: + # the input is scaled depth, so we don't have to do anything + xy_sdepth = xy_depth + else: + # parse out important values from the projection matrix + P_matrix = self.get_projection_transform(**kwargs.copy()).get_matrix() + # parse out f1, f2 from P_matrix + unsqueeze_shape = [1] * xy_depth.dim() + unsqueeze_shape[0] = P_matrix.shape[0] + f1 = P_matrix[:, 2, 2].reshape(unsqueeze_shape) + f2 = P_matrix[:, 3, 2].reshape(unsqueeze_shape) + # get the scaled depth + sdepth = (f1 * xy_depth[..., 2:3] + f2) / xy_depth[..., 2:3] + # concatenate xy + scaled depth + xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1) - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - T: a Transform3d object which represents a batch of transforms - of shape (N, 3, 3) - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T) - return world_to_view_transform - - def get_full_projection_transform(self, **kwargs) -> Transform3d: - """ - Return the full world-to-screen transform composing the - world-to-view and view-to-screen transforms. - - Args: - **kwargs: parameters for the projection transforms can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - T: a Transform3d object which represents a batch of transforms - of shape (N, 3, 3) - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) - view_to_screen_transform = self.get_projection_transform(**kwargs) - return world_to_view_transform.compose(view_to_screen_transform) - - def transform_points(self, points, **kwargs) -> torch.Tensor: - """ - Transform input points from world to screen space. - - Args: - points: torch tensor of shape (..., 3). - - Returns - new_points: transformed points with the same shape as the input. - """ - world_to_screen_transform = self.get_full_projection_transform(**kwargs) - return world_to_screen_transform.transform_points(points) + # unproject with inverse of the projection + unprojection_transform = to_screen_transform.inverse() + return unprojection_transform.transform_points(xy_sdepth) -class OpenGLOrthographicCameras(TensorProperties): +class OpenGLOrthographicCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of transformation matrices using the OpenGL convention for orthographic camera. @@ -360,98 +510,48 @@ class OpenGLOrthographicCameras(TensorProperties): transform._matrix = P.transpose(1, 2).contiguous() return transform - def clone(self): - other = OpenGLOrthographicCameras(device=self.device) - return super().clone(other) - - def get_camera_center(self, **kwargs): - """ - Return the 3D location of the camera optical center - in the world coordinates. + def unproject_points( + self, + xy_depth: torch.Tensor, + world_coordinates: bool = True, + scaled_depth_input: bool = False, + **kwargs + ) -> torch.Tensor: + """>! + OpenGL cameras further allow for passing depth in world units + (`scaled_depth_input=False`) or in the [0, 1]-normalized units + (`scaled_depth_input=True`) Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting T here will update the values set in init as this - value may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - - Returns: - C: a batch of 3D locations of shape (N, 3) denoting - the locations of the center of each camera in the batch. + scaled_depth_input: If `True`, assumes the input depth is in + the [0, 1]-normalized units. If `False` the input depth is in + the world units. """ - w2v_trans = self.get_world_to_view_transform(**kwargs) - P = w2v_trans.inverse().get_matrix() - # The camera center is the translation component (the first 3 elements - # of the last row) of the inverted world-to-view - # transform (4x4 RT matrix). - C = P[:, 3, :3] - return C - def get_world_to_view_transform(self, **kwargs) -> Transform3d: - """ - Return the world-to-view transform. + if world_coordinates: + to_screen_transform = self.get_full_projection_transform(**kwargs.copy()) + else: + to_screen_transform = self.get_projection_transform(**kwargs.copy()) - Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - T: a Transform3d object which represents a batch of transforms - of shape (N, 3, 3) - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T) - return world_to_view_transform - - def get_full_projection_transform(self, **kwargs) -> Transform3d: - """ - Return the full world-to-screen transform composing the - world-to-view and view-to-screen transforms. - - Args: - **kwargs: parameters for the projection transforms can be passed in - as keyword arguments to override the default values - set in `__init__`. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - T: a Transform3d object which represents a batch of transforms - of shape (N, 3, 3) - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) - view_to_screen_transform = self.get_projection_transform(**kwargs) - return world_to_view_transform.compose(view_to_screen_transform) - - def transform_points(self, points, **kwargs) -> torch.Tensor: - """ - Transform input points from world to screen space. - - Args: - points: torch tensor of shape (..., 3). - - Returns - new_points: transformed points with the same shape as the input. - """ - world_to_screen_transform = self.get_full_projection_transform(**kwargs) - return world_to_screen_transform.transform_points(points) + if scaled_depth_input: + # the input depth is already scaled + xy_sdepth = xy_depth + else: + # we have to obtain the scaled depth first + P = self.get_projection_transform(**kwargs).get_matrix() + unsqueeze_shape = [1] * P.dim() + unsqueeze_shape[0] = P.shape[0] + mid_z = P[:, 3, 2].reshape(unsqueeze_shape) + scale_z = P[:, 2, 2].reshape(unsqueeze_shape) + scaled_depth = scale_z * xy_depth[..., 2:3] + mid_z + # cat xy and scaled depth + xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1) + # finally invert the transform + unprojection_transform = to_screen_transform.inverse() + return unprojection_transform.transform_points(xy_sdepth) -class SfMPerspectiveCameras(TensorProperties): +class SfMPerspectiveCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of transformation matrices using the multi-view geometry convention for @@ -495,14 +595,14 @@ class SfMPerspectiveCameras(TensorProperties): arguments to override the default values set in __init__. Returns: - P: a batch of projection matrices of shape (N, 4, 4) + P: A `Transform3d` object with a batch of `N` projection transforms. .. code-block:: python - fx = focal_length[:,0] - fy = focal_length[:,1] - px = principal_point[:,0] - py = principal_point[:,1] + fx = focal_length[:, 0] + fy = focal_length[:, 1] + px = principal_point[:, 0] + py = principal_point[:, 1] P = [ [fx, 0, px, 0], @@ -524,93 +624,22 @@ class SfMPerspectiveCameras(TensorProperties): transform._matrix = P.transpose(1, 2).contiguous() return transform - def clone(self): - other = SfMPerspectiveCameras(device=self.device) - return super().clone(other) + def unproject_points( + self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs + ) -> torch.Tensor: + if world_coordinates: + to_screen_transform = self.get_full_projection_transform(**kwargs) + else: + to_screen_transform = self.get_projection_transform(**kwargs) - def get_camera_center(self, **kwargs): - """ - Return the 3D location of the camera optical center - in the world coordinates. - - Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting T here will update the values set in init as this - value may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - C: a batch of 3D locations of shape (N, 3) denoting - the locations of the center of each camera in the batch. - """ - w2v_trans = self.get_world_to_view_transform(**kwargs) - P = w2v_trans.inverse().get_matrix() - # the camera center is the translation component (the first 3 elements - # of the last row) of the inverted world-to-view - # transform (4x4 RT matrix) - C = P[:, 3, :3] - return C - - def get_world_to_view_transform(self, **kwargs) -> Transform3d: - """ - Return the world-to-view transform. - - Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - T: a Transform3d object which represents a batch of transforms - of shape (N, 3, 3) - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T) - return world_to_view_transform - - def get_full_projection_transform(self, **kwargs) -> Transform3d: - """ - Return the full world-to-screen transform composing the - world-to-view and view-to-screen transforms. - - Args: - **kwargs: parameters for the projection transforms can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) - view_to_screen_transform = self.get_projection_transform(**kwargs) - return world_to_view_transform.compose(view_to_screen_transform) - - def transform_points(self, points, **kwargs) -> torch.Tensor: - """ - Transform input points from world to screen space. - - Args: - points: torch tensor of shape (..., 3). - - Returns - new_points: transformed points with the same shape as the input. - """ - world_to_screen_transform = self.get_full_projection_transform(**kwargs) - return world_to_screen_transform.transform_points(points) + unprojection_transform = to_screen_transform.inverse() + xy_inv_depth = torch.cat( + (xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore + ) + return unprojection_transform.transform_points(xy_inv_depth) -class SfMOrthographicCameras(TensorProperties): +class SfMOrthographicCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of transformation matrices using the multi-view geometry convention for @@ -653,8 +682,8 @@ class SfMOrthographicCameras(TensorProperties): **kwargs: parameters for the projection can be passed in as keyword arguments to override the default values set in __init__. - Return: - P: a batch of projection matrices of shape (N, 4, 4) + Returns: + P: A `Transform3d` object with a batch of `N` projection transforms. .. code-block:: python @@ -683,90 +712,16 @@ class SfMOrthographicCameras(TensorProperties): transform._matrix = P.transpose(1, 2).contiguous() return transform - def clone(self): - other = SfMOrthographicCameras(device=self.device) - return super().clone(other) + def unproject_points( + self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs + ) -> torch.Tensor: + if world_coordinates: + to_screen_transform = self.get_full_projection_transform(**kwargs) + else: + to_screen_transform = self.get_projection_transform(**kwargs) - def get_camera_center(self, **kwargs): - """ - Return the 3D location of the camera optical center - in the world coordinates. - - Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting T here will update the values set in init as this - value may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - C: a batch of 3D locations of shape (N, 3) denoting - the locations of the center of each camera in the batch. - """ - w2v_trans = self.get_world_to_view_transform(**kwargs) - P = w2v_trans.inverse().get_matrix() - # the camera center is the translation component (the first 3 elements - # of the last row) of the inverted world-to-view - # transform (4x4 RT matrix) - C = P[:, 3, :3] - return C - - def get_world_to_view_transform(self, **kwargs) -> Transform3d: - """ - Return the world-to-view transform. - - Args: - **kwargs: parameters for the camera extrinsics can be passed in - as keyword arguments to override the default values - set in __init__. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - - Returns: - T: a Transform3d object which represents a batch of transforms - of shape (N, 3, 3) - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = get_world_to_view_transform(R=self.R, T=self.T) - return world_to_view_transform - - def get_full_projection_transform(self, **kwargs) -> Transform3d: - """ - Return the full world-to-screen transform composing the - world-to-view and view-to-screen transforms. - - Args: - **kwargs: parameters for the projection transforms can be passed in - as keyword arguments to override the default values - set in `__init__`. - - Setting R and T here will update the values set in init as these - values may be needed later on in the rendering pipeline e.g. for - lighting calculations. - """ - self.R = kwargs.get("R", self.R) # pyre-ignore[16] - self.T = kwargs.get("T", self.T) # pyre-ignore[16] - world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) - view_to_screen_transform = self.get_projection_transform(**kwargs) - return world_to_view_transform.compose(view_to_screen_transform) - - def transform_points(self, points, **kwargs) -> torch.Tensor: - """ - Transform input points from world to screen space. - - Args: - points: torch tensor of shape (..., 3). - - Returns - new_points: transformed points with the same shape as the input. - """ - world_to_screen_transform = self.get_full_projection_transform(**kwargs) - return world_to_screen_transform.transform_points(points) + unprojection_transform = to_screen_transform.inverse() + return unprojection_transform.transform_points(xy_depth) # SfMCameras helper diff --git a/pytorch3d/renderer/utils.py b/pytorch3d/renderer/utils.py index 3afd6003..76cbf78c 100644 --- a/pytorch3d/renderer/utils.py +++ b/pytorch3d/renderer/utils.py @@ -1,6 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +import copy +import inspect import warnings from typing import Any, Union @@ -168,10 +170,13 @@ class TensorProperties(object): """ for k in dir(self): v = getattr(self, k) - if k == "device": - setattr(self, k, v) + if inspect.ismethod(v) or k.startswith("__"): + continue if torch.is_tensor(v): - setattr(other, k, v.clone()) + v_clone = v.clone() + else: + v_clone = copy.deepcopy(v) + setattr(other, k, v_clone) return other def gather_props(self, batch_idx): diff --git a/tests/test_cameras.py b/tests/test_cameras.py index da35b4e7..34bf6410 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -32,6 +32,7 @@ import numpy as np import torch from common_testing import TestCaseMixin from pytorch3d.renderer.cameras import ( + CamerasBase, OpenGLOrthographicCameras, OpenGLPerspectiveCameras, SfMOrthographicCameras, @@ -347,6 +348,8 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase): RT = get_world_to_view_transform(R=R, T=T) self.assertTrue(isinstance(RT, Transform3d)) + +class TestCamerasCommon(TestCaseMixin, unittest.TestCase): def test_view_transform_class_method(self): T = torch.tensor([0.0, 0.0, -1.0], requires_grad=True).view(1, -1) R = look_at_rotation(T) @@ -377,6 +380,108 @@ class TestCameraHelpers(TestCaseMixin, unittest.TestCase): C_ = -torch.bmm(R, T[:, :, None])[:, :, 0] self.assertTrue(torch.allclose(C, C_, atol=1e-05)) + @staticmethod + def init_random_cameras(cam_type: CamerasBase, batch_size: int): + cam_params = {} + T = torch.randn(batch_size, 3) * 0.03 + T[:, 2] = 4 + R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0) + cam_params = {"R": R, "T": T} + if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): + cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] + if cam_type == OpenGLPerspectiveCameras: + cam_params["fov"] = torch.rand(batch_size) * 60 + 30 + cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 + else: + cam_params["top"] = torch.rand(batch_size) * 0.2 + 0.9 + cam_params["bottom"] = -torch.rand(batch_size) * 0.2 - 0.9 + cam_params["left"] = -torch.rand(batch_size) * 0.2 - 0.9 + cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9 + elif cam_type in (SfMOrthographicCameras, SfMPerspectiveCameras): + cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["principal_point"] = torch.randn((batch_size, 2)) + else: + raise ValueError(str(cam_type)) + return cam_type(**cam_params) + + def test_unproject_points(self, batch_size=50, num_points=100): + """ + Checks that an unprojection of a randomly projected point cloud + stays the same. + """ + + for cam_type in ( + SfMOrthographicCameras, + OpenGLPerspectiveCameras, + OpenGLOrthographicCameras, + SfMPerspectiveCameras, + ): + # init the cameras + cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) + # xyz - the ground truth point cloud + xyz = torch.randn(batch_size, num_points, 3) * 0.3 + # xyz in camera coordinates + xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz) + # depth = z-component of xyz_cam + depth = xyz_cam[:, :, 2:] + # project xyz + xyz_proj = cameras.transform_points(xyz) + xy, cam_depth = xyz_proj.split(2, dim=2) + # input to the unprojection function + xy_depth = torch.cat((xy, depth), dim=2) + + for to_world in (False, True): + if to_world: + matching_xyz = xyz + else: + matching_xyz = xyz_cam + + # if we have OpenGL cameras + # test for scaled_depth_input=True/False + if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): + for scaled_depth_input in (True, False): + if scaled_depth_input: + xy_depth_ = xyz_proj + else: + xy_depth_ = xy_depth + xyz_unproj = cameras.unproject_points( + xy_depth_, + world_coordinates=to_world, + scaled_depth_input=scaled_depth_input, + ) + self.assertTrue( + torch.allclose(xyz_unproj, matching_xyz, atol=1e-4) + ) + else: + xyz_unproj = cameras.unproject_points( + xy_depth, world_coordinates=to_world + ) + self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4)) + + def test_clone(self, batch_size: int = 10): + """ + Checks the clone function of the cameras. + """ + for cam_type in ( + SfMOrthographicCameras, + OpenGLPerspectiveCameras, + OpenGLOrthographicCameras, + SfMPerspectiveCameras, + ): + cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) + cameras = cameras.to(torch.device("cpu")) + cameras_clone = cameras.clone() + + for var in cameras.__dict__.keys(): + val = getattr(cameras, var) + val_clone = getattr(cameras_clone, var) + if torch.is_tensor(val): + self.assertClose(val, val_clone) + self.assertSeparate(val, val_clone) + else: + self.assertTrue(val == val_clone) + class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): def test_perspective(self): @@ -679,4 +784,4 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase): vertices = torch.randn([3, 4, 3], dtype=torch.float32) v1 = P.transform_points(vertices) v2 = sfm_perspective_project_naive(vertices, fx=2.0, fy=2.0, p0x=2.5, p0y=3.5) - self.assertClose(v1, v2) + self.assertClose(v1, v2, atol=1e-6)