mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Support for moving the renderer to a new device
Summary: Support for moving all the tensors of the renderer to another device by calling `renderer.to(new_device)` Currently the `MeshRenderer`, `MeshRasterizer` and `SoftPhongShader` (and other shaders) are all of type `nn.Module` which already supports easily moving tensors of submodules (defined as class attributes) to a different device. However the class attributes of the rasterizer and shader (e.g. cameras, lights, materials), are of type `TensorProperties`, not nn.Module so we need to explicity create a `to` method to move these tensors to device. Note that the `TensorProperties` class already has a `to` method so we only need to call `cameras.to(device)` and don't need to worry about the internal tensors. The other option is of course making these other classes (cameras, lights etc) also of type nn.Module. Reviewed By: gkioxari Differential Revision: D23885107 fbshipit-source-id: d71565c442181f739de4d797076ed5d00fb67f8e
This commit is contained in:
parent
b1eee579fd
commit
956d3a010c
@ -46,7 +46,7 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||
|
||||
if torch.is_tensor(blend_params.background_color):
|
||||
background_color = blend_params.background_color
|
||||
background_color = blend_params.background_color.to(device)
|
||||
else:
|
||||
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
||||
|
||||
@ -163,6 +163,8 @@ def softmax_rgb_blend(
|
||||
background = blend_params.background_color
|
||||
if not torch.is_tensor(background):
|
||||
background = torch.tensor(background, dtype=torch.float32, device=device)
|
||||
else:
|
||||
background = background.to(device)
|
||||
|
||||
# Weight for background color
|
||||
eps = 1e-10
|
||||
|
@ -76,6 +76,10 @@ class MeshRasterizer(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.raster_settings = raster_settings
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
|
||||
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
|
@ -33,6 +33,11 @@ class MeshRenderer(nn.Module):
|
||||
self.rasterizer = rasterizer
|
||||
self.shader = shader
|
||||
|
||||
def to(self, device):
|
||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||
self.rasterizer.to(device)
|
||||
self.shader.to(device)
|
||||
|
||||
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
Render a batch of images from a batch of meshes by rasterizing and then
|
||||
@ -44,6 +49,7 @@ class MeshRenderer(nn.Module):
|
||||
face f, clipping is required before interpolating the texture uv
|
||||
coordinates and z buffer so that the colors and depths are limited to
|
||||
the range for the corresponding face.
|
||||
For this set rasterizer.raster_settings.clip_barycentric_coords=True
|
||||
"""
|
||||
fragments = self.rasterizer(meshes_world, **kwargs)
|
||||
images = self.shader(fragments, meshes_world, **kwargs)
|
||||
|
@ -50,6 +50,12 @@ class HardPhongShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -98,6 +104,12 @@ class SoftPhongShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -151,6 +163,12 @@ class HardGouraudShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -203,6 +221,12 @@ class SoftGouraudShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
@ -272,6 +296,12 @@ class HardFlatShader(nn.Module):
|
||||
self.cameras = cameras
|
||||
self.blend_params = blend_params if blend_params is not None else BlendParams()
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device modules which are not subclasses of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
self.materials = self.materials.to(device)
|
||||
self.lights = self.lights.to(device)
|
||||
|
||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||
cameras = kwargs.get("cameras", self.cameras)
|
||||
if cameras is None:
|
||||
|
@ -1042,3 +1042,67 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_to(self):
|
||||
# Test moving all the tensors in the renderer to a new device
|
||||
# to support multigpu rendering.
|
||||
device1 = torch.device("cpu")
|
||||
|
||||
R, T = look_at_view_transform(1500, 0.0, 0.0)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(device=device1)
|
||||
lights = PointLights(device=device1)
|
||||
lights.location = torch.tensor([0.0, 0.0, +1000.0], device=device1)[None]
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=256, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
cameras = FoVPerspectiveCameras(
|
||||
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
|
||||
)
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
|
||||
blend_params = BlendParams(
|
||||
1e-4,
|
||||
1e-4,
|
||||
background_color=torch.zeros(3, dtype=torch.float32, device=device1),
|
||||
)
|
||||
|
||||
shader = SoftPhongShader(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
|
||||
def _check_props_on_device(renderer, device):
|
||||
self.assertEqual(renderer.rasterizer.cameras.device, device)
|
||||
self.assertEqual(renderer.shader.cameras.device, device)
|
||||
self.assertEqual(renderer.shader.lights.device, device)
|
||||
self.assertEqual(renderer.shader.lights.ambient_color.device, device)
|
||||
self.assertEqual(renderer.shader.materials.device, device)
|
||||
self.assertEqual(renderer.shader.materials.ambient_color.device, device)
|
||||
|
||||
mesh = ico_sphere(2, device1)
|
||||
verts_padded = mesh.verts_padded()
|
||||
textures = TexturesVertex(
|
||||
verts_features=torch.ones_like(verts_padded, device=device1)
|
||||
)
|
||||
mesh.textures = textures
|
||||
_check_props_on_device(renderer, device1)
|
||||
|
||||
# Test rendering on cpu
|
||||
output_images = renderer(mesh)
|
||||
self.assertEqual(output_images.device, device1)
|
||||
|
||||
# Move renderer and mesh to another device and re render
|
||||
# This also tests that background_color is correctly moved to
|
||||
# the new device
|
||||
device2 = torch.device("cuda:0")
|
||||
renderer.to(device2)
|
||||
mesh = mesh.to(device2)
|
||||
_check_props_on_device(renderer, device2)
|
||||
output_images = renderer(mesh)
|
||||
self.assertEqual(output_images.device, device2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user