mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 15:50:39 +08:00
Return self in the to method for the renderer classes
Summary: Add `return self` to the `to` function for the renderer classes. Reviewed By: bottler Differential Revision: D25534487 fbshipit-source-id: e8dbd35524f0bd40e835439e93184b5a1f1532ca
This commit is contained in:
committed by
Facebook GitHub Bot
parent
831e64efb0
commit
19340462e4
@@ -99,6 +99,11 @@ class PointsRasterizer(nn.Module):
|
||||
point_clouds = point_clouds.update_padded(pts_screen)
|
||||
return point_clouds
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device cameras as it is not a subclass of nn.Module
|
||||
self.cameras = self.cameras.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, point_clouds, **kwargs) -> PointFragments:
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -32,6 +32,13 @@ class PointsRenderer(nn.Module):
|
||||
self.rasterizer = rasterizer
|
||||
self.compositor = compositor
|
||||
|
||||
def to(self, device):
|
||||
# Manually move to device rasterizer as the cameras
|
||||
# within the class are not of type nn.Module
|
||||
self.rasterizer = self.rasterizer.to(device)
|
||||
self.compositor = self.compositor.to(device)
|
||||
return self
|
||||
|
||||
def forward(self, point_clouds, **kwargs) -> torch.Tensor:
|
||||
fragments = self.rasterizer(point_clouds, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user