diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 5ab9ce9d..80751d9d 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -26,7 +26,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading # - blend colors across top K faces per pixel. -class HardPhongShader(nn.Module): # pragma: no cover +class HardPhongShader(nn.Module): """ Per pixel lighting - the lighting model is applied using the interpolated coordinates and normals for each pixel. The blending function hard assigns @@ -58,7 +58,8 @@ class HardPhongShader(nn.Module): # pragma: no cover def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self @@ -86,7 +87,7 @@ class HardPhongShader(nn.Module): # pragma: no cover return images -class SoftPhongShader(nn.Module): # pragma: no cover +class SoftPhongShader(nn.Module): """ Per pixel lighting - the lighting model is applied using the interpolated coordinates and normals for each pixel. The blending function returns the @@ -118,7 +119,8 @@ class SoftPhongShader(nn.Module): # pragma: no cover def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self @@ -150,7 +152,7 @@ class SoftPhongShader(nn.Module): # pragma: no cover return images -class HardGouraudShader(nn.Module): # pragma: no cover +class HardGouraudShader(nn.Module): """ Per vertex lighting - the lighting model is applied to the vertex colors and the colors are then interpolated using the barycentric coordinates to @@ -183,7 +185,8 @@ class HardGouraudShader(nn.Module): # pragma: no cover def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self @@ -214,7 +217,7 @@ class HardGouraudShader(nn.Module): # pragma: no cover return images -class SoftGouraudShader(nn.Module): # pragma: no cover +class SoftGouraudShader(nn.Module): """ Per vertex lighting - the lighting model is applied to the vertex colors and the colors are then interpolated using the barycentric coordinates to @@ -247,7 +250,8 @@ class SoftGouraudShader(nn.Module): # pragma: no cover def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self @@ -277,7 +281,7 @@ class SoftGouraudShader(nn.Module): # pragma: no cover def TexturedSoftPhongShader( device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None -): # pragma: no cover +): """ TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. Preserving TexturedSoftPhongShader as a function for backwards compatibility. @@ -296,7 +300,7 @@ def TexturedSoftPhongShader( ) -class HardFlatShader(nn.Module): # pragma: no cover +class HardFlatShader(nn.Module): """ Per face lighting - the lighting model is applied using the average face position and the face normal. The blending function hard assigns @@ -328,7 +332,8 @@ class HardFlatShader(nn.Module): # pragma: no cover def to(self, device: Device): # Manually move to device modules which are not subclasses of nn.Module - self.cameras = self.cameras.to(device) + if self.cameras is not None: + self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) return self @@ -355,7 +360,7 @@ class HardFlatShader(nn.Module): # pragma: no cover return images -class SoftSilhouetteShader(nn.Module): # pragma: no cover +class SoftSilhouetteShader(nn.Module): """ Calculate the silhouette by blending the top K faces for each pixel based on the 2d euclidean distance of the center of the pixel to the mesh face. diff --git a/tests/test_shader.py b/tests/test_shader.py new file mode 100644 index 00000000..74284c5d --- /dev/null +++ b/tests/test_shader.py @@ -0,0 +1,87 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import unittest + +import torch +from common_testing import TestCaseMixin +from pytorch3d.renderer.cameras import ( + look_at_view_transform, + PerspectiveCameras, +) +from pytorch3d.renderer.mesh.rasterizer import Fragments +from pytorch3d.renderer.mesh.shader import ( + HardFlatShader, + HardGouraudShader, + HardPhongShader, + SoftPhongShader, +) +from pytorch3d.structures.meshes import Meshes + + +class TestShader(TestCaseMixin, unittest.TestCase): + def test_to(self): + cpu_device = torch.device("cpu") + cuda_device = torch.device("cuda") + + R, T = look_at_view_transform() + + shader_classes = [ + HardFlatShader, + HardGouraudShader, + HardPhongShader, + SoftPhongShader, + ] + + for shader_class in shader_classes: + for cameras_class in (None, PerspectiveCameras): + if cameras_class is None: + cameras = None + else: + cameras = PerspectiveCameras(device=cpu_device, R=R, T=T) + + cpu_shader = shader_class(device=cpu_device, cameras=cameras) + if cameras is None: + self.assertIsNone(cpu_shader.cameras) + else: + self.assertEqual(cpu_device, cpu_shader.cameras.device) + self.assertEqual(cpu_device, cpu_shader.materials.device) + self.assertEqual(cpu_device, cpu_shader.lights.device) + + cuda_shader = cpu_shader.to(cuda_device) + self.assertIs(cpu_shader, cuda_shader) + if cameras is None: + self.assertIsNone(cuda_shader.cameras) + else: + self.assertEqual(cuda_device, cuda_shader.cameras.device) + self.assertEqual(cuda_device, cuda_shader.materials.device) + self.assertEqual(cuda_device, cuda_shader.lights.device) + + def test_cameras_check(self): + verts = torch.tensor( + [[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32 + ) + faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64) + meshes = Meshes(verts=[verts], faces=[faces]) + + pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2) + barycentric_coords = torch.tensor( + [[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32 + ).view(1, 1, 1, 2, -1) + fragments = Fragments( + pix_to_face=pix_to_face, + bary_coords=barycentric_coords, + zbuf=torch.ones_like(pix_to_face), + dists=torch.ones_like(pix_to_face), + ) + shader_classes = [ + HardFlatShader, + HardGouraudShader, + HardPhongShader, + SoftPhongShader, + ] + + for shader_class in shader_classes: + shader = shader_class() + + with self.assertRaises(ValueError): + shader(fragments, meshes)