mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	allow cameras to be None in rasterizer initialization
Summary: Fix to enable a mesh/point rasterizer to be initialized without having to specify the camera. Reviewed By: jcjohnson, gkioxari Differential Revision: D21362359 fbshipit-source-id: 4f84ea18ad9f179c7b7c2289ebf9422a2f5e26de
This commit is contained in:
		
							parent
							
								
									9c5ab57156
								
							
						
					
					
						commit
						17ca6ecd81
					
				@ -79,6 +79,7 @@
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {},
 | 
			
		||||
    "colab_type": "code",
 | 
			
		||||
    "collapsed": true,
 | 
			
		||||
    "id": "w9mH5iVprQdZ"
 | 
			
		||||
   },
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
@ -260,7 +261,7 @@
 | 
			
		||||
    "        cameras=cameras, \n",
 | 
			
		||||
    "        raster_settings=raster_settings\n",
 | 
			
		||||
    "    ),\n",
 | 
			
		||||
    "    shader=HardPhongShader(device=device, lights=lights)\n",
 | 
			
		||||
    "    shader=HardPhongShader(device=device, cameras=cameras, lights=lights)\n",
 | 
			
		||||
    ")"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
 | 
			
		||||
@ -54,7 +54,7 @@ class MeshRasterizer(nn.Module):
 | 
			
		||||
    Meshes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, cameras, raster_settings=None):
 | 
			
		||||
    def __init__(self, cameras=None, raster_settings=None):
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            cameras: A cameras object which has a  `transform_points` method
 | 
			
		||||
@ -88,6 +88,11 @@ class MeshRasterizer(nn.Module):
 | 
			
		||||
        be moved into forward.
 | 
			
		||||
        """
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of MeshRasterizer"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        verts_world = meshes_world.verts_padded()
 | 
			
		||||
        verts_world_packed = meshes_world.verts_packed()
 | 
			
		||||
        verts_screen = cameras.transform_points(verts_world, **kwargs)
 | 
			
		||||
 | 
			
		||||
@ -10,7 +10,6 @@ from ..blending import (
 | 
			
		||||
    sigmoid_alpha_blend,
 | 
			
		||||
    softmax_rgb_blend,
 | 
			
		||||
)
 | 
			
		||||
from ..cameras import OpenGLPerspectiveCameras
 | 
			
		||||
from ..lighting import PointLights
 | 
			
		||||
from ..materials import Materials
 | 
			
		||||
from .shading import flat_shading, gouraud_shading, phong_shading
 | 
			
		||||
@ -46,13 +45,16 @@ class HardPhongShader(nn.Module):
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = (
 | 
			
		||||
            cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of HardPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        colors = phong_shading(
 | 
			
		||||
@ -89,14 +91,16 @@ class SoftPhongShader(nn.Module):
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = (
 | 
			
		||||
            cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        colors = phong_shading(
 | 
			
		||||
@ -132,12 +136,14 @@ class HardGouraudShader(nn.Module):
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = (
 | 
			
		||||
            cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        pixel_colors = gouraud_shading(
 | 
			
		||||
@ -174,13 +180,15 @@ class SoftGouraudShader(nn.Module):
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = (
 | 
			
		||||
            cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        pixel_colors = gouraud_shading(
 | 
			
		||||
@ -219,14 +227,16 @@ class TexturedSoftPhongShader(nn.Module):
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = (
 | 
			
		||||
            cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
        self.blend_params = blend_params if blend_params is not None else BlendParams()
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        texels = interpolate_texture_map(fragments, meshes)
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        texels = interpolate_texture_map(fragments, meshes)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        blend_params = kwargs.get("blend_params", self.blend_params)
 | 
			
		||||
@ -262,13 +272,15 @@ class HardFlatShader(nn.Module):
 | 
			
		||||
        self.materials = (
 | 
			
		||||
            materials if materials is not None else Materials(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = (
 | 
			
		||||
            cameras if cameras is not None else OpenGLPerspectiveCameras(device=device)
 | 
			
		||||
        )
 | 
			
		||||
        self.cameras = cameras
 | 
			
		||||
 | 
			
		||||
    def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of SoftPhongShader"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
        texels = interpolate_vertex_colors(fragments, meshes)
 | 
			
		||||
        lights = kwargs.get("lights", self.lights)
 | 
			
		||||
        materials = kwargs.get("materials", self.materials)
 | 
			
		||||
        colors = flat_shading(
 | 
			
		||||
 | 
			
		||||
@ -48,7 +48,7 @@ class PointsRasterizer(nn.Module):
 | 
			
		||||
    This class implements methods for rasterizing a batch of pointclouds.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, cameras, raster_settings=None):
 | 
			
		||||
    def __init__(self, cameras=None, raster_settings=None):
 | 
			
		||||
        """
 | 
			
		||||
        cameras: A cameras object which has a  `transform_points` method
 | 
			
		||||
                which returns the transformed points after applying the
 | 
			
		||||
