diff --git a/pytorch3d/renderer/blending.py b/pytorch3d/renderer/blending.py index 4bc4822b..952f68c8 100644 --- a/pytorch3d/renderer/blending.py +++ b/pytorch3d/renderer/blending.py @@ -46,7 +46,7 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor: is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W) if torch.is_tensor(blend_params.background_color): - background_color = blend_params.background_color + background_color = blend_params.background_color.to(device) else: background_color = colors.new_tensor(blend_params.background_color) # (3) @@ -163,6 +163,8 @@ def softmax_rgb_blend( background = blend_params.background_color if not torch.is_tensor(background): background = torch.tensor(background, dtype=torch.float32, device=device) + else: + background = background.to(device) # Weight for background color eps = 1e-10 diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index 04044f5f..e2983feb 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -76,6 +76,10 @@ class MeshRasterizer(nn.Module): self.cameras = cameras self.raster_settings = raster_settings + def to(self, device): + # Manually move to device cameras as it is not a subclass of nn.Module + self.cameras = self.cameras.to(device) + def transform(self, meshes_world, **kwargs) -> torch.Tensor: """ Args: diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index 8e5789bb..9e366616 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -33,6 +33,11 @@ class MeshRenderer(nn.Module): self.rasterizer = rasterizer self.shader = shader + def to(self, device): + # Rasterizer and shader have submodules which are not of type nn.Module + self.rasterizer.to(device) + self.shader.to(device) + def forward(self, meshes_world, **kwargs) -> torch.Tensor: """ Render a batch of images from a batch of meshes by rasterizing and then @@ -44,6 +49,7 @@ class MeshRenderer(nn.Module): face f, clipping is required before interpolating the texture uv coordinates and z buffer so that the colors and depths are limited to the range for the corresponding face. + For this set rasterizer.raster_settings.clip_barycentric_coords=True """ fragments = self.rasterizer(meshes_world, **kwargs) images = self.shader(fragments, meshes_world, **kwargs) diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index c1b5b791..eb8d080c 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -50,6 +50,12 @@ class HardPhongShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -98,6 +104,12 @@ class SoftPhongShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -151,6 +163,12 @@ class HardGouraudShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -203,6 +221,12 @@ class SoftGouraudShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: @@ -272,6 +296,12 @@ class HardFlatShader(nn.Module): self.cameras = cameras self.blend_params = blend_params if blend_params is not None else BlendParams() + def to(self, device): + # Manually move to device modules which are not subclasses of nn.Module + self.cameras = self.cameras.to(device) + self.materials = self.materials.to(device) + self.lights = self.lights.to(device) + def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) if cameras is None: diff --git a/tests/test_render_meshes.py b/tests/test_render_meshes.py index 4f0c2ba2..511f95db 100644 --- a/tests/test_render_meshes.py +++ b/tests/test_render_meshes.py @@ -1042,3 +1042,67 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase): ) self.assertClose(rgb, image_ref, atol=0.05) + + def test_to(self): + # Test moving all the tensors in the renderer to a new device + # to support multigpu rendering. + device1 = torch.device("cpu") + + R, T = look_at_view_transform(1500, 0.0, 0.0) + + # Init shader settings + materials = Materials(device=device1) + lights = PointLights(device=device1) + lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None] + + raster_settings = RasterizationSettings( + image_size=256, blur_radius=0.0, faces_per_pixel=1 + ) + cameras = FoVPerspectiveCameras( + device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100 + ) + rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) + + blend_params = BlendParams( + 1e-4, + 1e-4, + background_color=torch.zeros(3, dtype=torch.float32, device=device1), + ) + + shader = SoftPhongShader( + lights=lights, + cameras=cameras, + materials=materials, + blend_params=blend_params, + ) + renderer = MeshRenderer(rasterizer=rasterizer, shader=shader) + + def _check_props_on_device(renderer, device): + self.assertEqual(renderer.rasterizer.cameras.device, device) + self.assertEqual(renderer.shader.cameras.device, device) + self.assertEqual(renderer.shader.lights.device, device) + self.assertEqual(renderer.shader.lights.ambient_color.device, device) + self.assertEqual(renderer.shader.materials.device, device) + self.assertEqual(renderer.shader.materials.ambient_color.device, device) + + mesh = ico_sphere(2, device1) + verts_padded = mesh.verts_padded() + textures = TexturesVertex( + verts_features=torch.ones_like(verts_padded, device=device1) + ) + mesh.textures = textures + _check_props_on_device(renderer, device1) + + # Test rendering on cpu + output_images = renderer(mesh) + self.assertEqual(output_images.device, device1) + + # Move renderer and mesh to another device and re render + # This also tests that background_color is correctly moved to + # the new device + device2 = torch.device("cuda:0") + renderer.to(device2) + mesh = mesh.to(device2) + _check_props_on_device(renderer, device2) + output_images = renderer(mesh) + self.assertEqual(output_images.device, device2)