mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	renderer: add support for rendering high dimensional textures for classification/segmentation use cases (#1248)
Summary: For 3D segmentation problems it's really useful to be able to train the models from multiple viewpoints using Pytorch3D as the renderer. Currently due to hardcoded assumptions in a few spots the mesh renderer only supports rendering RGB (3 dimensional) data. You can encode the classification information as 3 channel data but if you have more than 3 classes you're out of luck. This relaxes the assumptions to make rendering semantic classes work with `HardFlatShader` and `AmbientLights` with no diffusion/specular. The other shaders/lights don't make any sense for classification since they mutate the texture values in some way. This only requires changes in `Materials` and `AmbientLights`. The bulk of the code is the unit test. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1248 Test Plan: Added unit test that renders a 5 dimensional texture and compare dimensions 2-5 to a stored picture. Reviewed By: bottler Differential Revision: D37764610 Pulled By: d4l3k fbshipit-source-id: 031895724d9318a6f6bab5b31055bb3f438176a5
This commit is contained in:
		
							parent
							
								
									aa8b03f31d
								
							
						
					
					
						commit
						8d10ba52b2
					
				@ -292,6 +292,9 @@ class AmbientLights(TensorProperties):
 | 
			
		||||
    A light object representing the same color of light everywhere.
 | 
			
		||||
    By default, this is white, which effectively means lighting is
 | 
			
		||||
    not used in rendering.
 | 
			
		||||
 | 
			
		||||
    Unlike other lights this supports an arbitrary number of channels, not just 3 for RGB.
 | 
			
		||||
    The ambient_color input determines the number of channels.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, ambient_color=None, device: Device = "cpu") -> None:
 | 
			
		||||
