mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Adding utility methods to TensorProperties
Summary: Context: in the code we are releasing with CO3D dataset, we use `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build. Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed. Reviewed By: bottler Differential Revision: D29659529 fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
This commit is contained in:
parent
fa44a05567
commit
0c02ae907e
@ -15,7 +15,8 @@ Device = Union[str, torch.device]
|
||||
def make_device(device: Device) -> torch.device:
|
||||
"""
|
||||
Makes an actual torch.device object from the device specified as
|
||||
either a string or torch.device object.
|
||||
either a string or torch.device object. If the device is `cuda` without
|
||||
a specific index, the index of the current device is assigned.
|
||||
|
||||
Args:
|
||||
device: Device (as str or torch.device)
|
||||
@ -23,7 +24,12 @@ def make_device(device: Device) -> torch.device:
|
||||
Returns:
|
||||
A matching torch.device object
|
||||
"""
|
||||
return torch.device(device) if isinstance(device, str) else device
|
||||
device = torch.device(device) if isinstance(device, str) else device
|
||||
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
|
||||
# If cuda but with no index, then the current cuda device is indicated.
|
||||
# In that case, we fix to that device
|
||||
device = torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
return device
|
||||
|
||||
|
||||
def get_device(x, device: Optional[Device] = None) -> torch.device:
|
||||
|
@ -8,7 +8,7 @@
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -174,6 +174,12 @@ class TensorProperties(nn.Module):
|
||||
setattr(self, k, v.to(device_))
|
||||
return self
|
||||
|
||||
def cpu(self) -> "TensorProperties":
|
||||
return self.to("cpu")
|
||||
|
||||
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
|
||||
return self.to(f"cuda:{device}" if device is not None else "cuda")
|
||||
|
||||
def clone(self, other) -> "TensorProperties":
|
||||
"""
|
||||
Update the tensor properties of other with the cloned properties of self.
|
||||
|
@ -709,9 +709,9 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(cpu_device, mesh.device)
|
||||
self.assertIs(mesh, converted_mesh)
|
||||
|
||||
cuda_device = torch.device("cuda")
|
||||
cuda_device = torch.device("cuda:0")
|
||||
|
||||
converted_mesh = mesh.to("cuda")
|
||||
converted_mesh = mesh.to("cuda:0")
|
||||
self.assertEqual(cuda_device, converted_mesh.device)
|
||||
self.assertEqual(cpu_device, mesh.device)
|
||||
self.assertIsNot(mesh, converted_mesh)
|
||||
|
@ -39,7 +39,17 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
|
||||
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))
|
||||
device = torch.device("cuda:0")
|
||||
new_example = example.to(device=device)
|
||||
self.assertTrue(new_example.device == device)
|
||||
self.assertEqual(new_example.device, device)
|
||||
|
||||
example_cpu = example.cpu()
|
||||
self.assertEqual(example_cpu.device, torch.device("cpu"))
|
||||
|
||||
example_gpu = example.cuda()
|
||||
self.assertEqual(example_gpu.device.type, "cuda")
|
||||
self.assertIsNotNone(example_gpu.device.index)
|
||||
|
||||
example_gpu1 = example.cuda(1)
|
||||
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
|
||||
|
||||
def test_clone(self):
|
||||
# Check clone method
|
||||
|
@ -22,7 +22,7 @@ from pytorch3d.structures.meshes import Meshes
|
||||
class TestShader(TestCaseMixin, unittest.TestCase):
|
||||
def test_to(self):
|
||||
cpu_device = torch.device("cpu")
|
||||
cuda_device = torch.device("cuda")
|
||||
cuda_device = torch.device("cuda:0")
|
||||
|
||||
R, T = look_at_view_transform()
|
||||
|
||||
|
@ -50,9 +50,9 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cpu_t)
|
||||
|
||||
cuda_device = torch.device("cuda")
|
||||
cuda_device = torch.device("cuda:0")
|
||||
|
||||
cuda_t = t.to("cuda")
|
||||
cuda_t = t.to("cuda:0")
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||
|
Loading…
x
Reference in New Issue
Block a user