diff --git a/docs/notes/cameras.md b/docs/notes/cameras.md new file mode 100644 index 00000000..c96e3451 --- /dev/null +++ b/docs/notes/cameras.md @@ -0,0 +1,63 @@ +# Cameras + +## Camera Coordinate Systems + +When working with 3D data, there are 4 coordinate systems users need to know +* **World coordinate system** +This is the system the object/scene lives - the world. +* **Camera view coordinate system** +This is the system that has its origin on the image plane and the `Z`-axis perpendicular to the image plane. In PyTorch3D, we assume that `+X` points left, and `+Y` points up and `+Z` points out from the image plane. The transformation from world to view happens after applying a rotation (`R`) and translation (`T`). +* **NDC coordinate system** +This is the normalized coordinate system that confines in a volume the renderered part of the object/scene. Also known as view volume. Under the PyTorch3D convention, `(+1, +1, znear)` is the top left near corner, and `(-1, -1, zfar)` is the bottom right far corner of the volume. The transformation from view to NDC happens after applying the camera projection matrix (`P`). +* **Screen coordinate system** +This is another representation of the view volume with the `XY` coordinates defined in pixel space instead of a normalized space. + +An illustration of the 4 coordinate systems is shown below +![cameras](https://user-images.githubusercontent.com/4369065/90317960-d9b8db80-dee1-11ea-8088-39c414b1e2fa.png) + +## Defining Cameras in PyTorch3D + +Cameras in PyTorch3D transform an object/scene from world to NDC by first transforming the object/scene to view (via transforms `R` and `T`) and then projecting the 3D object/scene to NDC (via the projection matrix `P`, else known as camera matrix). Thus, the camera parameters in `P` are assumed to be in NDC space. If the user has camera parameters in screen space, which is a common use case, the parameters should transformed to NDC (see below for an example) + +We describe the camera types in PyTorch3D and the convention for the camera parameters provided at construction time. + +### Camera Types + +All cameras inherit from `CamerasBase` which is a base class for all cameras. PyTorch3D provides four different camera types. The `CamerasBase` defines methods that are common to all camera models: +* `get_camera_center` that returns the optical center of the camera in world coordinates +* `get_world_to_view_transform` which returns a 3D transform from world coordinates to the camera view coordinates (R, T) +* `get_full_projection_transform` which composes the projection transform (P) with the world-to-view transform (R, T) +* `transform_points` which takes a set of input points in world coordinates and projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar]. +* `transform_points_screen` which takes a set of input points in world coordinates and projects them to the screen coordinates ranging from [0, 0, znear] to [W-1, H-1, zfar] + +Users can easily customize their own cameras. For each new camera, users should implement the `get_projection_transform` routine that returns the mapping `P` from camera view coordinates to NDC coordinates. + +#### FoVPerspectiveCameras, FoVOrthographicCameras +These two cameras follow the OpenGL convention for perspective and orthographic cameras respectively. The user provides the near `znear` and far `zfar` field which confines the view volume in the `Z` axis. The view volume in the `XY` plane is defined by field of view angle (`fov`) in the case of `FoVPerspectiveCameras` and by `min_x, min_y, max_x, max_y` in the case of `FoVOrthographicCameras`. + +#### PerspectiveCameras, OrthographicCameras +These two cameras follow the Multi-View Geometry convention for cameras. The user provides the focal length (`fx`, `fy`) and the principal point (`px`, `py`). For example, `camera = PerspectiveCameras(focal_length=((fx, fy),), principal_point=((px, py),))` + +As mentioned above, the focal length and principal point are used to convert a point `(X, Y, Z)` from view coordinates to NDC coordinates, as follows + +``` +# for perspective +x_ndc = fx * X / Z + px +y_ndc = fy * Y / Z + py +z_ndc = 1 / Z + +# for orthographic +x_ndc = fx * X + px +y_ndc = fy * Y + py +z_ndc = Z +``` + +Commonly, users have access to the focal length (`fx_screen`, `fy_screen`) and the principal point (`px_screen`, `py_screen`) in screen space. In that case, to construct the camera the user needs to additionally provide the `image_size = ((image_width, image_height),)`. More precisely, `camera = PerspectiveCameras(focal_length=((fx_screen, fy_screen),), principal_point=((px_screen, py_screen),), image_size = ((image_width, image_height),))`. Internally, the camera parameters are converted from screen to NDC as follows: + +``` +fx = fx_screen * 2.0 / image_width +fy = fy_screen * 2.0 / image_height + +px = - (px_screen - image_width / 2.0) * 2.0 / image_width +py = - (py_screen - image_height / 2.0) * 2.0/ image_height +``` diff --git a/docs/notes/renderer_getting_started.md b/docs/notes/renderer_getting_started.md index f65351df..3d35f1e2 100644 --- a/docs/notes/renderer_getting_started.md +++ b/docs/notes/renderer_getting_started.md @@ -39,7 +39,7 @@ Rendering requires transformations between several different coordinate frames: -For example, given a teapot mesh, the world coordinate frame, camera coordiante frame and image are show in the figure below. Note that the world and camera coordinate frames have the +z direction pointing in to the page. +For example, given a teapot mesh, the world coordinate frame, camera coordiante frame and image are show in the figure below. Note that the world and camera coordinate frames have the +z direction pointing in to the page. @@ -47,8 +47,8 @@ For example, given a teapot mesh, the world coordinate frame, camera coordiante **NOTE: PyTorch3D vs OpenGL** -While we tried to emulate several aspects of OpenGL, there are differences in the coordinate frame conventions. -- The default world coordinate frame in PyTorch3D has +Z pointing in to the screen whereas in OpenGL, +Z is pointing out of the screen. Both are right handed. +While we tried to emulate several aspects of OpenGL, there are differences in the coordinate frame conventions. +- The default world coordinate frame in PyTorch3D has +Z pointing in to the screen whereas in OpenGL, +Z is pointing out of the screen. Both are right handed. - The NDC coordinate system in PyTorch3D is **right-handed** compared with a **left-handed** NDC coordinate system in OpenGL (the projection matrix switches the handedness). @@ -61,14 +61,14 @@ A renderer in PyTorch3D is composed of a **rasterizer** and a **shader**. Create ``` # Imports from pytorch3d.renderer import ( - OpenGLPerspectiveCameras, look_at_view_transform, + FoVPerspectiveCameras, look_at_view_transform, RasterizationSettings, BlendParams, MeshRenderer, MeshRasterizer, HardPhongShader ) # Initialize an OpenGL perspective camera. R, T = look_at_view_transform(2.7, 10, 20) -cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) +cameras = FoVPerspectiveCameras(device=device, R=R, T=T) # Define the settings for rasterization and shading. Here we set the output image to be of size # 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1 diff --git a/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb b/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb index 5bafce91..2a953db4 100644 --- a/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb +++ b/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb @@ -102,7 +102,7 @@ "\n", "# rendering components\n", "from pytorch3d.renderer import (\n", - " OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation, \n", + " FoVPerspectiveCameras, look_at_view_transform, look_at_rotation, \n", " RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,\n", " SoftSilhouetteShader, HardPhongShader, PointLights\n", ")" @@ -217,8 +217,8 @@ }, "outputs": [], "source": [ - "# Initialize an OpenGL perspective camera.\n", - "cameras = OpenGLPerspectiveCameras(device=device)\n", + "# Initialize a perspective camera.\n", + "cameras = FoVPerspectiveCameras(device=device)\n", "\n", "# To blend the 100 faces we set a few parameters which control the opacity and the sharpness of \n", "# edges. Refer to blending.py for more details. \n", diff --git a/docs/tutorials/fit_textured_mesh.ipynb b/docs/tutorials/fit_textured_mesh.ipynb index 3301812b..2b5add74 100644 --- a/docs/tutorials/fit_textured_mesh.ipynb +++ b/docs/tutorials/fit_textured_mesh.ipynb @@ -129,7 +129,7 @@ "from pytorch3d.structures import Meshes, Textures\n", "from pytorch3d.renderer import (\n", " look_at_view_transform,\n", - " OpenGLPerspectiveCameras, \n", + " FoVPerspectiveCameras, \n", " PointLights, \n", " DirectionalLights, \n", " Materials, \n", @@ -309,16 +309,16 @@ "# the cow is facing the -z direction. \n", "lights = PointLights(device=device, location=[[0.0, 0.0, -3.0]])\n", "\n", - "# Initialize an OpenGL perspective camera that represents a batch of different \n", + "# Initialize a camera that represents a batch of different \n", "# viewing angles. All the cameras helper methods support mixed type inputs and \n", "# broadcasting. So we can view the camera from the a distance of dist=2.7, and \n", "# then specify elevation and azimuth angles for each viewpoint as tensors. \n", "R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)\n", - "cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)\n", + "cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n", "\n", "# We arbitrarily choose one particular view that will be used to visualize \n", "# results\n", - "camera = OpenGLPerspectiveCameras(device=device, R=R[None, 1, ...], \n", + "camera = FoVPerspectiveCameras(device=device, R=R[None, 1, ...], \n", " T=T[None, 1, ...]) \n", "\n", "# Define the settings for rasterization and shading. Here we set the output \n", @@ -361,7 +361,7 @@ "# Our multi-view cow dataset will be represented by these 2 lists of tensors,\n", "# each of length num_views.\n", "target_rgb = [target_images[i, ..., :3] for i in range(num_views)]\n", - "target_cameras = [OpenGLPerspectiveCameras(device=device, R=R[None, i, ...], \n", + "target_cameras = [FoVPerspectiveCameras(device=device, R=R[None, i, ...], \n", " T=T[None, i, ...]) for i in range(num_views)]" ], "execution_count": null, @@ -925,4 +925,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/docs/tutorials/render_colored_points.ipynb b/docs/tutorials/render_colored_points.ipynb index 318baebd..0adbd8d7 100644 --- a/docs/tutorials/render_colored_points.ipynb +++ b/docs/tutorials/render_colored_points.ipynb @@ -64,7 +64,7 @@ "from pytorch3d.structures import Pointclouds\n", "from pytorch3d.renderer import (\n", " look_at_view_transform,\n", - " OpenGLOrthographicCameras, \n", + " FoVOrthographicCameras, \n", " PointsRasterizationSettings,\n", " PointsRenderer,\n", " PointsRasterizer,\n", @@ -147,9 +147,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize an OpenGL perspective camera.\n", + "# Initialize a camera.\n", "R, T = look_at_view_transform(20, 10, 0)\n", - "cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)\n", + "cameras = FoVOrthographicCameras(device=device, R=R, T=T, znear=0.01)\n", "\n", "# Define the settings for rasterization and shading. Here we set the output image to be of size\n", "# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1\n", @@ -195,9 +195,9 @@ "metadata": {}, "outputs": [], "source": [ - "# Initialize an OpenGL perspective camera.\n", + "# Initialize a camera.\n", "R, T = look_at_view_transform(20, 10, 0)\n", - "cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01)\n", + "cameras = FoVOrthographicCameras(device=device, R=R, T=T, znear=0.01)\n", "\n", "# Define the settings for rasterization and shading. Here we set the output image to be of size\n", "# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1\n", diff --git a/docs/tutorials/render_textured_meshes.ipynb b/docs/tutorials/render_textured_meshes.ipynb index 1f63307a..24e8548d 100644 --- a/docs/tutorials/render_textured_meshes.ipynb +++ b/docs/tutorials/render_textured_meshes.ipynb @@ -90,7 +90,7 @@ "from pytorch3d.structures import Meshes, Textures\n", "from pytorch3d.renderer import (\n", " look_at_view_transform,\n", - " OpenGLPerspectiveCameras, \n", + " FoVPerspectiveCameras, \n", " PointLights, \n", " DirectionalLights, \n", " Materials, \n", @@ -286,11 +286,11 @@ }, "outputs": [], "source": [ - "# Initialize an OpenGL perspective camera.\n", + "# Initialize a camera.\n", "# With world coordinates +Y up, +X left and +Z in, the front of the cow is facing the -Z direction. \n", "# So we move the camera by 180 in the azimuth direction so it is facing the front of the cow. \n", "R, T = look_at_view_transform(2.7, 0, 180) \n", - "cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)\n", + "cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n", "\n", "# Define the settings for rasterization and shading. Here we set the output image to be of size\n", "# 512x512. As we are rendering images for visualization purposes only we will set faces_per_pixel=1\n", @@ -444,7 +444,7 @@ "source": [ "# Rotate the object by increasing the elevation and azimuth angles\n", "R, T = look_at_view_transform(dist=2.7, elev=10, azim=-150)\n", - "cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)\n", + "cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n", "\n", "# Move the light location so the light is shining on the cow's face. \n", "lights.location = torch.tensor([[2.0, 2.0, -2.0]], device=device)\n", @@ -519,7 +519,7 @@ "# view the camera from the same distance and specify dist=2.7 as a float,\n", "# and then specify elevation and azimuth angles for each viewpoint as tensors. \n", "R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)\n", - "cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)\n", + "cameras = FoVPerspectiveCameras(device=device, R=R, T=T)\n", "\n", "# Move the light back in front of the cow which is facing the -z direction.\n", "lights.location = torch.tensor([[0.0, 0.0, -3.0]], device=device)" diff --git a/pytorch3d/datasets/shapenet_base.py b/pytorch3d/datasets/shapenet_base.py index 735a8414..dea66c15 100644 --- a/pytorch3d/datasets/shapenet_base.py +++ b/pytorch3d/datasets/shapenet_base.py @@ -10,7 +10,7 @@ from pytorch3d.renderer import ( HardPhongShader, MeshRasterizer, MeshRenderer, - OpenGLPerspectiveCameras, + FoVPerspectiveCameras, PointLights, RasterizationSettings, TexturesVertex, @@ -125,7 +125,7 @@ class ShapeNetBase(torch.utils.data.Dataset): meshes.textures = TexturesVertex( verts_features=torch.ones_like(meshes.verts_padded(), device=device) ) - cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device) + cameras = kwargs.get("cameras", FoVPerspectiveCameras()).to(device) if len(cameras) != 1 and len(cameras) % len(meshes) != 0: raise ValueError("Mismatch between batch dims of cameras and meshes.") if len(cameras) > 1: diff --git a/pytorch3d/renderer/__init__.py b/pytorch3d/renderer/__init__.py index 91b2a47f..72a01122 100644 --- a/pytorch3d/renderer/__init__.py +++ b/pytorch3d/renderer/__init__.py @@ -6,11 +6,15 @@ from .blending import ( sigmoid_alpha_blend, softmax_rgb_blend, ) +from .cameras import OpenGLOrthographicCameras # deprecated +from .cameras import OpenGLPerspectiveCameras # deprecated +from .cameras import SfMOrthographicCameras # deprecated +from .cameras import SfMPerspectiveCameras # deprecated from .cameras import ( - OpenGLOrthographicCameras, - OpenGLPerspectiveCameras, - SfMOrthographicCameras, - SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, camera_position_from_spherical_angles, get_world_to_view_transform, look_at_rotation, diff --git a/pytorch3d/renderer/cameras.py b/pytorch3d/renderer/cameras.py index f2de3d5a..94683d06 100644 --- a/pytorch3d/renderer/cameras.py +++ b/pytorch3d/renderer/cameras.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import math +import warnings from typing import Optional, Sequence, Tuple import numpy as np @@ -20,23 +21,43 @@ class CamerasBase(TensorProperties): """ `CamerasBase` implements a base class for all cameras. + For cameras, there are four different coordinate systems (or spaces) + - World coordinate system: This is the system the object lives - the world. + - Camera view coordinate system: This is the system that has its origin on the image plane + and the and the Z-axis perpendicular to the image plane. + In PyTorch3D, we assume that +X points left, and +Y points up and + +Z points out from the image plane. + The transformation from world -> view happens after applying a rotation (R) + and translation (T) + - NDC coordinate system: This is the normalized coordinate system that confines + in a volume the renderered part of the object or scene. Also known as view volume. + Given the PyTorch3D convention, (+1, +1, znear) is the top left near corner, + and (-1, -1, zfar) is the bottom right far corner of the volume. + The transformation from view -> NDC happens after applying the camera + projection matrix (P). + - Screen coordinate system: This is another representation of the view volume with + the XY coordinates defined in pixel space instead of a normalized space. + + A better illustration of the coordinate systems can be found in pytorch3d/docs/notes/cameras.md. + It defines methods that are common to all camera models: - `get_camera_center` that returns the optical center of the camera in world coordinates - `get_world_to_view_transform` which returns a 3D transform from - world coordinates to the camera coordinates + world coordinates to the camera view coordinates (R, T) - `get_full_projection_transform` which composes the projection - transform with the world-to-view transform - - `transform_points` which takes a set of input points and - projects them onto a 2D camera plane. + transform (P) with the world-to-view transform (R, T) + - `transform_points` which takes a set of input points in world coordinates and + projects to NDC coordinates ranging from [-1, -1, znear] to [+1, +1, zfar]. + - `transform_points_screen` which takes a set of input points in world coordinates and + projects them to the screen coordinates ranging from [0, 0, znear] to [W-1, H-1, zfar] For each new camera, one should implement the `get_projection_transform` - routine that returns the mapping from camera coordinates in world units - to the screen coordinates. + routine that returns the mapping from camera view coordinates to NDC coordinates. Another useful function that is specific to each camera model is - `unproject_points` which sends points from screen coordinates back to - camera or world coordinates depending on the `world_coordinates` + `unproject_points` which sends points from NDC coordinates back to + camera view or world coordinates depending on the `world_coordinates` boolean argument of the function. """ @@ -56,7 +77,7 @@ class CamerasBase(TensorProperties): def unproject_points(self): """ - Transform input points in screen coodinates + Transform input points from NDC coodinates to the world / camera coordinates. Each of the input points `xy_depth` of shape (..., 3) is @@ -74,7 +95,7 @@ class CamerasBase(TensorProperties): cameras = # camera object derived from CamerasBase xyz = # 3D points of shape (batch_size, num_points, 3) - # transform xyz to the camera coordinates + # transform xyz to the camera view coordinates xyz_cam = cameras.get_world_to_view_transform().transform_points(xyz) # extract the depth of each point as the 3rd coord of xyz_cam depth = xyz_cam[:, :, 2:] @@ -94,7 +115,7 @@ class CamerasBase(TensorProperties): world_coordinates: If `True`, unprojects the points back to world coordinates using the camera extrinsics `R` and `T`. `False` ignores `R` and `T` and unprojects to - the camera coordinates. + the camera view coordinates. Returns new_points: unprojected points with the same shape as `xy_depth`. @@ -141,7 +162,7 @@ class CamerasBase(TensorProperties): lighting calculations. Returns: - T: a Transform3d object which represents a batch of transforms + A Transform3d object which represents a batch of transforms of shape (N, 3, 3) """ self.R = kwargs.get("R", self.R) # pyre-ignore[16] @@ -151,8 +172,8 @@ class CamerasBase(TensorProperties): def get_full_projection_transform(self, **kwargs) -> Transform3d: """ - Return the full world-to-screen transform composing the - world-to-view and view-to-screen transforms. + Return the full world-to-NDC transform composing the + world-to-view and view-to-NDC transforms. Args: **kwargs: parameters for the projection transforms can be passed in @@ -164,26 +185,26 @@ class CamerasBase(TensorProperties): lighting calculations. Returns: - T: a Transform3d object which represents a batch of transforms + a Transform3d object which represents a batch of transforms of shape (N, 3, 3) """ self.R = kwargs.get("R", self.R) # pyre-ignore[16] self.T = kwargs.get("T", self.T) # pyre-ignore[16] world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T) - view_to_screen_transform = self.get_projection_transform(**kwargs) - return world_to_view_transform.compose(view_to_screen_transform) + view_to_ndc_transform = self.get_projection_transform(**kwargs) + return world_to_view_transform.compose(view_to_ndc_transform) def transform_points( self, points, eps: Optional[float] = None, **kwargs ) -> torch.Tensor: """ - Transform input points from world to screen space. + Transform input points from world to NDC space. Args: points: torch tensor of shape (..., 3). eps: If eps!=None, the argument is used to clamp the divisor in the homogeneous normalization of the points - transformed to the screen space. Plese see + transformed to the ndc space. Please see `transforms.Transform3D.transform_points` for details. For `CamerasBase.transform_points`, setting `eps > 0` @@ -194,8 +215,50 @@ class CamerasBase(TensorProperties): Returns new_points: transformed points with the same shape as the input. """ - world_to_screen_transform = self.get_full_projection_transform(**kwargs) - return world_to_screen_transform.transform_points(points, eps=eps) + world_to_ndc_transform = self.get_full_projection_transform(**kwargs) + return world_to_ndc_transform.transform_points(points, eps=eps) + + def transform_points_screen( + self, points, image_size, eps: Optional[float] = None, **kwargs + ) -> torch.Tensor: + """ + Transform input points from world to screen space. + + Args: + points: torch tensor of shape (N, V, 3). + image_size: torch tensor of shape (N, 2) + eps: If eps!=None, the argument is used to clamp the + divisor in the homogeneous normalization of the points + transformed to the ndc space. Please see + `transforms.Transform3D.transform_points` for details. + + For `CamerasBase.transform_points`, setting `eps > 0` + stabilizes gradients since it leads to avoiding division + by excessivelly low numbers for points close to the + camera plane. + + Returns + new_points: transformed points with the same shape as the input. + """ + + ndc_points = self.transform_points(points, eps=eps, **kwargs) + + if not torch.is_tensor(image_size): + image_size = torch.tensor( + image_size, dtype=torch.int64, device=points.device + ) + if (image_size < 1).any(): + raise ValueError("Provided image size is invalid.") + + image_width, image_height = image_size.unbind(1) + image_width = image_width.view(-1, 1) # (N, 1) + image_height = image_height.view(-1, 1) # (N, 1) + + ndc_z = ndc_points[..., 2] + screen_x = (image_width - 1.0) / 2.0 * (1.0 - ndc_points[..., 0]) + screen_y = (image_height - 1.0) / 2.0 * (1.0 - ndc_points[..., 1]) + + return torch.stack((screen_x, screen_y, ndc_z), dim=2) def clone(self): """ @@ -206,21 +269,56 @@ class CamerasBase(TensorProperties): return super().clone(other) -######################## -# Specific camera classes -######################## +############################################################ +# Field of View Camera Classes # +############################################################ -class OpenGLPerspectiveCameras(CamerasBase): +def OpenGLPerspectiveCameras( + znear=1.0, + zfar=100.0, + aspect_ratio=1.0, + fov=60.0, + degrees: bool = True, + R=r, + T=t, + device="cpu", +): + """ + OpenGLPerspectiveCameras has been DEPRECATED. Use FoVPerspectiveCameras instead. + Preserving OpenGLPerspectiveCameras for backward compatibility. + """ + + warnings.warn( + """OpenGLPerspectiveCameras is deprecated, + Use FoVPerspectiveCameras instead. + OpenGLPerspectiveCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return FoVPerspectiveCameras( + znear=znear, + zfar=zfar, + aspect_ratio=aspect_ratio, + fov=fov, + degrees=degrees, + R=R, + T=T, + device=device, + ) + + +class FoVPerspectiveCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of - projection matrices using the OpenGL convention for a perspective camera. + projection matrices by specifiying the field of view. + The definition of the parameters follow the OpenGL perspective camera. The extrinsics of the camera (R and T matrices) can also be set in the initializer or passed in to `get_full_projection_transform` to get - the full transformation from world -> screen. + the full transformation from world -> ndc. - The `transform_points` method calculates the full world -> screen transform + The `transform_points` method calculates the full world -> ndc transform and then applies it to the input points. The transforms can also be returned separately as Transform3d objects. @@ -267,8 +365,11 @@ class OpenGLPerspectiveCameras(CamerasBase): def get_projection_transform(self, **kwargs) -> Transform3d: """ - Calculate the OpenGL perpective projection matrix with a symmetric + Calculate the perpective projection matrix with a symmetric viewing frustrum. Use column major order. + The viewing frustrum will be projected into ndc, s.t. + (max_x, max_y) -> (+1, +1) + (min_x, min_y) -> (-1, -1) Args: **kwargs: parameters for the projection can be passed in as keyword @@ -276,14 +377,14 @@ class OpenGLPerspectiveCameras(CamerasBase): Return: P: a Transform3d object which represents a batch of projection - matrices of shape (N, 3, 3) + matrices of shape (N, 4, 4) .. code-block:: python f1 = -(far + near)/(far−near) f2 = -2*far*near/(far-near) - h1 = (top + bottom)/(top - bottom) - w1 = (right + left)/(right - left) + h1 = (max_y + min_y)/(max_y - min_y) + w1 = (max_x + min_x)/(max_x - min_x) tanhalffov = tan((fov/2)) s1 = 1/tanhalffov s2 = 1/(tanhalffov * (aspect_ratio)) @@ -310,10 +411,10 @@ class OpenGLPerspectiveCameras(CamerasBase): if not torch.is_tensor(fov): fov = torch.tensor(fov, device=self.device) tanHalfFov = torch.tan((fov / 2)) - top = tanHalfFov * znear - bottom = -top - right = top * aspect_ratio - left = -right + max_y = tanHalfFov * znear + min_y = -max_y + max_x = max_y * aspect_ratio + min_x = -max_x # NOTE: In OpenGL the projection matrix changes the handedness of the # coordinate frame. i.e the NDC space postive z direction is the @@ -323,28 +424,19 @@ class OpenGLPerspectiveCameras(CamerasBase): # so the so the z sign is 1.0. z_sign = 1.0 - P[:, 0, 0] = 2.0 * znear / (right - left) - P[:, 1, 1] = 2.0 * znear / (top - bottom) - P[:, 0, 2] = (right + left) / (right - left) - P[:, 1, 2] = (top + bottom) / (top - bottom) + P[:, 0, 0] = 2.0 * znear / (max_x - min_x) + P[:, 1, 1] = 2.0 * znear / (max_y - min_y) + P[:, 0, 2] = (max_x + min_x) / (max_x - min_x) + P[:, 1, 2] = (max_y + min_y) / (max_y - min_y) P[:, 3, 2] = z_sign * ones - # NOTE: This part of the matrix is for z renormalization in OpenGL - # which maps the z to [-1, 1]. This won't work yet as the torch3d - # rasterizer ignores faces which have z < 0. - # P[:, 2, 2] = z_sign * (far + near) / (far - near) - # P[:, 2, 3] = -2.0 * far * near / (far - near) - # P[:, 3, 2] = z_sign * torch.ones((N)) - # NOTE: This maps the z coordinate from [0, 1] where z = 0 if the point # is at the near clipping plane and z = 1 when the point is at the far - # clipping plane. This replaces the OpenGL z normalization to [-1, 1] - # until rasterization is changed to clip at z = -1. + # clipping plane. P[:, 2, 2] = z_sign * zfar / (zfar - znear) P[:, 2, 3] = -(zfar * znear) / (zfar - znear) - # OpenGL uses column vectors so need to transpose the projection matrix - # as torch3d uses row vectors. + # Transpose the projection matrix as PyTorch3d transforms use row vectors. transform = Transform3d(device=self.device) transform._matrix = P.transpose(1, 2).contiguous() return transform @@ -357,7 +449,7 @@ class OpenGLPerspectiveCameras(CamerasBase): **kwargs ) -> torch.Tensor: """>! - OpenGL cameras further allow for passing depth in world units + FoV cameras further allow for passing depth in world units (`scaled_depth_input=False`) or in the [0, 1]-normalized units (`scaled_depth_input=True`) @@ -367,11 +459,11 @@ class OpenGLPerspectiveCameras(CamerasBase): the world units. """ - # obtain the relevant transformation to screen + # obtain the relevant transformation to ndc if world_coordinates: - to_screen_transform = self.get_full_projection_transform() + to_ndc_transform = self.get_full_projection_transform() else: - to_screen_transform = self.get_projection_transform() + to_ndc_transform = self.get_projection_transform() if scaled_depth_input: # the input is scaled depth, so we don't have to do anything @@ -390,45 +482,84 @@ class OpenGLPerspectiveCameras(CamerasBase): xy_sdepth = torch.cat((xy_depth[..., 0:2], sdepth), dim=-1) # unproject with inverse of the projection - unprojection_transform = to_screen_transform.inverse() + unprojection_transform = to_ndc_transform.inverse() return unprojection_transform.transform_points(xy_sdepth) -class OpenGLOrthographicCameras(CamerasBase): +def OpenGLOrthographicCameras( + znear=1.0, + zfar=100.0, + top=1.0, + bottom=-1.0, + left=-1.0, + right=1.0, + scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) + R=r, + T=t, + device="cpu", +): + """ + OpenGLOrthographicCameras has been DEPRECATED. Use FoVOrthographicCameras instead. + Preserving OpenGLOrthographicCameras for backward compatibility. + """ + + warnings.warn( + """OpenGLOrthographicCameras is deprecated, + Use FoVOrthographicCameras instead. + OpenGLOrthographicCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return FoVOrthographicCameras( + znear=znear, + zfar=zfar, + max_y=top, + min_y=bottom, + max_x=right, + min_x=left, + scale_xyz=scale_xyz, + R=R, + T=T, + device=device, + ) + + +class FoVOrthographicCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of - transformation matrices using the OpenGL convention for orthographic camera. + projection matrices by specifiying the field of view. + The definition of the parameters follow the OpenGL orthographic camera. """ def __init__( self, znear=1.0, zfar=100.0, - top=1.0, - bottom=-1.0, - left=-1.0, - right=1.0, + max_y=1.0, + min_y=-1.0, + max_x=1.0, + min_x=-1.0, scale_xyz=((1.0, 1.0, 1.0),), # (1, 3) R=r, T=t, device="cpu", ): """ - __init__(self, znear, zfar, top, bottom, left, right, scale_xyz, R, T, device) -> None # noqa + __init__(self, znear, zfar, max_y, min_y, max_x, min_x, scale_xyz, R, T, device) -> None # noqa Args: znear: near clipping plane of the view frustrum. zfar: far clipping plane of the view frustrum. - top: position of the top of the screen. - bottom: position of the bottom of the screen. - left: position of the left of the screen. - right: position of the right of the screen. + max_y: maximum y coordinate of the frustrum. + min_y: minimum y coordinate of the frustrum. + max_x: maximum x coordinate of the frustrum. + min_x: minumum x coordinage of the frustrum scale_xyz: scale factors for each axis of shape (N, 3). R: Rotation matrix of shape (N, 3, 3). T: Translation of shape (N, 3). device: torch.device or string. - Only need to set left, right, top, bottom for viewing frustrums + Only need to set min_x, max_x, min_y, max_y for viewing frustrums which are non symmetric about the origin. """ # The initializer formats all inputs to torch tensors and broadcasts @@ -437,10 +568,10 @@ class OpenGLOrthographicCameras(CamerasBase): device=device, znear=znear, zfar=zfar, - top=top, - bottom=bottom, - left=left, - right=right, + max_y=max_y, + min_y=min_y, + max_x=max_x, + min_x=min_x, scale_xyz=scale_xyz, R=R, T=T, @@ -448,7 +579,7 @@ class OpenGLOrthographicCameras(CamerasBase): def get_projection_transform(self, **kwargs) -> Transform3d: """ - Calculate the OpenGL orthographic projection matrix. + Calculate the orthographic projection matrix. Use column major order. Args: @@ -456,16 +587,16 @@ class OpenGLOrthographicCameras(CamerasBase): override the default values set in __init__. Return: P: a Transform3d object which represents a batch of projection - matrices of shape (N, 3, 3) + matrices of shape (N, 4, 4) .. code-block:: python - scale_x = 2/(right - left) - scale_y = 2/(top - bottom) - scale_z = 2/(far-near) - mid_x = (right + left)/(right - left) - mix_y = (top + bottom)/(top - bottom) - mid_z = (far + near)/(far−near) + scale_x = 2 / (max_x - min_x) + scale_y = 2 / (max_y - min_y) + scale_z = 2 / (far-near) + mid_x = (max_x + min_x) / (max_x - min_x) + mix_y = (max_y + min_y) / (max_y - min_y) + mid_z = (far + near) / (far−near) P = [ [scale_x, 0, 0, -mid_x], @@ -476,10 +607,10 @@ class OpenGLOrthographicCameras(CamerasBase): """ znear = kwargs.get("znear", self.znear) # pyre-ignore[16] zfar = kwargs.get("zfar", self.zfar) # pyre-ignore[16] - left = kwargs.get("left", self.left) # pyre-ignore[16] - right = kwargs.get("right", self.right) # pyre-ignore[16] - top = kwargs.get("top", self.top) # pyre-ignore[16] - bottom = kwargs.get("bottom", self.bottom) # pyre-ignore[16] + max_x = kwargs.get("max_x", self.max_x) # pyre-ignore[16] + min_x = kwargs.get("min_x", self.min_x) # pyre-ignore[16] + max_y = kwargs.get("max_y", self.max_y) # pyre-ignore[16] + min_y = kwargs.get("min_y", self.min_y) # pyre-ignore[16] scale_xyz = kwargs.get("scale_xyz", self.scale_xyz) # pyre-ignore[16] P = torch.zeros((self._N, 4, 4), dtype=torch.float32, device=self.device) @@ -489,10 +620,10 @@ class OpenGLOrthographicCameras(CamerasBase): # right handed coordinate system throughout. z_sign = +1.0 - P[:, 0, 0] = (2.0 / (right - left)) * scale_xyz[:, 0] - P[:, 1, 1] = (2.0 / (top - bottom)) * scale_xyz[:, 1] - P[:, 0, 3] = -(right + left) / (right - left) - P[:, 1, 3] = -(top + bottom) / (top - bottom) + P[:, 0, 0] = (2.0 / (max_x - min_x)) * scale_xyz[:, 0] + P[:, 1, 1] = (2.0 / (max_y - min_y)) * scale_xyz[:, 1] + P[:, 0, 3] = -(max_x + min_x) / (max_x - min_x) + P[:, 1, 3] = -(max_y + min_y) / (max_y - min_y) P[:, 3, 3] = ones # NOTE: This maps the z coordinate to the range [0, 1] and replaces the @@ -500,12 +631,6 @@ class OpenGLOrthographicCameras(CamerasBase): P[:, 2, 2] = z_sign * (1.0 / (zfar - znear)) * scale_xyz[:, 2] P[:, 2, 3] = -znear / (zfar - znear) - # NOTE: This part of the matrix is for z renormalization in OpenGL. - # The z is mapped to the range [-1, 1] but this won't work yet in - # pytorch3d as the rasterizer ignores faces which have z < 0. - # P[:, 2, 2] = z_sign * (2.0 / (far - near)) * scale[:, 2] - # P[:, 2, 3] = -(far + near) / (far - near) - transform = Transform3d(device=self.device) transform._matrix = P.transpose(1, 2).contiguous() return transform @@ -518,7 +643,7 @@ class OpenGLOrthographicCameras(CamerasBase): **kwargs ) -> torch.Tensor: """>! - OpenGL cameras further allow for passing depth in world units + FoV cameras further allow for passing depth in world units (`scaled_depth_input=False`) or in the [0, 1]-normalized units (`scaled_depth_input=True`) @@ -529,9 +654,9 @@ class OpenGLOrthographicCameras(CamerasBase): """ if world_coordinates: - to_screen_transform = self.get_full_projection_transform(**kwargs.copy()) + to_ndc_transform = self.get_full_projection_transform(**kwargs.copy()) else: - to_screen_transform = self.get_projection_transform(**kwargs.copy()) + to_ndc_transform = self.get_projection_transform(**kwargs.copy()) if scaled_depth_input: # the input depth is already scaled @@ -547,22 +672,88 @@ class OpenGLOrthographicCameras(CamerasBase): # cat xy and scaled depth xy_sdepth = torch.cat((xy_depth[..., :2], scaled_depth), dim=-1) # finally invert the transform - unprojection_transform = to_screen_transform.inverse() + unprojection_transform = to_ndc_transform.inverse() return unprojection_transform.transform_points(xy_sdepth) -class SfMPerspectiveCameras(CamerasBase): +############################################################ +# MultiView Camera Classes # +############################################################ +""" +Note that the MultiView Cameras accept parameters in both +screen and NDC space. +If the user specifies `image_size` at construction time then +we assume the parameters are in screen space. +""" + + +def SfMPerspectiveCameras( + focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu" +): + """ + SfMPerspectiveCameras has been DEPRECATED. Use PerspectiveCameras instead. + Preserving SfMPerspectiveCameras for backward compatibility. + """ + + warnings.warn( + """SfMPerspectiveCameras is deprecated, + Use PerspectiveCameras instead. + SfMPerspectiveCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return PerspectiveCameras( + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + device=device, + ) + + +class PerspectiveCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of transformation matrices using the multi-view geometry convention for perspective camera. + + Parameters for this camera can be specified in NDC or in screen space. + If you wish to provide parameters in screen space, you NEED to provide + the image_size = (imwidth, imheight). + If you wish to provide parameters in NDC space, you should NOT provide + image_size. Providing valid image_size will triger a screen space to + NDC space transformation in the camera. + + For example, here is how to define cameras on the two spaces. + + .. code-block:: python + # camera defined in screen space + cameras = PerspectiveCameras( + focal_length=((22.0, 15.0),), # (fx_screen, fy_screen) + principal_point=((192.0, 128.0),), # (px_screen, py_screen) + image_size=((256, 256),), # (imwidth, imheight) + ) + + # the equivalent camera defined in NDC space + cameras = PerspectiveCameras( + focal_length=((0.17875, 0.11718),), # fx = fx_screen / half_imwidth, + # fy = fy_screen / half_imheight + principal_point=((-0.5, 0),), # px = - (px_screen - half_imwidth) / half_imwidth, + # py = - (py_screen - half_imheight) / half_imheight + ) """ def __init__( - self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu" + self, + focal_length=1.0, + principal_point=((0.0, 0.0),), + R=r, + T=t, + device="cpu", + image_size=((-1, -1),), ): """ - __init__(self, focal_length, principal_point, R, T, device) -> None + __init__(self, focal_length, principal_point, R, T, device, image_size) -> None Args: focal_length: Focal length of the camera in world units. @@ -574,6 +765,11 @@ class SfMPerspectiveCameras(CamerasBase): R: Rotation matrix of shape (N, 3, 3) T: Translation matrix of shape (N, 3) device: torch.device or string + image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0 + is provided, the camera parameters are assumed to be in screen + space. They will be converted to NDC space. + If image_size is not provided, the parameters are assumed to + be in NDC space. """ # The initializer formats all inputs to torch tensors and broadcasts # all the inputs to have the same batch dimension where necessary. @@ -583,6 +779,7 @@ class SfMPerspectiveCameras(CamerasBase): principal_point=principal_point, R=R, T=T, + image_size=image_size, ) def get_projection_transform(self, **kwargs) -> Transform3d: @@ -615,9 +812,20 @@ class SfMPerspectiveCameras(CamerasBase): principal_point = kwargs.get("principal_point", self.principal_point) # pyre-ignore[16] focal_length = kwargs.get("focal_length", self.focal_length) + # pyre-ignore[16] + image_size = kwargs.get("image_size", self.image_size) + + # if imwidth > 0, parameters are in screen space + in_screen = image_size[0][0] > 0 + image_size = image_size if in_screen else None P = _get_sfm_calibration_matrix( - self._N, self.device, focal_length, principal_point, False + self._N, + self.device, + focal_length, + principal_point, + orthographic=False, + image_size=image_size, ) transform = Transform3d(device=self.device) @@ -628,29 +836,83 @@ class SfMPerspectiveCameras(CamerasBase): self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs ) -> torch.Tensor: if world_coordinates: - to_screen_transform = self.get_full_projection_transform(**kwargs) + to_ndc_transform = self.get_full_projection_transform(**kwargs) else: - to_screen_transform = self.get_projection_transform(**kwargs) + to_ndc_transform = self.get_projection_transform(**kwargs) - unprojection_transform = to_screen_transform.inverse() + unprojection_transform = to_ndc_transform.inverse() xy_inv_depth = torch.cat( (xy_depth[..., :2], 1.0 / xy_depth[..., 2:3]), dim=-1 # type: ignore ) return unprojection_transform.transform_points(xy_inv_depth) -class SfMOrthographicCameras(CamerasBase): +def SfMOrthographicCameras( + focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu" +): + """ + SfMOrthographicCameras has been DEPRECATED. Use OrthographicCameras instead. + Preserving SfMOrthographicCameras for backward compatibility. + """ + + warnings.warn( + """SfMOrthographicCameras is deprecated, + Use OrthographicCameras instead. + SfMOrthographicCameras will be removed in future releases.""", + PendingDeprecationWarning, + ) + + return OrthographicCameras( + focal_length=focal_length, + principal_point=principal_point, + R=R, + T=T, + device=device, + ) + + +class OrthographicCameras(CamerasBase): """ A class which stores a batch of parameters to generate a batch of transformation matrices using the multi-view geometry convention for orthographic camera. + + Parameters for this camera can be specified in NDC or in screen space. + If you wish to provide parameters in screen space, you NEED to provide + the image_size = (imwidth, imheight). + If you wish to provide parameters in NDC space, you should NOT provide + image_size. Providing valid image_size will triger a screen space to + NDC space transformation in the camera. + + For example, here is how to define cameras on the two spaces. + + .. code-block:: python + # camera defined in screen space + cameras = OrthographicCameras( + focal_length=((22.0, 15.0),), # (fx, fy) + principal_point=((192.0, 128.0),), # (px, py) + image_size=((256, 256),), # (imwidth, imheight) + ) + + # the equivalent camera defined in NDC space + cameras = OrthographicCameras( + focal_length=((0.17875, 0.11718),), # := (fx / half_imwidth, fy / half_imheight) + principal_point=((-0.5, 0),), # := (- (px - half_imwidth) / half_imwidth, + - (py - half_imheight) / half_imheight) + ) """ def __init__( - self, focal_length=1.0, principal_point=((0.0, 0.0),), R=r, T=t, device="cpu" + self, + focal_length=1.0, + principal_point=((0.0, 0.0),), + R=r, + T=t, + device="cpu", + image_size=((-1, -1),), ): """ - __init__(self, focal_length, principal_point, R, T, device) -> None + __init__(self, focal_length, principal_point, R, T, device, image_size) -> None Args: focal_length: Focal length of the camera in world units. @@ -662,6 +924,11 @@ class SfMOrthographicCameras(CamerasBase): R: Rotation matrix of shape (N, 3, 3) T: Translation matrix of shape (N, 3) device: torch.device or string + image_size: If image_size = (imwidth, imheight) with imwidth, imheight > 0 + is provided, the camera parameters are assumed to be in screen + space. They will be converted to NDC space. + If image_size is not provided, the parameters are assumed to + be in NDC space. """ # The initializer formats all inputs to torch tensors and broadcasts # all the inputs to have the same batch dimension where necessary. @@ -671,6 +938,7 @@ class SfMOrthographicCameras(CamerasBase): principal_point=principal_point, R=R, T=T, + image_size=image_size, ) def get_projection_transform(self, **kwargs) -> Transform3d: @@ -703,9 +971,20 @@ class SfMOrthographicCameras(CamerasBase): principal_point = kwargs.get("principal_point", self.principal_point) # pyre-ignore[16] focal_length = kwargs.get("focal_length", self.focal_length) + # pyre-ignore[16] + image_size = kwargs.get("image_size", self.image_size) + + # if imwidth > 0, parameters are in screen space + in_screen = image_size[0][0] > 0 + image_size = image_size if in_screen else None P = _get_sfm_calibration_matrix( - self._N, self.device, focal_length, principal_point, True + self._N, + self.device, + focal_length, + principal_point, + orthographic=True, + image_size=image_size, ) transform = Transform3d(device=self.device) @@ -716,17 +995,26 @@ class SfMOrthographicCameras(CamerasBase): self, xy_depth: torch.Tensor, world_coordinates: bool = True, **kwargs ) -> torch.Tensor: if world_coordinates: - to_screen_transform = self.get_full_projection_transform(**kwargs) + to_ndc_transform = self.get_full_projection_transform(**kwargs) else: - to_screen_transform = self.get_projection_transform(**kwargs) + to_ndc_transform = self.get_projection_transform(**kwargs) - unprojection_transform = to_screen_transform.inverse() + unprojection_transform = to_ndc_transform.inverse() return unprojection_transform.transform_points(xy_depth) -# SfMCameras helper +################################################ +# Helper functions for cameras # +################################################ + + def _get_sfm_calibration_matrix( - N, device, focal_length, principal_point, orthographic: bool + N, + device, + focal_length, + principal_point, + orthographic: bool = False, + image_size=None, ) -> torch.Tensor: """ Returns a calibration matrix of a perspective/orthograpic camera. @@ -736,6 +1024,10 @@ def _get_sfm_calibration_matrix( focal_length: Focal length of the camera in world units. principal_point: xy coordinates of the center of the principal point of the camera in pixels. + orthographic: Boolean specifying if the camera is orthographic or not + image_size: (Optional) Specifying the image_size = (imwidth, imheight). + If not None, the camera parameters are assumed to be in screen space + and are transformed to NDC space. The calibration matrix `K` is set up as follows: @@ -769,7 +1061,7 @@ def _get_sfm_calibration_matrix( if not torch.is_tensor(focal_length): focal_length = torch.tensor(focal_length, device=device) - if len(focal_length.shape) in (0, 1) or focal_length.shape[1] == 1: + if focal_length.ndim in (0, 1) or focal_length.shape[1] == 1: fx = fy = focal_length else: fx, fy = focal_length.unbind(1) @@ -779,6 +1071,22 @@ def _get_sfm_calibration_matrix( px, py = principal_point.unbind(1) + if image_size is not None: + if not torch.is_tensor(image_size): + image_size = torch.tensor(image_size, device=device) + imwidth, imheight = image_size.unbind(1) + # make sure imwidth, imheight are valid (>0) + if (imwidth < 1).any() or (imheight < 1).any(): + raise ValueError( + "Camera parameters provided in screen space. Image width or height invalid." + ) + half_imwidth = imwidth / 2.0 + half_imheight = imheight / 2.0 + fx = fx / half_imwidth + fy = fy / half_imheight + px = -(px - half_imwidth) / half_imwidth + py = -(py - half_imheight) / half_imheight + K = fx.new_zeros(N, 4, 4) K[:, 0, 0] = fx K[:, 1, 1] = fy diff --git a/tests/bm_barycentric_clipping.py b/tests/bm_barycentric_clipping.py index 0941a97c..df72987e 100644 --- a/tests/bm_barycentric_clipping.py +++ b/tests/bm_barycentric_clipping.py @@ -4,7 +4,7 @@ from itertools import product import torch from fvcore.common.benchmark import benchmark -from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.mesh.rasterizer import ( Fragments, MeshRasterizer, @@ -28,7 +28,7 @@ def baryclip_cuda( sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) # Init transform R, T = look_at_view_transform(1.0, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) # Init rasterizer raster_settings = RasterizationSettings( image_size=image_size, @@ -58,7 +58,7 @@ def baryclip_pytorch( sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) # Init transform R, T = look_at_view_transform(1.0, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) # Init rasterizer raster_settings = RasterizationSettings( image_size=image_size, diff --git a/tests/bm_mesh_rasterizer_transform.py b/tests/bm_mesh_rasterizer_transform.py index 97672504..0d875f3f 100644 --- a/tests/bm_mesh_rasterizer_transform.py +++ b/tests/bm_mesh_rasterizer_transform.py @@ -5,7 +5,7 @@ from itertools import product import torch from fvcore.common.benchmark import benchmark -from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer from pytorch3d.utils.ico_sphere import ico_sphere @@ -15,7 +15,7 @@ def rasterize_transform_with_init(num_meshes: int, ico_level: int = 5, device="c sphere_meshes = ico_sphere(ico_level, device).extend(num_meshes) # Init transform R, T = look_at_view_transform(1.0, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) # Init rasterizer rasterizer = MeshRasterizer(cameras=cameras) diff --git a/tests/data/test_FoVOrthographicCameras_silhouette.png b/tests/data/test_FoVOrthographicCameras_silhouette.png new file mode 100644 index 00000000..df0c3d0a Binary files /dev/null and b/tests/data/test_FoVOrthographicCameras_silhouette.png differ diff --git a/tests/data/test_FoVPerspectiveCameras_silhouette.png b/tests/data/test_FoVPerspectiveCameras_silhouette.png new file mode 100644 index 00000000..86d9217b Binary files /dev/null and b/tests/data/test_FoVPerspectiveCameras_silhouette.png differ diff --git a/tests/data/test_OrthographicCameras_silhouette.png b/tests/data/test_OrthographicCameras_silhouette.png new file mode 100644 index 00000000..df0c3d0a Binary files /dev/null and b/tests/data/test_OrthographicCameras_silhouette.png differ diff --git a/tests/data/test_PerspectiveCameras_silhouette.png b/tests/data/test_PerspectiveCameras_silhouette.png new file mode 100644 index 00000000..a6fd9978 Binary files /dev/null and b/tests/data/test_PerspectiveCameras_silhouette.png differ diff --git a/tests/data/test_silhouette.png b/tests/data/test_silhouette.png deleted file mode 100644 index efe34d92..00000000 Binary files a/tests/data/test_silhouette.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_batched_flat_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_batched_flat_FoVPerspectiveCameras.png new file mode 100644 index 00000000..891d6be3 Binary files /dev/null and b/tests/data/test_simple_sphere_batched_flat_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_batched_gouraud_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_batched_gouraud_FoVPerspectiveCameras.png new file mode 100644 index 00000000..4f2d4cba Binary files /dev/null and b/tests/data/test_simple_sphere_batched_gouraud_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_batched_phong_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_batched_phong_FoVPerspectiveCameras.png new file mode 100644 index 00000000..516e7484 Binary files /dev/null and b/tests/data/test_simple_sphere_batched_phong_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_dark.png b/tests/data/test_simple_sphere_dark.png deleted file mode 100644 index 7e6c6ca9..00000000 Binary files a/tests/data/test_simple_sphere_dark.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_dark_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_dark_FoVOrthographicCameras.png new file mode 100644 index 00000000..2567cc1d Binary files /dev/null and b/tests/data/test_simple_sphere_dark_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_dark_FoVPerspectiveCameras.png new file mode 100644 index 00000000..0e3f8e7b Binary files /dev/null and b/tests/data/test_simple_sphere_dark_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_OrthographicCameras.png b/tests/data/test_simple_sphere_dark_OrthographicCameras.png new file mode 100644 index 00000000..2567cc1d Binary files /dev/null and b/tests/data/test_simple_sphere_dark_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_PerspectiveCameras.png b/tests/data/test_simple_sphere_dark_PerspectiveCameras.png new file mode 100644 index 00000000..f82fb7b6 Binary files /dev/null and b/tests/data/test_simple_sphere_dark_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_elevated_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_dark_elevated_FoVOrthographicCameras.png new file mode 100644 index 00000000..132d948e Binary files /dev/null and b/tests/data/test_simple_sphere_dark_elevated_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_elevated_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_dark_elevated_FoVPerspectiveCameras.png new file mode 100644 index 00000000..889f9267 Binary files /dev/null and b/tests/data/test_simple_sphere_dark_elevated_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_elevated_OrthographicCameras.png b/tests/data/test_simple_sphere_dark_elevated_OrthographicCameras.png new file mode 100644 index 00000000..132d948e Binary files /dev/null and b/tests/data/test_simple_sphere_dark_elevated_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_elevated_PerspectiveCameras.png b/tests/data/test_simple_sphere_dark_elevated_PerspectiveCameras.png new file mode 100644 index 00000000..e7e6a7d6 Binary files /dev/null and b/tests/data/test_simple_sphere_dark_elevated_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_dark_elevated_camera.png b/tests/data/test_simple_sphere_dark_elevated_camera.png deleted file mode 100644 index 7ad8d966..00000000 Binary files a/tests/data/test_simple_sphere_dark_elevated_camera.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_light_flat.png b/tests/data/test_simple_sphere_light_flat.png deleted file mode 100644 index f8573d6d..00000000 Binary files a/tests/data/test_simple_sphere_light_flat.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_light_flat_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_light_flat_FoVOrthographicCameras.png new file mode 100644 index 00000000..ab846daa Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_light_flat_FoVPerspectiveCameras.png new file mode 100644 index 00000000..891d6be3 Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_OrthographicCameras.png b/tests/data/test_simple_sphere_light_flat_OrthographicCameras.png new file mode 100644 index 00000000..ab846daa Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_PerspectiveCameras.png b/tests/data/test_simple_sphere_light_flat_PerspectiveCameras.png new file mode 100644 index 00000000..457c16ab Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_elevated_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_light_flat_elevated_FoVOrthographicCameras.png new file mode 100644 index 00000000..869d4a8f Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_elevated_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_elevated_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_light_flat_elevated_FoVPerspectiveCameras.png new file mode 100644 index 00000000..eb0b0622 Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_elevated_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_elevated_OrthographicCameras.png b/tests/data/test_simple_sphere_light_flat_elevated_OrthographicCameras.png new file mode 100644 index 00000000..869d4a8f Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_elevated_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_elevated_PerspectiveCameras.png b/tests/data/test_simple_sphere_light_flat_elevated_PerspectiveCameras.png new file mode 100644 index 00000000..ef564031 Binary files /dev/null and b/tests/data/test_simple_sphere_light_flat_elevated_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_flat_elevated_camera.png b/tests/data/test_simple_sphere_light_flat_elevated_camera.png deleted file mode 100644 index d0eb8fb6..00000000 Binary files a/tests/data/test_simple_sphere_light_flat_elevated_camera.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_light_gouraud.png b/tests/data/test_simple_sphere_light_gouraud.png deleted file mode 100644 index 57737e9f..00000000 Binary files a/tests/data/test_simple_sphere_light_gouraud.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_light_gouraud_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_light_gouraud_FoVOrthographicCameras.png new file mode 100644 index 00000000..bf56179d Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_light_gouraud_FoVPerspectiveCameras.png new file mode 100644 index 00000000..4f2d4cba Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_OrthographicCameras.png b/tests/data/test_simple_sphere_light_gouraud_OrthographicCameras.png new file mode 100644 index 00000000..bf56179d Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_PerspectiveCameras.png b/tests/data/test_simple_sphere_light_gouraud_PerspectiveCameras.png new file mode 100644 index 00000000..64c1e70a Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_elevated_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_light_gouraud_elevated_FoVOrthographicCameras.png new file mode 100644 index 00000000..be92ce82 Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_elevated_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_elevated_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_light_gouraud_elevated_FoVPerspectiveCameras.png new file mode 100644 index 00000000..b8aba25c Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_elevated_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_elevated_OrthographicCameras.png b/tests/data/test_simple_sphere_light_gouraud_elevated_OrthographicCameras.png new file mode 100644 index 00000000..be92ce82 Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_elevated_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_elevated_PerspectiveCameras.png b/tests/data/test_simple_sphere_light_gouraud_elevated_PerspectiveCameras.png new file mode 100644 index 00000000..9b8d29d7 Binary files /dev/null and b/tests/data/test_simple_sphere_light_gouraud_elevated_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_gouraud_elevated_camera.png b/tests/data/test_simple_sphere_light_gouraud_elevated_camera.png deleted file mode 100644 index a39ef8f2..00000000 Binary files a/tests/data/test_simple_sphere_light_gouraud_elevated_camera.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_light_phong.png b/tests/data/test_simple_sphere_light_phong.png deleted file mode 100644 index d3aa1a7e..00000000 Binary files a/tests/data/test_simple_sphere_light_phong.png and /dev/null differ diff --git a/tests/data/test_simple_sphere_light_phong_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_light_phong_FoVOrthographicCameras.png new file mode 100644 index 00000000..d58b2109 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_light_phong_FoVPerspectiveCameras.png new file mode 100644 index 00000000..516e7484 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_OrthographicCameras.png b/tests/data/test_simple_sphere_light_phong_OrthographicCameras.png new file mode 100644 index 00000000..d58b2109 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_PerspectiveCameras.png b/tests/data/test_simple_sphere_light_phong_PerspectiveCameras.png new file mode 100644 index 00000000..a8690e08 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_elevated_FoVOrthographicCameras.png b/tests/data/test_simple_sphere_light_phong_elevated_FoVOrthographicCameras.png new file mode 100644 index 00000000..0d46b6a0 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_elevated_FoVOrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_elevated_FoVPerspectiveCameras.png b/tests/data/test_simple_sphere_light_phong_elevated_FoVPerspectiveCameras.png new file mode 100644 index 00000000..aa54b67a Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_elevated_FoVPerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_elevated_OrthographicCameras.png b/tests/data/test_simple_sphere_light_phong_elevated_OrthographicCameras.png new file mode 100644 index 00000000..0d46b6a0 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_elevated_OrthographicCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_elevated_PerspectiveCameras.png b/tests/data/test_simple_sphere_light_phong_elevated_PerspectiveCameras.png new file mode 100644 index 00000000..d8dd5f32 Binary files /dev/null and b/tests/data/test_simple_sphere_light_phong_elevated_PerspectiveCameras.png differ diff --git a/tests/data/test_simple_sphere_light_phong_elevated_camera.png b/tests/data/test_simple_sphere_light_phong_elevated_camera.png deleted file mode 100644 index b3fe6921..00000000 Binary files a/tests/data/test_simple_sphere_light_phong_elevated_camera.png and /dev/null differ diff --git a/tests/test_cameras.py b/tests/test_cameras.py index 175bd67e..88c78ea7 100644 --- a/tests/test_cameras.py +++ b/tests/test_cameras.py @@ -31,12 +31,16 @@ import unittest import numpy as np import torch from common_testing import TestCaseMixin +from pytorch3d.renderer.cameras import OpenGLOrthographicCameras # deprecated +from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras # deprecated +from pytorch3d.renderer.cameras import SfMOrthographicCameras # deprecated +from pytorch3d.renderer.cameras import SfMPerspectiveCameras # deprecated from pytorch3d.renderer.cameras import ( CamerasBase, - OpenGLOrthographicCameras, - OpenGLPerspectiveCameras, - SfMOrthographicCameras, - SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, camera_position_from_spherical_angles, get_world_to_view_transform, look_at_rotation, @@ -109,6 +113,25 @@ def orthographic_project_naive(points, scale_xyz=(1.0, 1.0, 1.0)): return points +def ndc_to_screen_points_naive(points, imsize): + """ + Transforms points from PyTorch3D's NDC space to screen space + Args: + points: (N, V, 3) representing padded points + imsize: (N, 2) image size = (width, height) + Returns: + (N, V, 3) tensor of transformed points + """ + imwidth, imheight = imsize.unbind(1) + imwidth = imwidth.view(-1, 1) + imheight = imheight.view(-1, 1) + + x, y, z = points.unbind(2) + x = (1.0 - x) * (imwidth - 1) / 2.0 + y = (1.0 - y) * (imheight - 1) / 2.0 + return torch.stack((x, y, z), dim=2) + + class TestCameraHelpers(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: super().setUp() @@ -359,6 +382,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): OpenGLOrthographicCameras, SfMOrthographicCameras, SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, ): cam = cam_type(R=R, T=T) RT_class = cam.get_world_to_view_transform() @@ -374,6 +401,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): OpenGLOrthographicCameras, SfMOrthographicCameras, SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, ): cam = cam_type(R=R, T=T) C = cam.get_camera_center() @@ -398,13 +429,53 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): cam_params["bottom"] = -(torch.rand(batch_size)) * 0.2 - 0.9 cam_params["left"] = -(torch.rand(batch_size)) * 0.2 - 0.9 cam_params["right"] = torch.rand(batch_size) * 0.2 + 0.9 - elif cam_type in (SfMOrthographicCameras, SfMPerspectiveCameras): + elif cam_type in (FoVPerspectiveCameras, FoVOrthographicCameras): + cam_params["znear"] = torch.rand(batch_size) * 10 + 0.1 + cam_params["zfar"] = torch.rand(batch_size) * 4 + 1 + cam_params["znear"] + if cam_type == FoVPerspectiveCameras: + cam_params["fov"] = torch.rand(batch_size) * 60 + 30 + cam_params["aspect_ratio"] = torch.rand(batch_size) * 0.5 + 0.5 + else: + cam_params["max_y"] = torch.rand(batch_size) * 0.2 + 0.9 + cam_params["min_y"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["min_x"] = -(torch.rand(batch_size)) * 0.2 - 0.9 + cam_params["max_x"] = torch.rand(batch_size) * 0.2 + 0.9 + elif cam_type in ( + SfMOrthographicCameras, + SfMPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + ): cam_params["focal_length"] = torch.rand(batch_size) * 10 + 0.1 cam_params["principal_point"] = torch.randn((batch_size, 2)) + else: raise ValueError(str(cam_type)) return cam_type(**cam_params) + @staticmethod + def init_equiv_cameras_ndc_screen(cam_type: CamerasBase, batch_size: int): + T = torch.randn(batch_size, 3) * 0.03 + T[:, 2] = 4 + R = so3_exponential_map(torch.randn(batch_size, 3) * 3.0) + screen_cam_params = {"R": R, "T": T} + ndc_cam_params = {"R": R, "T": T} + if cam_type in (OrthographicCameras, PerspectiveCameras): + ndc_cam_params["focal_length"] = torch.rand((batch_size, 2)) * 3.0 + ndc_cam_params["principal_point"] = torch.randn((batch_size, 2)) + + image_size = torch.randint(low=2, high=64, size=(batch_size, 2)) + screen_cam_params["image_size"] = image_size + screen_cam_params["focal_length"] = ( + ndc_cam_params["focal_length"] * image_size / 2.0 + ) + screen_cam_params["principal_point"] = ( + (1.0 - ndc_cam_params["principal_point"]) * image_size / 2.0 + ) + else: + raise ValueError(str(cam_type)) + return cam_type(**ndc_cam_params), cam_type(**screen_cam_params) + def test_unproject_points(self, batch_size=50, num_points=100): """ Checks that an unprojection of a randomly projected point cloud @@ -416,6 +487,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): OpenGLPerspectiveCameras, OpenGLOrthographicCameras, SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, ): # init the cameras cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) @@ -437,9 +512,14 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): else: matching_xyz = xyz_cam - # if we have OpenGL cameras + # if we have FoV (= OpenGL) cameras # test for scaled_depth_input=True/False - if cam_type in (OpenGLPerspectiveCameras, OpenGLOrthographicCameras): + if cam_type in ( + OpenGLPerspectiveCameras, + OpenGLOrthographicCameras, + FoVPerspectiveCameras, + FoVOrthographicCameras, + ): for scaled_depth_input in (True, False): if scaled_depth_input: xy_depth_ = xyz_proj @@ -459,6 +539,56 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): ) self.assertTrue(torch.allclose(xyz_unproj, matching_xyz, atol=1e-4)) + def test_project_points_screen(self, batch_size=50, num_points=100): + """ + Checks that an unprojection of a randomly projected point cloud + stays the same. + """ + + for cam_type in ( + OpenGLOrthographicCameras, + OpenGLPerspectiveCameras, + SfMOrthographicCameras, + SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + ): + + # init the cameras + cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) + # xyz - the ground truth point cloud + xyz = torch.randn(batch_size, num_points, 3) * 0.3 + # image size + image_size = torch.randint(low=2, high=64, size=(batch_size, 2)) + # project points + xyz_project_ndc = cameras.transform_points(xyz) + xyz_project_screen = cameras.transform_points_screen(xyz, image_size) + # naive + xyz_project_screen_naive = ndc_to_screen_points_naive( + xyz_project_ndc, image_size + ) + self.assertClose(xyz_project_screen, xyz_project_screen_naive) + + def test_equiv_project_points(self, batch_size=50, num_points=100): + """ + Checks that NDC and screen cameras project points to ndc correctly. + Applies only to OrthographicCameras and PerspectiveCameras. + """ + for cam_type in (OrthographicCameras, PerspectiveCameras): + # init the cameras + ( + ndc_cameras, + screen_cameras, + ) = TestCamerasCommon.init_equiv_cameras_ndc_screen(cam_type, batch_size) + # xyz - the ground truth point cloud + xyz = torch.randn(batch_size, num_points, 3) * 0.3 + # project points + xyz_ndc_cam = ndc_cameras.transform_points(xyz) + xyz_screen_cam = screen_cameras.transform_points(xyz) + self.assertClose(xyz_ndc_cam, xyz_screen_cam, atol=1e-6) + def test_clone(self, batch_size: int = 10): """ Checks the clone function of the cameras. @@ -468,6 +598,10 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): OpenGLPerspectiveCameras, OpenGLOrthographicCameras, SfMPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, ): cameras = TestCamerasCommon.init_random_cameras(cam_type, batch_size) cameras = cameras.to(torch.device("cpu")) @@ -483,11 +617,16 @@ class TestCamerasCommon(TestCaseMixin, unittest.TestCase): self.assertTrue(val == val_clone) -class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): +############################################################ +# FoVPerspective Camera # +############################################################ + + +class TestFoVPerspectiveProjection(TestCaseMixin, unittest.TestCase): def test_perspective(self): far = 10.0 near = 1.0 - cameras = OpenGLPerspectiveCameras(znear=near, zfar=far, fov=60.0) + cameras = FoVPerspectiveCameras(znear=near, zfar=far, fov=60.0) P = cameras.get_projection_transform() # vertices are at the far clipping plane so z gets mapped to 1. vertices = torch.tensor([1, 2, far], dtype=torch.float32) @@ -512,7 +651,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): self.assertClose(v1.squeeze(), projected_verts) def test_perspective_kwargs(self): - cameras = OpenGLPerspectiveCameras(znear=5.0, zfar=100.0, fov=0.0) + cameras = FoVPerspectiveCameras(znear=5.0, zfar=100.0, fov=0.0) # Override defaults by passing in values to get_projection_transform far = 10.0 P = cameras.get_projection_transform(znear=1.0, zfar=far, fov=60.0) @@ -528,7 +667,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): far = torch.tensor([10.0, 20.0], dtype=torch.float32) near = 1.0 fov = torch.tensor(60.0) - cameras = OpenGLPerspectiveCameras(znear=near, zfar=far, fov=fov) + cameras = FoVPerspectiveCameras(znear=near, zfar=far, fov=fov) P = cameras.get_projection_transform() vertices = torch.tensor([1, 2, 10], dtype=torch.float32) z1 = 1.0 # vertices at far clipping plane so z = 1.0 @@ -550,7 +689,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): far = torch.tensor([10.0]) near = 1.0 fov = torch.tensor(60.0, requires_grad=True) - cameras = OpenGLPerspectiveCameras(znear=near, zfar=far, fov=fov) + cameras = FoVPerspectiveCameras(znear=near, zfar=far, fov=fov) P = cameras.get_projection_transform() vertices = torch.tensor([1, 2, 10], dtype=torch.float32) vertices_batch = vertices[None, None, :] @@ -566,7 +705,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): def test_camera_class_init(self): device = torch.device("cuda:0") - cam = OpenGLPerspectiveCameras(znear=10.0, zfar=(100.0, 200.0)) + cam = FoVPerspectiveCameras(znear=10.0, zfar=(100.0, 200.0)) # Check broadcasting self.assertTrue(cam.znear.shape == (2,)) @@ -585,7 +724,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): self.assertTrue(new_cam.device == device) def test_get_full_transform(self): - cam = OpenGLPerspectiveCameras() + cam = FoVPerspectiveCameras() T = torch.tensor([0.0, 0.0, 1.0]).view(1, -1) R = look_at_rotation(T) P = cam.get_full_projection_transform(R=R, T=T) @@ -597,7 +736,7 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): # Check transform_points methods works with default settings for # RT and P far = 10.0 - cam = OpenGLPerspectiveCameras(znear=1.0, zfar=far, fov=60.0) + cam = FoVPerspectiveCameras(znear=1.0, zfar=far, fov=60.0) points = torch.tensor([1, 2, far], dtype=torch.float32) points = points.view(1, 1, 3).expand(5, 10, -1) projected_points = torch.tensor( @@ -608,11 +747,16 @@ class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): self.assertClose(new_points, projected_points) -class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): +############################################################ +# FoVOrthographic Camera # +############################################################ + + +class TestFoVOrthographicProjection(TestCaseMixin, unittest.TestCase): def test_orthographic(self): far = 10.0 near = 1.0 - cameras = OpenGLOrthographicCameras(znear=near, zfar=far) + cameras = FoVOrthographicCameras(znear=near, zfar=far) P = cameras.get_projection_transform() vertices = torch.tensor([1, 2, far], dtype=torch.float32) @@ -637,7 +781,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): # applying the scale puts the z coordinate at the far clipping plane # so the z is mapped to 1.0 projected_verts = torch.tensor([2, 1, 1], dtype=torch.float32) - cameras = OpenGLOrthographicCameras(znear=1.0, zfar=10.0, scale_xyz=scale) + cameras = FoVOrthographicCameras(znear=1.0, zfar=10.0, scale_xyz=scale) P = cameras.get_projection_transform() v1 = P.transform_points(vertices) v2 = orthographic_project_naive(vertices, scale) @@ -645,7 +789,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): self.assertClose(v1, projected_verts[None, None]) def test_orthographic_kwargs(self): - cameras = OpenGLOrthographicCameras(znear=5.0, zfar=100.0) + cameras = FoVOrthographicCameras(znear=5.0, zfar=100.0) far = 10.0 P = cameras.get_projection_transform(znear=1.0, zfar=far) vertices = torch.tensor([1, 2, far], dtype=torch.float32) @@ -657,7 +801,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): def test_orthographic_mixed_inputs_broadcast(self): far = torch.tensor([10.0, 20.0]) near = 1.0 - cameras = OpenGLOrthographicCameras(znear=near, zfar=far) + cameras = FoVOrthographicCameras(znear=near, zfar=far) P = cameras.get_projection_transform() vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32) z2 = 1.0 / (20.0 - 1.0) * 10.0 + -1.0 / (20.0 - 1.0) @@ -674,7 +818,7 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): far = torch.tensor([10.0]) near = 1.0 scale = torch.tensor([[1.0, 1.0, 1.0]], requires_grad=True) - cameras = OpenGLOrthographicCameras(znear=near, zfar=far, scale_xyz=scale) + cameras = FoVOrthographicCameras(znear=near, zfar=far, scale_xyz=scale) P = cameras.get_projection_transform() vertices = torch.tensor([1.0, 2.0, 10.0], dtype=torch.float32) vertices_batch = vertices[None, None, :] @@ -694,9 +838,14 @@ class TestOpenGLOrthographicProjection(TestCaseMixin, unittest.TestCase): self.assertClose(scale_grad, grad_scale) -class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase): +############################################################ +# Orthographic Camera # +############################################################ + + +class TestOrthographicProjection(TestCaseMixin, unittest.TestCase): def test_orthographic(self): - cameras = SfMOrthographicCameras() + cameras = OrthographicCameras() P = cameras.get_projection_transform() vertices = torch.randn([3, 4, 3], dtype=torch.float32) @@ -711,9 +860,7 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase): focal_length_x = 10.0 focal_length_y = 15.0 - cameras = SfMOrthographicCameras( - focal_length=((focal_length_x, focal_length_y),) - ) + cameras = OrthographicCameras(focal_length=((focal_length_x, focal_length_y),)) P = cameras.get_projection_transform() vertices = torch.randn([3, 4, 3], dtype=torch.float32) @@ -730,9 +877,7 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase): self.assertClose(v1, projected_verts) def test_orthographic_kwargs(self): - cameras = SfMOrthographicCameras( - focal_length=5.0, principal_point=((2.5, 2.5),) - ) + cameras = OrthographicCameras(focal_length=5.0, principal_point=((2.5, 2.5),)) P = cameras.get_projection_transform( focal_length=2.0, principal_point=((2.5, 3.5),) ) @@ -745,9 +890,14 @@ class TestSfMOrthographicProjection(TestCaseMixin, unittest.TestCase): self.assertClose(v1, projected_verts) -class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase): +############################################################ +# Perspective Camera # +############################################################ + + +class TestPerspectiveProjection(TestCaseMixin, unittest.TestCase): def test_perspective(self): - cameras = SfMPerspectiveCameras() + cameras = PerspectiveCameras() P = cameras.get_projection_transform() vertices = torch.randn([3, 4, 3], dtype=torch.float32) @@ -761,7 +911,7 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase): p0x = 15.0 p0y = 30.0 - cameras = SfMPerspectiveCameras( + cameras = PerspectiveCameras( focal_length=((focal_length_x, focal_length_y),), principal_point=((p0x, p0y),), ) @@ -777,7 +927,7 @@ class TestSfMPerspectiveProjection(TestCaseMixin, unittest.TestCase): self.assertClose(v3[..., :2], v2[..., :2]) def test_perspective_kwargs(self): - cameras = SfMPerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),)) + cameras = PerspectiveCameras(focal_length=5.0, principal_point=((2.5, 2.5),)) P = cameras.get_projection_transform( focal_length=2.0, principal_point=((2.5, 3.5),) ) diff --git a/tests/test_r2n2.py b/tests/test_r2n2.py index 0765e699..921f9b11 100644 --- a/tests/test_r2n2.py +++ b/tests/test_r2n2.py @@ -18,7 +18,7 @@ from pytorch3d.datasets import ( render_cubified_voxels, ) from pytorch3d.renderer import ( - OpenGLPerspectiveCameras, + FoVPerspectiveCameras, PointLights, RasterizationSettings, look_at_view_transform, @@ -211,7 +211,7 @@ class TestR2N2(TestCaseMixin, unittest.TestCase): # Render first three models in the dataset. R, T = look_at_view_transform(1.0, 1.0, 90) - cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) + cameras = FoVPerspectiveCameras(R=R, T=T, device=device) raster_settings = RasterizationSettings(image_size=512) lights = PointLights( location=torch.tensor([0.0, 1.0, -2.0], device=device)[None], diff --git a/tests/test_rasterizer.py b/tests/test_rasterizer.py index da1c95e4..df13c317 100644 --- a/tests/test_rasterizer.py +++ b/tests/test_rasterizer.py @@ -7,7 +7,7 @@ from pathlib import Path import numpy as np import torch from PIL import Image -from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings from pytorch3d.renderer.points.rasterizer import ( PointsRasterizationSettings, @@ -43,7 +43,7 @@ class TestMeshRasterizer(unittest.TestCase): # Init rasterizer settings R, T = look_at_view_transform(2.7, 0, 0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0 ) @@ -148,7 +148,7 @@ class TestPointRasterizer(unittest.TestCase): verts_padded[..., 0] += 0.2 pointclouds = Pointclouds(points=verts_padded) R, T = look_at_view_transform(2.7, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = PointsRasterizationSettings( image_size=256, radius=5e-2, points_per_pixel=1 ) diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index be5b9ec6..94e49862 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -4,6 +4,7 @@ """ Sanity checks for output images from the renderer. """ +import os import unittest from pathlib import Path @@ -12,7 +13,13 @@ import torch from common_testing import TestCaseMixin, load_rgb_image from PIL import Image from pytorch3d.io import load_obj -from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform +from pytorch3d.renderer.cameras import ( + FoVOrthographicCameras, + FoVPerspectiveCameras, + OrthographicCameras, + PerspectiveCameras, + look_at_view_transform, +) from pytorch3d.renderer.lighting import PointLights from pytorch3d.renderer.materials import Materials from pytorch3d.renderer.mesh import TexturesAtlas, TexturesUV, TexturesVertex @@ -60,78 +67,94 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): if elevated_camera: # Elevated and rotated camera R, T = look_at_view_transform(dist=2.7, elev=45.0, azim=45.0) - postfix = "_elevated_camera" + postfix = "_elevated_" # If y axis is up, the spot of light should # be on the bottom left of the sphere. else: # No elevation or azimuth rotation R, T = look_at_view_transform(2.7, 0.0, 0.0) - postfix = "" - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + postfix = "_" + for cam_type in ( + FoVPerspectiveCameras, + FoVOrthographicCameras, + PerspectiveCameras, + OrthographicCameras, + ): + cameras = cam_type(device=device, R=R, T=T) - # Init shader settings - materials = Materials(device=device) - lights = PointLights(device=device) - lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + # Init shader settings + materials = Materials(device=device) + lights = PointLights(device=device) + lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] - raster_settings = RasterizationSettings( - image_size=512, blur_radius=0.0, faces_per_pixel=1 - ) - rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) - blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1 + ) + rasterizer = MeshRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) - # Test several shaders - shaders = { - "phong": HardPhongShader, - "gouraud": HardGouraudShader, - "flat": HardFlatShader, - } - for (name, shader_init) in shaders.items(): - shader = shader_init( + # Test several shaders + shaders = { + "phong": HardPhongShader, + "gouraud": HardGouraudShader, + "flat": HardFlatShader, + } + for (name, shader_init) in shaders.items(): + shader = shader_init( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + images = renderer(sphere_mesh) + rgb = images[0, ..., :3].squeeze().cpu() + filename = "simple_sphere_light_%s%s%s.png" % ( + name, + postfix, + cam_type.__name__, + ) + + image_ref = load_rgb_image("test_%s" % filename, DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05) + + if DEBUG: + filename = "DEBUG_%s" % filename + Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / filename + ) + + ######################################################## + # Move the light to the +z axis in world space so it is + # behind the sphere. Note that +Z is in, +Y up, + # +X left for both world and camera space. + ######################################################## + lights.location[..., 2] = -2.0 + phong_shader = HardPhongShader( lights=lights, cameras=cameras, materials=materials, blend_params=blend_params, ) - renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) - images = renderer(sphere_mesh) - filename = "simple_sphere_light_%s%s.png" % (name, postfix) - image_ref = load_rgb_image("test_%s" % filename, DATA_DIR) + phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader) + images = phong_renderer(sphere_mesh, lights=lights) rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - filename = "DEBUG_%s" % filename + filename = "DEBUG_simple_sphere_dark%s%s.png" % ( + postfix, + cam_type.__name__, + ) Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( DATA_DIR / filename ) - self.assertClose(rgb, image_ref, atol=0.05) - ######################################################## - # Move the light to the +z axis in world space so it is - # behind the sphere. Note that +Z is in, +Y up, - # +X left for both world and camera space. - ######################################################## - lights.location[..., 2] = -2.0 - phong_shader = HardPhongShader( - lights=lights, - cameras=cameras, - materials=materials, - blend_params=blend_params, - ) - phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader) - images = phong_renderer(sphere_mesh, lights=lights) - rgb = images[0, ..., :3].squeeze().cpu() - if DEBUG: - filename = "DEBUG_simple_sphere_dark%s.png" % postfix - Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / filename + image_ref_phong_dark = load_rgb_image( + "test_simple_sphere_dark%s%s.png" % (postfix, cam_type.__name__), + DATA_DIR, ) - - # Load reference image - image_ref_phong_dark = load_rgb_image( - "test_simple_sphere_dark%s.png" % postfix, DATA_DIR - ) - self.assertClose(rgb, image_ref_phong_dark, atol=0.05) + self.assertClose(rgb, image_ref_phong_dark, atol=0.05) def test_simple_sphere_elevated_camera(self): """ @@ -142,6 +165,60 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): """ self.test_simple_sphere(elevated_camera=True) + def test_simple_sphere_screen(self): + + """ + Test output when rendering with PerspectiveCameras & OrthographicCameras + in NDC vs screen space. + """ + device = torch.device("cuda:0") + + # Init mesh + sphere_mesh = ico_sphere(5, device) + verts_padded = sphere_mesh.verts_padded() + faces_padded = sphere_mesh.faces_padded() + feats = torch.ones_like(verts_padded, device=device) + textures = TexturesVertex(verts_features=feats) + sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures) + + R, T = look_at_view_transform(2.7, 0.0, 0.0) + + # Init shader settings + materials = Materials(device=device) + lights = PointLights(device=device) + lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None] + + raster_settings = RasterizationSettings( + image_size=512, blur_radius=0.0, faces_per_pixel=1 + ) + for cam_type in (PerspectiveCameras, OrthographicCameras): + cameras = cam_type( + device=device, + R=R, + T=T, + principal_point=((256.0, 256.0),), + focal_length=((256.0, 256.0),), + image_size=((512, 512),), + ) + rasterizer = MeshRasterizer( + cameras=cameras, raster_settings=raster_settings + ) + blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0)) + + shader = HardPhongShader( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + images = renderer(sphere_mesh) + rgb = images[0, ..., :3].squeeze().cpu() + filename = "test_simple_sphere_light_phong_%s.png" % cam_type.__name__ + + image_ref = load_rgb_image(filename, DATA_DIR) + self.assertClose(rgb, image_ref, atol=0.05) + def test_simple_sphere_batched(self): """ Test a mesh with vertex textures can be extended to form a batch, and @@ -165,7 +242,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): elev = torch.zeros_like(dist) azim = torch.zeros_like(dist) R, T = look_at_view_transform(dist, elev, azim) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1 ) @@ -193,12 +270,16 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) images = renderer(sphere_meshes) image_ref = load_rgb_image( - "test_simple_sphere_light_%s.png" % name, DATA_DIR + "test_simple_sphere_light_%s_%s.png" % (name, type(cameras).__name__), + DATA_DIR, ) for i in range(batch_size): rgb = images[i, ..., :3].squeeze().cpu() if i == 0 and DEBUG: - filename = "DEBUG_simple_sphere_batched_%s.png" % name + filename = "DEBUG_simple_sphere_batched_%s_%s.png" % ( + name, + type(cameras).__name__, + ) Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save( DATA_DIR / filename ) @@ -209,8 +290,6 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): Test silhouette blending. Also check that gradient calculation works. """ device = torch.device("cuda:0") - ref_filename = "test_silhouette.png" - image_ref_filename = DATA_DIR / ref_filename sphere_mesh = ico_sphere(5, device) verts, faces = sphere_mesh.get_mesh_verts_faces(0) sphere_mesh = Meshes(verts=[verts], faces=[faces]) @@ -225,32 +304,45 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # Init rasterizer settings R, T = look_at_view_transform(2.7, 0, 0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + for cam_type in ( + FoVPerspectiveCameras, + FoVOrthographicCameras, + PerspectiveCameras, + OrthographicCameras, + ): + cameras = cam_type(device=device, R=R, T=T) - # Init renderer - renderer = MeshRenderer( - rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), - shader=SoftSilhouetteShader(blend_params=blend_params), - ) - images = renderer(sphere_mesh) - alpha = images[0, ..., 3].squeeze().cpu() - if DEBUG: - Image.fromarray((alpha.numpy() * 255).astype(np.uint8)).save( - DATA_DIR / "DEBUG_silhouette.png" + # Init renderer + renderer = MeshRenderer( + rasterizer=MeshRasterizer( + cameras=cameras, raster_settings=raster_settings + ), + shader=SoftSilhouetteShader(blend_params=blend_params), ) + images = renderer(sphere_mesh) + alpha = images[0, ..., 3].squeeze().cpu() + if DEBUG: + filename = os.path.join( + DATA_DIR, "DEBUG_%s_silhouette.png" % (cam_type.__name__) + ) + Image.fromarray((alpha.detach().numpy() * 255).astype(np.uint8)).save( + filename + ) - with Image.open(image_ref_filename) as raw_image_ref: - image_ref = torch.from_numpy(np.array(raw_image_ref)) + ref_filename = "test_%s_silhouette.png" % (cam_type.__name__) + image_ref_filename = DATA_DIR / ref_filename + with Image.open(image_ref_filename) as raw_image_ref: + image_ref = torch.from_numpy(np.array(raw_image_ref)) - image_ref = image_ref.to(dtype=torch.float32) / 255.0 - self.assertClose(alpha, image_ref, atol=0.055) + image_ref = image_ref.to(dtype=torch.float32) / 255.0 + self.assertClose(alpha, image_ref, atol=0.055) - # Check grad exist - verts.requires_grad = True - sphere_mesh = Meshes(verts=[verts], faces=[faces]) - images = renderer(sphere_mesh) - images[0, ...].sum().backward() - self.assertIsNotNone(verts.grad) + # Check grad exist + verts.requires_grad = True + sphere_mesh = Meshes(verts=[verts], faces=[faces]) + images = renderer(sphere_mesh) + images[0, ...].sum().backward() + self.assertIsNotNone(verts.grad) def test_texture_map(self): """ @@ -274,7 +366,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # Init rasterizer settings R, T = look_at_view_transform(2.7, 0, 0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1 @@ -337,7 +429,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ########################################## R, T = look_at_view_transform(2.7, 0, 180) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) # Move light to the front of the cow in world space lights.location = torch.tensor([0.0, 0.0, -2.0], device=device)[None] @@ -367,7 +459,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # Add blurring to rasterization ################################# R, T = look_at_view_transform(2.7, 0, 180) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) blend_params = BlendParams(sigma=5e-4, gamma=1e-4) raster_settings = RasterizationSettings( image_size=512, @@ -429,7 +521,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # Init rasterizer settings R, T = look_at_view_transform(2.7, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1 ) @@ -490,7 +582,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): # Init rasterizer settings R, T = look_at_view_transform(2.7, 0, 0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = RasterizationSettings( image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True diff --git a/tests/test_render_points.py b/tests/test_render_points.py index 0220fa73..c5820267 100644 --- a/tests/test_render_points.py +++ b/tests/test_render_points.py @@ -14,8 +14,8 @@ import torch from common_testing import TestCaseMixin, load_rgb_image from PIL import Image from pytorch3d.renderer.cameras import ( - OpenGLOrthographicCameras, - OpenGLPerspectiveCameras, + FoVOrthographicCameras, + FoVPerspectiveCameras, look_at_view_transform, ) from pytorch3d.renderer.points import ( @@ -47,7 +47,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): points=verts_padded, features=torch.ones_like(verts_padded) ) R, T = look_at_view_transform(2.7, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = PointsRasterizationSettings( image_size=256, radius=5e-2, points_per_pixel=1 ) @@ -97,7 +97,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): point_cloud = Pointclouds(points=[verts], features=[rgb_feats]) R, T = look_at_view_transform(20, 10, 0) - cameras = OpenGLOrthographicCameras(device=device, R=R, T=T, znear=0.01) + cameras = FoVOrthographicCameras(device=device, R=R, T=T, znear=0.01) raster_settings = PointsRasterizationSettings( # Set image_size so it is not a multiple of 16 (min bin_size) @@ -150,7 +150,7 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase): batch_size = 20 pointclouds = pointclouds.extend(batch_size) R, T = look_at_view_transform(2.7, 0.0, 0.0) - cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + cameras = FoVPerspectiveCameras(device=device, R=R, T=T) raster_settings = PointsRasterizationSettings( image_size=256, radius=5e-2, points_per_pixel=1 ) diff --git a/tests/test_shapenet_core.py b/tests/test_shapenet_core.py index 4b225674..b974a700 100644 --- a/tests/test_shapenet_core.py +++ b/tests/test_shapenet_core.py @@ -12,7 +12,7 @@ from common_testing import TestCaseMixin, load_rgb_image from PIL import Image from pytorch3d.datasets import ShapeNetCore, collate_batched_meshes from pytorch3d.renderer import ( - OpenGLPerspectiveCameras, + FoVPerspectiveCameras, PointLights, RasterizationSettings, look_at_view_transform, @@ -174,7 +174,7 @@ class TestShapenetCore(TestCaseMixin, unittest.TestCase): # Rendering settings. R, T = look_at_view_transform(1.0, 1.0, 90) - cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device) + cameras = FoVPerspectiveCameras(R=R, T=T, device=device) raster_settings = RasterizationSettings(image_size=512) lights = PointLights( location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],