@ -304,9 +307,11 @@ class AmbientLights(TensorProperties):
 | 
			
		||||
            device: Device (as str or torch.device) on which the tensors should be located
 | 
			
		||||
 | 
			
		||||
        The ambient_color if provided, should be
 | 
			
		||||
            - 3 element tuple/list or list of lists
 | 
			
		||||
            - torch tensor of shape (1, 3)
 | 
			
		||||
            - torch tensor of shape (N, 3)
 | 
			
		||||
            - tuple/list of C-element tuples of floats
 | 
			
		||||
            - torch tensor of shape (1, C)
 | 
			
		||||
            - torch tensor of shape (N, C)
 | 
			
		||||
        where C is the number of channels and N is batch size.
 | 
			
		||||
        For RGB, C is 3.
 | 
			
		||||
        """
 | 
			
		||||
        if ambient_color is None:
 | 
			
		||||
            ambient_color = ((1.0, 1.0, 1.0),)
 | 
			
		||||
@ -317,10 +322,14 @@ class AmbientLights(TensorProperties):
 | 
			
		||||
        return super().clone(other)
 | 
			
		||||
 | 
			
		||||
    def diffuse(self, normals, points) -> torch.Tensor:
 | 
			
		||||
        return torch.zeros_like(points)
 | 
			
		||||
        return self._zeros_channels(points)
 | 
			
		||||
 | 
			
		||||
    def specular(self, normals, points, camera_position, shininess) -> torch.Tensor:
 | 
			
		||||
        return torch.zeros_like(points)
 | 
			
		||||
        return self._zeros_channels(points)
 | 
			
		||||
 | 
			
		||||
    def _zeros_channels(self, points: torch.Tensor) -> torch.Tensor:
 | 
			
		||||
        ch = self.ambient_color.shape[-1]
 | 
			
		||||
        return torch.zeros(*points.shape[:-1], ch, device=points.device)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _validate_light_properties(obj) -> None:
 | 
			
		||||
 | 
			
		||||
@ -27,9 +27,9 @@ class Materials(TensorProperties):
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Args:
 | 
			
		||||
            ambient_color: RGB ambient reflectivity of the material
 | 
			
		||||
            diffuse_color: RGB diffuse reflectivity of the material
 | 
			
		||||
            specular_color: RGB specular reflectivity of the material
 | 
			
		||||
            ambient_color: ambient reflectivity of the material
 | 
			
		||||
            diffuse_color: diffuse reflectivity of the material
 | 
			
		||||
            specular_color: specular reflectivity of the material
 | 
			
		||||
            shininess: The specular exponent for the material. This defines
 | 
			
		||||
                the focus of the specular highlight with a high value
 | 
			
		||||
                resulting in a concentrated highlight. Shininess values
 | 
			
		||||
@ -37,7 +37,8 @@ class Materials(TensorProperties):
 | 
			
		||||
            device: Device (as str or torch.device) on which the tensors should be located
 | 
			
		||||
 | 
			
		||||
        ambient_color, diffuse_color and specular_color can be of shape
 | 
			
		||||
        (1, 3) or (N, 3). shininess can be of shape (1) or (N).
 | 
			
		||||
        (1, C) or (N, C) where C is typically 3 (for RGB). shininess can be of shape (1,)
 | 
			
		||||
        or (N,).
 | 
			
		||||
 | 
			
		||||
        The colors and shininess are broadcast against each other so need to
 | 
			
		||||
        have either the same batch dimension or batch dimension = 1.
 | 
			
		||||
@ -49,11 +50,12 @@ class Materials(TensorProperties):
 | 
			
		||||
            specular_color=specular_color,
 | 
			
		||||
            shininess=shininess,
 | 
			
		||||
        )
 | 
			
		||||
        C = self.ambient_color.shape[-1]
 | 
			
		||||
        for n in ["ambient_color", "diffuse_color", "specular_color"]:
 | 
			
		||||
            t = getattr(self, n)
 | 
			
		||||
            if t.shape[-1] != 3:
 | 
			
		||||
                msg = "Expected %s to have shape (N, 3); got %r"
 | 
			
		||||
                raise ValueError(msg % (n, t.shape))
 | 
			
		||||
            if t.shape[-1] != C:
 | 
			
		||||
                msg = "Expected %s to have shape (N, %d); got %r"
 | 
			
		||||
                raise ValueError(msg % (n, C, t.shape))
 | 
			
		||||
        if self.shininess.shape != torch.Size([self._N]):
 | 
			
		||||
            msg = "shininess should have shape (N); got %r"
 | 
			
		||||
            raise ValueError(msg % repr(self.shininess.shape))
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								tests/data/test_nd_sphere.png
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								tests/data/test_nd_sphere.png
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 56 KiB  | 
@ -1236,3 +1236,81 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
                "test_simple_sphere_light_phong_%s.png" % cam_type.__name__, DATA_DIR
 | 
			
		||||
            )
 | 
			
		||||
            self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
    def test_nd_sphere(self):
 | 
			
		||||
        """
 | 
			
		||||
        Test that the render can handle textures with more than 3 channels and
 | 
			
		||||
        not just 3 channel RGB.
 | 
			
		||||
        """
 | 
			
		||||
        torch.manual_seed(1)
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        C = 5
 | 
			
		||||
        WHITE = ((1.0,) * C,)
 | 
			
		||||
        BLACK = ((0.0,) * C,)
 | 
			
		||||
 | 
			
		||||
        # Init mesh
 | 
			
		||||
        sphere_mesh = ico_sphere(5, device)
 | 
			
		||||
        verts_padded = sphere_mesh.verts_padded()
 | 
			
		||||
        faces_padded = sphere_mesh.faces_padded()
 | 
			
		||||
        feats = torch.ones(*verts_padded.shape[:-1], C, device=device)
 | 
			
		||||
        n_verts = feats.shape[1]
 | 
			
		||||
        # make some non-uniform pattern
 | 
			
		||||
        feats *= torch.arange(0, 10, step=10 / n_verts, device=device).unsqueeze(1)
 | 
			
		||||
        textures = TexturesVertex(verts_features=feats)
 | 
			
		||||
        sphere_mesh = Meshes(verts=verts_padded, faces=faces_padded, textures=textures)
 | 
			
		||||
 | 
			
		||||
        # No elevation or azimuth rotation
 | 
			
		||||
        R, T = look_at_view_transform(2.7, 0.0, 0.0)
 | 
			
		||||
 | 
			
		||||
        cameras = PerspectiveCameras(device=device, R=R, T=T)
 | 
			
		||||
 | 
			
		||||
        # Init shader settings
 | 
			
		||||
        materials = Materials(
 | 
			
		||||
            device=device,
 | 
			
		||||
            ambient_color=WHITE,
 | 
			
		||||
            diffuse_color=WHITE,
 | 
			
		||||
            specular_color=WHITE,
 | 
			
		||||
        )
 | 
			
		||||
        lights = AmbientLights(
 | 
			
		||||
            device=device,
 | 
			
		||||
            ambient_color=WHITE,
 | 
			
		||||
        )
 | 
			
		||||
        lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
 | 
			
		||||
 | 
			
		||||
        raster_settings = RasterizationSettings(
 | 
			
		||||
            image_size=512, blur_radius=0.0, faces_per_pixel=1
 | 
			
		||||
        )
 | 
			
		||||
        rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
 | 
			
		||||
        blend_params = BlendParams(
 | 
			
		||||
            1e-4,
 | 
			
		||||
            1e-4,
 | 
			
		||||
            background_color=BLACK[0],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # only test HardFlatShader since that's the only one that makes
 | 
			
		||||
        # sense for classification
 | 
			
		||||
        shader = HardFlatShader(
 | 
			
		||||
            lights=lights,
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            materials=materials,
 | 
			
		||||
            blend_params=blend_params,
 | 
			
		||||
        )
 | 
			
		||||
        renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
 | 
			
		||||
        images = renderer(sphere_mesh)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(images.shape[-1], C + 1)
 | 
			
		||||
        self.assertClose(images.amax(), torch.tensor(10.0), atol=0.01)
 | 
			
		||||
        self.assertClose(images.amin(), torch.tensor(0.0), atol=0.01)
 | 
			
		||||
 | 
			
		||||
        # grab last 3 color channels
 | 
			
		||||
        rgb = (images[0, ..., C - 3 : C] / 10).squeeze().cpu()
 | 
			
		||||
        filename = "test_nd_sphere.png"
 | 
			
		||||
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            debug_filename = "DEBUG_%s" % filename
 | 
			
		||||
            Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
 | 
			
		||||
                DATA_DIR / debug_filename
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        image_ref = load_rgb_image(filename, DATA_DIR)
 | 
			
		||||
        self.assertClose(rgb, image_ref, atol=0.05)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user