Support for moving the renderer to a new device

Summary:
Support for moving all the tensors of the renderer to another device by calling `renderer.to(new_device)`

Currently the `MeshRenderer`, `MeshRasterizer` and `SoftPhongShader` (and other shaders) are all of type `nn.Module` which already supports easily moving tensors of submodules (defined as class attributes) to a different device. However the class attributes of the rasterizer and shader (e.g. cameras, lights, materials), are of type `TensorProperties`, not nn.Module so we need to explicity create a `to` method to move these tensors to device. Note that the `TensorProperties` class already has a `to` method so we only need to call `cameras.to(device)` and don't need to worry about the internal tensors.

The other option is of course making these other classes (cameras, lights etc) also of type nn.Module.

Reviewed By: gkioxari

Differential Revision: D23885107

fbshipit-source-id: d71565c442181f739de4d797076ed5d00fb67f8e
This commit is contained in:
Nikhila Ravi
2020-09-23 17:09:50 -07:00
committed by Facebook GitHub Bot
parent b1eee579fd
commit 956d3a010c
5 changed files with 107 additions and 1 deletions

View File

@@ -46,7 +46,7 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
if torch.is_tensor(blend_params.background_color):
background_color = blend_params.background_color
background_color = blend_params.background_color.to(device)
else:
background_color = colors.new_tensor(blend_params.background_color) # (3)
@@ -163,6 +163,8 @@ def softmax_rgb_blend(
background = blend_params.background_color
if not torch.is_tensor(background):
background = torch.tensor(background, dtype=torch.float32, device=device)
else:
background = background.to(device)
# Weight for background color
eps = 1e-10

View File

@@ -76,6 +76,10 @@ class MeshRasterizer(nn.Module):
self.cameras = cameras
self.raster_settings = raster_settings
def to(self, device):
# Manually move to device cameras as it is not a subclass of nn.Module
self.cameras = self.cameras.to(device)
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
"""
Args:

View File

@@ -33,6 +33,11 @@ class MeshRenderer(nn.Module):
self.rasterizer = rasterizer
self.shader = shader
def to(self, device):
# Rasterizer and shader have submodules which are not of type nn.Module
self.rasterizer.to(device)
self.shader.to(device)
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
"""
Render a batch of images from a batch of meshes by rasterizing and then
@@ -44,6 +49,7 @@ class MeshRenderer(nn.Module):
face f, clipping is required before interpolating the texture uv
coordinates and z buffer so that the colors and depths are limited to
the range for the corresponding face.
For this set rasterizer.raster_settings.clip_barycentric_coords=True
"""
fragments = self.rasterizer(meshes_world, **kwargs)
images = self.shader(fragments, meshes_world, **kwargs)

View File

@@ -50,6 +50,12 @@ class HardPhongShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@@ -98,6 +104,12 @@ class SoftPhongShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@@ -151,6 +163,12 @@ class HardGouraudShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@@ -203,6 +221,12 @@ class SoftGouraudShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None:
@@ -272,6 +296,12 @@ class HardFlatShader(nn.Module):
self.cameras = cameras
self.blend_params = blend_params if blend_params is not None else BlendParams()
def to(self, device):
# Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device)
self.lights = self.lights.to(device)
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
cameras = kwargs.get("cameras", self.cameras)
if cameras is None: