diff --git a/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb b/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb index a7b60921..dbe06830 100644 --- a/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb +++ b/docs/tutorials/camera_position_optimization_with_differentiable_rendering.ipynb @@ -79,6 +79,7 @@ "metadata": { "colab": {}, "colab_type": "code", + "collapsed": true, "id": "w9mH5iVprQdZ" }, "outputs": [], @@ -260,7 +261,7 @@ " cameras=cameras, \n", " raster_settings=raster_settings\n", " ),\n", - " shader=HardPhongShader(device=device, lights=lights)\n", + " shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n", ")" ] }, diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index a5c9a2e8..1a4d4abc 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -54,7 +54,7 @@ class MeshRasterizer(nn.Module): Meshes. """ - def __init__(self, cameras, raster_settings=None): + def __init__(self, cameras=None, raster_settings=None): """ Args: cameras: A cameras object which has a `transform_points` method @@ -88,6 +88,11 @@ class MeshRasterizer(nn.Module): be moved into forward. """ cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of MeshRasterizer" + raise ValueError(msg) + verts_world = meshes_world.verts_padded() verts_world_packed = meshes_world.verts_packed() verts_screen = cameras.transform_points(verts_world, **kwargs) diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 5ac34420..9deb2e3b 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -10,7 +10,6 @@ from ..blending import ( sigmoid_alpha_blend, softmax_rgb_blend, ) -from ..cameras import OpenGLPerspectiveCameras from ..lighting import PointLights from ..materials import Materials from .shading import flat_shading, gouraud_shading, phong_shading @@ -46,13 +45,16 @@ class HardPhongShader(nn.Module): self.materials = ( materials if materials is not None else Materials(device=device) ) - self.cameras = ( - cameras if cameras is not None else OpenGLPerspectiveCameras(device=device) - ) + self.cameras = cameras def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: - texels = interpolate_vertex_colors(fragments, meshes) cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of HardPhongShader" + raise ValueError(msg) + + texels = interpolate_vertex_colors(fragments, meshes) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) colors = phong_shading( @@ -89,14 +91,16 @@ class SoftPhongShader(nn.Module): self.materials = ( materials if materials is not None else Materials(device=device) ) - self.cameras = ( - cameras if cameras is not None else OpenGLPerspectiveCameras(device=device) - ) + self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: - texels = interpolate_vertex_colors(fragments, meshes) cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) + texels = interpolate_vertex_colors(fragments, meshes) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) colors = phong_shading( @@ -132,12 +136,14 @@ class HardGouraudShader(nn.Module): self.materials = ( materials if materials is not None else Materials(device=device) ) - self.cameras = ( - cameras if cameras is not None else OpenGLPerspectiveCameras(device=device) - ) + self.cameras = cameras def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) pixel_colors = gouraud_shading( @@ -174,13 +180,15 @@ class SoftGouraudShader(nn.Module): self.materials = ( materials if materials is not None else Materials(device=device) ) - self.cameras = ( - cameras if cameras is not None else OpenGLPerspectiveCameras(device=device) - ) + self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) pixel_colors = gouraud_shading( @@ -219,14 +227,16 @@ class TexturedSoftPhongShader(nn.Module): self.materials = ( materials if materials is not None else Materials(device=device) ) - self.cameras = ( - cameras if cameras is not None else OpenGLPerspectiveCameras(device=device) - ) + self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: - texels = interpolate_texture_map(fragments, meshes) cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) + texels = interpolate_texture_map(fragments, meshes) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) blend_params = kwargs.get("blend_params", self.blend_params) @@ -262,13 +272,15 @@ class HardFlatShader(nn.Module): self.materials = ( materials if materials is not None else Materials(device=device) ) - self.cameras = ( - cameras if cameras is not None else OpenGLPerspectiveCameras(device=device) - ) + self.cameras = cameras def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: - texels = interpolate_vertex_colors(fragments, meshes) cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of SoftPhongShader" + raise ValueError(msg) + texels = interpolate_vertex_colors(fragments, meshes) lights = kwargs.get("lights", self.lights) materials = kwargs.get("materials", self.materials) colors = flat_shading( diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index a732eccb..9b9e8bbf 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -48,7 +48,7 @@ class PointsRasterizer(nn.Module): This class implements methods for rasterizing a batch of pointclouds. """ - def __init__(self, cameras, raster_settings=None): + def __init__(self, cameras=None, raster_settings=None): """ cameras: A cameras object which has a `transform_points` method which returns the transformed points after applying the @@ -80,6 +80,10 @@ class PointsRasterizer(nn.Module): be moved into forward. """ cameras = kwargs.get("cameras", self.cameras) + if cameras is None: + msg = "Cameras must be specified either at initialization \ + or in the forward pass of PointsRasterizer" + raise ValueError(msg) pts_world = point_clouds.points_padded() pts_world_packed = point_clouds.points_packed() diff --git a/tests/test_rasterizer.py b/tests/test_rasterizer.py index 9d82e0c9..b65147a5 100644 --- a/tests/test_rasterizer.py +++ b/tests/test_rasterizer.py @@ -9,6 +9,11 @@ import torch from PIL import Image from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings +from pytorch3d.renderer.points.rasterizer import ( + PointsRasterizationSettings, + PointsRasterizer, +) +from pytorch3d.structures import Pointclouds from pytorch3d.utils.ico_sphere import ico_sphere @@ -99,3 +104,101 @@ class TestMeshRasterizer(unittest.TestCase): DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png" ) self.assertTrue(torch.allclose(image, image_ref)) + + ################################# + # 4. Test init without cameras. + ################################## + + # Create a new empty rasterizer: + rasterizer = MeshRasterizer() + + # Check that omitting the cameras in both initialization + # and the forward pass throws an error: + with self.assertRaisesRegex(ValueError, "Cameras must be specified"): + rasterizer(sphere_mesh) + + # Now pass in the cameras as a kwarg + fragments = rasterizer( + sphere_mesh, cameras=cameras, raster_settings=raster_settings + ) + image = fragments.pix_to_face[0, ..., 0].squeeze().cpu() + # Convert pix_to_face to a binary mask + image[image >= 0] = 1.0 + image[image < 0] = 0.0 + + if DEBUG: + Image.fromarray((image.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_test_rasterized_sphere.png" + ) + + self.assertTrue(torch.allclose(image, image_ref)) + + +class TestPointRasterizer(unittest.TestCase): + def test_simple_sphere(self): + device = torch.device("cuda:0") + + # Load reference image + ref_filename = "test_simple_pointcloud_sphere.png" + image_ref_filename = DATA_DIR / ref_filename + + # Rescale image_ref to the 0 - 1 range and convert to a binary mask. + image_ref = convert_image_to_binary_mask(image_ref_filename).to(torch.int32) + + sphere_mesh = ico_sphere(1, device) + verts_padded = sphere_mesh.verts_padded() + verts_padded[..., 1] += 0.2 + verts_padded[..., 0] += 0.2 + pointclouds = Pointclouds(points=verts_padded) + R, T = look_at_view_transform(2.7, 0.0, 0.0) + cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T) + raster_settings = PointsRasterizationSettings( + image_size=256, radius=5e-2, points_per_pixel=1 + ) + + ################################# + # 1. Test init without cameras. + ################################## + + # Initialize without passing in the cameras + rasterizer = PointsRasterizer() + + # Check that omitting the cameras in both initialization + # and the forward pass throws an error: + with self.assertRaisesRegex(ValueError, "Cameras must be specified"): + rasterizer(pointclouds) + + ########################################## + # 2. Test rasterizing a single pointcloud + ########################################## + + fragments = rasterizer( + pointclouds, cameras=cameras, raster_settings=raster_settings + ) + + # Convert idx to a binary mask + image = fragments.idx[0, ..., 0].squeeze().cpu() + image[image >= 0] = 1.0 + image[image < 0] = 0.0 + + if DEBUG: + Image.fromarray((image.numpy() * 255).astype(np.uint8)).save( + DATA_DIR / "DEBUG_test_rasterized_sphere_points.png" + ) + + self.assertTrue(torch.allclose(image, image_ref[..., 0])) + + ######################################## + # 3. Test with a batch of pointclouds + ######################################## + + batch_size = 10 + pointclouds = pointclouds.extend(batch_size) + fragments = rasterizer( + pointclouds, cameras=cameras, raster_settings=raster_settings + ) + for i in range(batch_size): + image = fragments.idx[i, ..., 0].squeeze().cpu() + image[image >= 0] = 1.0 + image[image < 0] = 0.0 + self.assertTrue(torch.allclose(image, image_ref[..., 0]))