mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Tidy uses of torch.device in Pointclouds
Summary: Tidy uses of `torch.device` in `Pointclouds`: - Allow `str` or `torch.device` in `Pointclouds.to()` method - Consistently use `torch.device` for internal type - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28970221 fbshipit-source-id: 3ca7104d4c0d9b20b0cff4f00e3ce931c5f1528a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
633d66f1f0
commit
1db40ac566
@@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
from ..common.types import Device, make_device
|
||||
from . import utils as struct_utils
|
||||
|
||||
|
||||
@@ -130,7 +131,7 @@ class Pointclouds:
|
||||
Refer to comments above for descriptions of List and Padded
|
||||
representations.
|
||||
"""
|
||||
self.device = None
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
# Indicates whether the clouds in the list/batch have the same number
|
||||
# of points.
|
||||
@@ -175,7 +176,6 @@ class Pointclouds:
|
||||
if isinstance(points, list):
|
||||
self._points_list = points
|
||||
self._N = len(self._points_list)
|
||||
self.device = torch.device("cpu")
|
||||
self.valid = torch.zeros((self._N,), dtype=torch.bool, device=self.device)
|
||||
self._num_points_per_cloud = []
|
||||
|
||||
@@ -700,7 +700,7 @@ class Pointclouds:
|
||||
setattr(other, k, v.detach())
|
||||
return other
|
||||
|
||||
def to(self, device, copy: bool = False):
|
||||
def to(self, device: Device, copy: bool = False):
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
If copy = True or the self Tensor is on a different device, the
|
||||
@@ -709,34 +709,39 @@ class Pointclouds:
|
||||
then self is returned.
|
||||
|
||||
Args:
|
||||
device: Device id for the new tensor.
|
||||
device: Device (as str or torch.device) for the new tensor.
|
||||
copy: Boolean indicator whether or not to clone self. Default False.
|
||||
|
||||
Returns:
|
||||
Pointclouds object.
|
||||
"""
|
||||
if not copy and self.device == device:
|
||||
device_ = make_device(device)
|
||||
|
||||
if not copy and self.device == device_:
|
||||
return self
|
||||
|
||||
other = self.clone()
|
||||
if self.device != device:
|
||||
other.device = device
|
||||
if other._N > 0:
|
||||
other._points_list = [v.to(device) for v in other.points_list()]
|
||||
if other._normals_list is not None:
|
||||
other._normals_list = [n.to(device) for n in other.normals_list()]
|
||||
if other._features_list is not None:
|
||||
other._features_list = [f.to(device) for f in other.features_list()]
|
||||
for k in self._INTERNAL_TENSORS:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.to(device))
|
||||
if self.device == device_:
|
||||
return other
|
||||
|
||||
other.device = device_
|
||||
if other._N > 0:
|
||||
other._points_list = [v.to(device_) for v in other.points_list()]
|
||||
if other._normals_list is not None:
|
||||
other._normals_list = [n.to(device_) for n in other.normals_list()]
|
||||
if other._features_list is not None:
|
||||
other._features_list = [f.to(device_) for f in other.features_list()]
|
||||
for k in self._INTERNAL_TENSORS:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.to(device_))
|
||||
return other
|
||||
|
||||
def cpu(self):
|
||||
return self.to(torch.device("cpu"))
|
||||
return self.to("cpu")
|
||||
|
||||
def cuda(self):
|
||||
return self.to(torch.device("cuda"))
|
||||
return self.to("cuda")
|
||||
|
||||
def get_cloud(self, index: int):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user