@ -80,6 +80,10 @@ class PointsRasterizer(nn.Module):
 | 
			
		||||
        be moved into forward.
 | 
			
		||||
        """
 | 
			
		||||
        cameras = kwargs.get("cameras", self.cameras)
 | 
			
		||||
        if cameras is None:
 | 
			
		||||
            msg = "Cameras must be specified either at initialization \
 | 
			
		||||
                or in the forward pass of PointsRasterizer"
 | 
			
		||||
            raise ValueError(msg)
 | 
			
		||||
 | 
			
		||||
        pts_world = point_clouds.points_padded()
 | 
			
		||||
        pts_world_packed = point_clouds.points_packed()
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,11 @@ import torch
 | 
			
		||||
from PIL import Image
 | 
			
		||||
from pytorch3d.renderer.cameras import OpenGLPerspectiveCameras, look_at_view_transform
 | 
			
		||||
from pytorch3d.renderer.mesh.rasterizer import MeshRasterizer, RasterizationSettings
 | 
			
		||||
from pytorch3d.renderer.points.rasterizer import (
 | 
			
		||||
    PointsRasterizationSettings,
 | 
			
		||||
    PointsRasterizer,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.structures import Pointclouds
 | 
			
		||||
from pytorch3d.utils.ico_sphere import ico_sphere
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -99,3 +104,101 @@ class TestMeshRasterizer(unittest.TestCase):
 | 
			
		||||
                DATA_DIR / "DEBUG_test_rasterized_sphere_zoom.png"
 | 
			
		||||
            )
 | 
			
		||||
        self.assertTrue(torch.allclose(image, image_ref))
 | 
			
		||||
 | 
			
		||||
        #################################
 | 
			
		||||
        #  4. Test init without cameras.
 | 
			
		||||
        ##################################
 | 
			
		||||
 | 
			
		||||
        # Create a new empty rasterizer:
 | 
			
		||||
        rasterizer = MeshRasterizer()
 | 
			
		||||
 | 
			
		||||
        # Check that omitting the cameras in both initialization
 | 
			
		||||
        # and the forward pass throws an error:
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
 | 
			
		||||
            rasterizer(sphere_mesh)
 | 
			
		||||
 | 
			
		||||
        # Now pass in the cameras as a kwarg
 | 
			
		||||
        fragments = rasterizer(
 | 
			
		||||
            sphere_mesh, cameras=cameras, raster_settings=raster_settings
 | 
			
		||||
        )
 | 
			
		||||
        image = fragments.pix_to_face[0, ..., 0].squeeze().cpu()
 | 
			
		||||
        # Convert pix_to_face to a binary mask
 | 
			
		||||
        image[image >= 0] = 1.0
 | 
			
		||||
        image[image < 0] = 0.0
 | 
			
		||||
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                DATA_DIR / "DEBUG_test_rasterized_sphere.png"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(torch.allclose(image, image_ref))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestPointRasterizer(unittest.TestCase):
 | 
			
		||||
    def test_simple_sphere(self):
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
 | 
			
		||||
        # Load reference image
 | 
			
		||||
        ref_filename = "test_simple_pointcloud_sphere.png"
 | 
			
		||||
        image_ref_filename = DATA_DIR / ref_filename
 | 
			
		||||
 | 
			
		||||
        # Rescale image_ref to the 0 - 1 range and convert to a binary mask.
 | 
			
		||||
        image_ref = convert_image_to_binary_mask(image_ref_filename).to(torch.int32)
 | 
			
		||||
 | 
			
		||||
        sphere_mesh = ico_sphere(1, device)
 | 
			
		||||
        verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
        verts_padded[..., 1] += 0.2
 | 
			
		||||
        verts_padded[..., 0] += 0.2
 | 
			
		||||
        pointclouds = Pointclouds(points=verts_padded)
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0.0, 0.0)
 | 
			
		||||
        cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
        raster_settings = PointsRasterizationSettings(
 | 
			
		||||
            image_size=256, radius=5e-2, points_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        #################################
 | 
			
		||||
        #  1. Test init without cameras.
 | 
			
		||||
        ##################################
 | 
			
		||||
 | 
			
		||||
        # Initialize without passing in the cameras
 | 
			
		||||
        rasterizer = PointsRasterizer()
 | 
			
		||||
 | 
			
		||||
        # Check that omitting the cameras in both initialization
 | 
			
		||||
        # and the forward pass throws an error:
 | 
			
		||||
        with self.assertRaisesRegex(ValueError, "Cameras must be specified"):
 | 
			
		||||
            rasterizer(pointclouds)
 | 
			
		||||
 | 
			
		||||
        ##########################################
 | 
			
		||||
        # 2. Test rasterizing a single pointcloud
 | 
			
		||||
        ##########################################
 | 
			
		||||
 | 
			
		||||
        fragments = rasterizer(
 | 
			
		||||
            pointclouds, cameras=cameras, raster_settings=raster_settings
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Convert idx to a binary mask
 | 
			
		||||
        image = fragments.idx[0, ..., 0].squeeze().cpu()
 | 
			
		||||
        image[image >= 0] = 1.0
 | 
			
		||||
        image[image < 0] = 0.0
 | 
			
		||||
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            Image.fromarray((image.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                DATA_DIR / "DEBUG_test_rasterized_sphere_points.png"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        self.assertTrue(torch.allclose(image, image_ref[..., 0]))
 | 
			
		||||
 | 
			
		||||
        ########################################
 | 
			
		||||
        #  3. Test with a batch of pointclouds
 | 
			
		||||
        ########################################
 | 
			
		||||
 | 
			
		||||
        batch_size = 10
 | 
			
		||||
        pointclouds = pointclouds.extend(batch_size)
 | 
			
		||||
        fragments = rasterizer(
 | 
			
		||||
            pointclouds, cameras=cameras, raster_settings=raster_settings
 | 
			
		||||
        )
 | 
			
		||||
        for i in range(batch_size):
 | 
			
		||||
            image = fragments.idx[i, ..., 0].squeeze().cpu()
 | 
			
		||||
            image[image >= 0] = 1.0
 | 
			
		||||
            image[image < 0] = 0.0
 | 
			
		||||
            self.assertTrue(torch.allclose(image, image_ref[..., 0]))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user