mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Return self in the to
method for the renderer classes
Summary: Add `return self` to the `to` function for the renderer classes. Reviewed By: bottler Differential Revision: D25534487 fbshipit-source-id: e8dbd35524f0bd40e835439e93184b5a1f1532ca
This commit is contained in:
parent
831e64efb0
commit
19340462e4
@ -79,6 +79,7 @@ class MeshRasterizer(nn.Module):
|
|||||||
def to(self, device):
|
def to(self, device):
|
||||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
|
def transform(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -37,6 +37,7 @@ class MeshRenderer(nn.Module):
|
|||||||
# Rasterizer and shader have submodules which are not of type nn.Module
|
# Rasterizer and shader have submodules which are not of type nn.Module
|
||||||
self.rasterizer.to(device)
|
self.rasterizer.to(device)
|
||||||
self.shader.to(device)
|
self.shader.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -55,6 +55,7 @@ class HardPhongShader(nn.Module):
|
|||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -109,6 +110,7 @@ class SoftPhongShader(nn.Module):
|
|||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -168,6 +170,7 @@ class HardGouraudShader(nn.Module):
|
|||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -226,6 +229,7 @@ class SoftGouraudShader(nn.Module):
|
|||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
@ -301,6 +305,7 @@ class HardFlatShader(nn.Module):
|
|||||||
self.cameras = self.cameras.to(device)
|
self.cameras = self.cameras.to(device)
|
||||||
self.materials = self.materials.to(device)
|
self.materials = self.materials.to(device)
|
||||||
self.lights = self.lights.to(device)
|
self.lights = self.lights.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
|
||||||
cameras = kwargs.get("cameras", self.cameras)
|
cameras = kwargs.get("cameras", self.cameras)
|
||||||
|
@ -99,6 +99,11 @@ class PointsRasterizer(nn.Module):
|
|||||||
point_clouds = point_clouds.update_padded(pts_screen)
|
point_clouds = point_clouds.update_padded(pts_screen)
|
||||||
return point_clouds
|
return point_clouds
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||||
|
self.cameras = self.cameras.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, point_clouds, **kwargs) -> PointFragments:
|
def forward(self, point_clouds, **kwargs) -> PointFragments:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
|
@ -32,6 +32,13 @@ class PointsRenderer(nn.Module):
|
|||||||
self.rasterizer = rasterizer
|
self.rasterizer = rasterizer
|
||||||
self.compositor = compositor
|
self.compositor = compositor
|
||||||
|
|
||||||
|
def to(self, device):
|
||||||
|
# Manually move to device rasterizer as the cameras
|
||||||
|
# within the class are not of type nn.Module
|
||||||
|
self.rasterizer = self.rasterizer.to(device)
|
||||||
|
self.compositor = self.compositor.to(device)
|
||||||
|
return self
|
||||||
|
|
||||||
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
||||||
fragments = self.rasterizer(point_clouds, **kwargs)
|
fragments = self.rasterizer(point_clouds, **kwargs)
|
||||||
|
|
||||||
|
@ -6,18 +6,22 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from common_testing import TestCaseMixin, get_random_cuda_device
|
from common_testing import TestCaseMixin, get_random_cuda_device
|
||||||
from pytorch3d.renderer import (
|
from pytorch3d.renderer import (
|
||||||
|
AlphaCompositor,
|
||||||
BlendParams,
|
BlendParams,
|
||||||
HardGouraudShader,
|
HardGouraudShader,
|
||||||
Materials,
|
Materials,
|
||||||
MeshRasterizer,
|
MeshRasterizer,
|
||||||
MeshRenderer,
|
MeshRenderer,
|
||||||
PointLights,
|
PointLights,
|
||||||
|
PointsRasterizationSettings,
|
||||||
|
PointsRasterizer,
|
||||||
|
PointsRenderer,
|
||||||
RasterizationSettings,
|
RasterizationSettings,
|
||||||
SoftPhongShader,
|
SoftPhongShader,
|
||||||
TexturesVertex,
|
TexturesVertex,
|
||||||
)
|
)
|
||||||
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
from pytorch3d.renderer.cameras import FoVPerspectiveCameras, look_at_view_transform
|
||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures import Meshes, Pointclouds
|
||||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
|
|
||||||
|
|
||||||
@ -27,7 +31,7 @@ GPU_LIST = list({get_random_cuda_device() for _ in range(NUM_GPUS)})
|
|||||||
print("GPUs: %s" % ", ".join(GPU_LIST))
|
print("GPUs: %s" % ", ".join(GPU_LIST))
|
||||||
|
|
||||||
|
|
||||||
class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase):
|
class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||||
def _check_mesh_renderer_props_on_device(self, renderer, device):
|
def _check_mesh_renderer_props_on_device(self, renderer, device):
|
||||||
"""
|
"""
|
||||||
Helper function to check that all the properties of the mesh
|
Helper function to check that all the properties of the mesh
|
||||||
@ -99,7 +103,7 @@ class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase):
|
|||||||
# This also tests that background_color is correctly moved to
|
# This also tests that background_color is correctly moved to
|
||||||
# the new device
|
# the new device
|
||||||
device2 = torch.device("cuda:0")
|
device2 = torch.device("cuda:0")
|
||||||
renderer.to(device2)
|
renderer = renderer.to(device2)
|
||||||
mesh = mesh.to(device2)
|
mesh = mesh.to(device2)
|
||||||
self._check_mesh_renderer_props_on_device(renderer, device2)
|
self._check_mesh_renderer_props_on_device(renderer, device2)
|
||||||
output_images = renderer(mesh)
|
output_images = renderer(mesh)
|
||||||
@ -137,7 +141,7 @@ class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def forward(self, verts, texs):
|
def forward(self, verts, texs):
|
||||||
batch_size = verts.size(0)
|
batch_size = verts.size(0)
|
||||||
self.renderer.to(verts.device)
|
self.renderer = self.renderer.to(verts.device)
|
||||||
tex = TexturesVertex(verts_features=texs)
|
tex = TexturesVertex(verts_features=texs)
|
||||||
faces = self.faces.expand(batch_size, -1, -1).to(verts.device)
|
faces = self.faces.expand(batch_size, -1, -1).to(verts.device)
|
||||||
mesh = Meshes(verts, faces, tex).to(verts.device)
|
mesh = Meshes(verts, faces, tex).to(verts.device)
|
||||||
@ -157,3 +161,53 @@ class TestRenderMultiGPU(TestCaseMixin, unittest.TestCase):
|
|||||||
# Test a few iterations
|
# Test a few iterations
|
||||||
for _ in range(100):
|
for _ in range(100):
|
||||||
model(verts, texs)
|
model(verts, texs)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRenderPointssMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||||
|
def _check_points_renderer_props_on_device(self, renderer, device):
|
||||||
|
"""
|
||||||
|
Helper function to check that all the properties have
|
||||||
|
been moved to the correct device.
|
||||||
|
"""
|
||||||
|
# Cameras
|
||||||
|
self.assertEqual(renderer.rasterizer.cameras.device, device)
|
||||||
|
self.assertEqual(renderer.rasterizer.cameras.R.device, device)
|
||||||
|
self.assertEqual(renderer.rasterizer.cameras.T.device, device)
|
||||||
|
|
||||||
|
def test_points_renderer_to(self):
|
||||||
|
"""
|
||||||
|
Test moving all the tensors in the points renderer to a new device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
device1 = torch.device("cpu")
|
||||||
|
|
||||||
|
R, T = look_at_view_transform(1500, 0.0, 0.0)
|
||||||
|
|
||||||
|
raster_settings = PointsRasterizationSettings(
|
||||||
|
image_size=256, radius=0.001, points_per_pixel=1
|
||||||
|
)
|
||||||
|
cameras = FoVPerspectiveCameras(
|
||||||
|
device=device1, R=R, T=T, aspect_ratio=1.0, fov=60.0, zfar=100
|
||||||
|
)
|
||||||
|
rasterizer = PointsRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||||
|
|
||||||
|
renderer = PointsRenderer(rasterizer=rasterizer, compositor=AlphaCompositor())
|
||||||
|
|
||||||
|
mesh = ico_sphere(2, device1)
|
||||||
|
verts_padded = mesh.verts_padded()
|
||||||
|
pointclouds = Pointclouds(
|
||||||
|
points=verts_padded, features=torch.randn_like(verts_padded)
|
||||||
|
)
|
||||||
|
self._check_points_renderer_props_on_device(renderer, device1)
|
||||||
|
|
||||||
|
# Test rendering on cpu
|
||||||
|
output_images = renderer(pointclouds)
|
||||||
|
self.assertEqual(output_images.device, device1)
|
||||||
|
|
||||||
|
# Move renderer and pointclouds to another device and re render
|
||||||
|
device2 = torch.device("cuda:0")
|
||||||
|
renderer = renderer.to(device2)
|
||||||
|
pointclouds = pointclouds.to(device2)
|
||||||
|
self._check_points_renderer_props_on_device(renderer, device2)
|
||||||
|
output_images = renderer(pointclouds)
|
||||||
|
self.assertEqual(output_images.device, device2)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user