mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-16 01:15:59 +08:00
Adding utility methods to TensorProperties
Summary:
Context: in the code we are releasing with CO3D dataset, we use `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build.
Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed.
Reviewed By: bottler
Differential Revision: D29659529
fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
This commit is contained in:
committed by
Facebook GitHub Bot
parent
fa44a05567
commit
0c02ae907e
@@ -8,7 +8,7 @@
|
||||
import copy
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -174,6 +174,12 @@ class TensorProperties(nn.Module):
|
||||
setattr(self, k, v.to(device_))
|
||||
return self
|
||||
|
||||
def cpu(self) -> "TensorProperties":
|
||||
return self.to("cpu")
|
||||
|
||||
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
|
||||
return self.to(f"cuda:{device}" if device is not None else "cuda")
|
||||
|
||||
def clone(self, other) -> "TensorProperties":
|
||||
"""
|
||||
Update the tensor properties of other with the cloned properties of self.
|
||||
|
||||
Reference in New Issue
Block a user