From 19340462e42d3694c33255b12f2611d9eec1b18b Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Mon, 14 Dec 2020 15:26:24 -0800 Subject: [PATCH] Return self in the `to` method for the renderer classes Summary: Add `return self` to the `to` function for the renderer classes. Reviewed By: bottler Differential Revision: D25534487 fbshipit-source-id: e8dbd35524f0bd40e835439e93184b5a1f1532ca --- pytorch3d/renderer/mesh/rasterizer.py | 1 + pytorch3d/renderer/mesh/renderer.py | 1 + pytorch3d/renderer/mesh/shader.py | 5 ++ pytorch3d/renderer/points/rasterizer.py | 5 ++ pytorch3d/renderer/points/renderer.py | 7 +++ tests/test_render_multigpu.py | 62 +++++++++++++++++++++++-- 6 files changed, 77 insertions(+), 4 deletions(-) diff --git a/pytorch3d/renderer/mesh/rasterizer.py b/pytorch3d/renderer/mesh/rasterizer.py index f4284e8f..13ffe184 100644 --- a/pytorch3d/renderer/mesh/rasterizer.py +++ b/pytorch3d/renderer/mesh/rasterizer.py @@ -79,6 +79,7 @@ class MeshRasterizer(nn.Module): def to(self, device): # Manually move to device cameras as it is not a subclass of nn.Module self.cameras = self.cameras.to(device) + return self def transform(self, meshes_world, **kwargs) -> torch.Tensor: """ diff --git a/pytorch3d/renderer/mesh/renderer.py b/pytorch3d/renderer/mesh/renderer.py index fe2ba5d1..5dea1b57 100644 --- a/pytorch3d/renderer/mesh/renderer.py +++ b/pytorch3d/renderer/mesh/renderer.py @@ -37,6 +37,7 @@ class MeshRenderer(nn.Module): # Rasterizer and shader have submodules which are not of type nn.Module self.rasterizer.to(device) self.shader.to(device) + return self def forward(self, meshes_world, **kwargs) -> torch.Tensor: """ diff --git a/pytorch3d/renderer/mesh/shader.py b/pytorch3d/renderer/mesh/shader.py index 11b98409..c213457d 100644 --- a/pytorch3d/renderer/mesh/shader.py +++ b/pytorch3d/renderer/mesh/shader.py @@ -55,6 +55,7 @@ class HardPhongShader(nn.Module): self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) + return self def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -109,6 +110,7 @@ class SoftPhongShader(nn.Module): self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) + return self def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -168,6 +170,7 @@ class HardGouraudShader(nn.Module): self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) + return self def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -226,6 +229,7 @@ class SoftGouraudShader(nn.Module): self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) + return self def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) @@ -301,6 +305,7 @@ class HardFlatShader(nn.Module): self.cameras = self.cameras.to(device) self.materials = self.materials.to(device) self.lights = self.lights.to(device) + return self def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: cameras = kwargs.get("cameras", self.cameras) diff --git a/pytorch3d/renderer/points/rasterizer.py b/pytorch3d/renderer/points/rasterizer.py index 7a7ebcda..85e93e42 100644 --- a/pytorch3d/renderer/points/rasterizer.py +++ b/pytorch3d/renderer/points/rasterizer.py @@ -99,6 +99,11 @@ class PointsRasterizer(nn.Module): point_clouds = point_clouds.update_padded(pts_screen) return point_clouds + def to(self, device): + # Manually move to device cameras as it is not a subclass of nn.Module + self.cameras = self.cameras.to(device) + return self + def forward(self, point_clouds, **kwargs) -> PointFragments: """ Args: diff --git a/pytorch3d/renderer/points/renderer.py b/pytorch3d/renderer/points/renderer.py index 0f5f3120..31566369 100644 --- a/pytorch3d/renderer/points/renderer.py +++ b/pytorch3d/renderer/points/renderer.py @@ -32,6 +32,13 @@ class PointsRenderer(nn.Module): self.rasterizer = rasterizer self.compositor = compositor + def to(self, device): + # Manually move to device rasterizer as the cameras + # within the class are not of type nn.Module + self.rasterizer = self.rasterizer.to(device) + self.compositor = self.compositor.to(device) + return self + def forward(self, point_clouds, **kwargs) -> torch.Tensor: fragments = self.rasterizer(point_clouds, **kwargs) diff --git a/tests/test_render_multigpu.py b/tests/test_render_multigpu.py index 298ddb7f..2e29aa0a 100644 --- a/tests/test_render_multigpu.py +++ b/tests/test_render_multigpu.py @@ -6,18 +6,22 @@ import torch import torch.nn as nn from common_testing import TestCaseMixin, get_random_cuda_device from pytorch3d.renderer import ( + AlphaCompositor, BlendParams, HardGouraudShader, Materials, MeshRasterizer, MeshRenderer, PointLights, + PointsRasterizationSettings, + PointsRasterizer, + PointsRenderer, RasterizationSettings, SoftPhongShader, TexturesVertex, ) from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform -from pytorch3d.structures.meshes import Meshes +from pytorch3d.structures import Meshes, Pointclouds from pytorch3d.utils.ico_sphere import ico_sphere @@ -27,7 +31,7 @@ GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)}) print("GPUs: %s" % ", ".join(GPU_LIST)) -class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase): +class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase): def _check_mesh_renderer_props_on_device(self, renderer, device): """ Helper function to check that all the properties of the mesh @@ -99,7 +103,7 @@ class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase): # This also tests that background_color is correctly moved to # the new device device2 = torch.device("cuda:0") - renderer.to(device2) + renderer = renderer.to(device2) mesh = mesh.to(device2) self._check_mesh_renderer_props_on_device(renderer, device2) output_images = renderer(mesh) @@ -137,7 +141,7 @@ class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase): def forward(self, verts, texs): batch_size = verts.size(0) - self.renderer.to(verts.device) + self.renderer = self.renderer.to(verts.device) tex = TexturesVertex(verts_features=texs) faces = self.faces.expand(batch_size, -1, -1).to(verts.device) mesh = Meshes(verts, faces, tex).to(verts.device) @@ -157,3 +161,53 @@ class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase): # Test a few iterations for _ in range(100): model(verts, texs) + + +class TestRenderPointssMultiGPU(TestCaseMixin, unittest.TestCase): + def _check_points_renderer_props_on_device(self, renderer, device): + """ + Helper function to check that all the properties have + been moved to the correct device. + """ + # Cameras + self.assertEqual(renderer.rasterizer.cameras.device, device) + self.assertEqual(renderer.rasterizer.cameras.R.device, device) + self.assertEqual(renderer.rasterizer.cameras.T.device, device) + + def test_points_renderer_to(self): + """ + Test moving all the tensors in the points renderer to a new device. + """ + + device1 = torch.device("cpu") + + R, T = look_at_view_transform(1500, 0.0, 0.0) + + raster_settings = PointsRasterizationSettings( + image_size=256, radius=0.001, points_per_pixel=1 + ) + cameras = FoVPerspectiveCameras( + device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100 + ) + rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings) + + renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor()) + + mesh = ico_sphere(2, device1) + verts_padded = mesh.verts_padded() + pointclouds = Pointclouds( + points=verts_padded, features=torch.randn_like(verts_padded) + ) + self._check_points_renderer_props_on_device(renderer, device1) + + # Test rendering on cpu + output_images = renderer(pointclouds) + self.assertEqual(output_images.device, device1) + + # Move renderer and pointclouds to another device and re render + device2 = torch.device("cuda:0") + renderer = renderer.to(device2) + pointclouds = pointclouds.to(device2) + self._check_points_renderer_props_on_device(renderer, device2) + output_images = renderer(pointclouds) + self.assertEqual(output_images.device, device2)