diff --git a/docs/modules/common.rst b/docs/modules/common.rst new file mode 100644 index 00000000..7ca68690 --- /dev/null +++ b/docs/modules/common.rst @@ -0,0 +1,6 @@ +pytorch3d.common +=========================== + +.. automodule:: pytorch3d.common + :members: + :undoc-members: diff --git a/docs/modules/index.rst b/docs/modules/index.rst index 88ddcb6d..4c95ef18 100644 --- a/docs/modules/index.rst +++ b/docs/modules/index.rst @@ -3,6 +3,7 @@ API Documentation .. toctree:: + common structures io loss diff --git a/pytorch3d/common/__init__.py b/pytorch3d/common/__init__.py index 10a55772..f34c1017 100644 --- a/pytorch3d/common/__init__.py +++ b/pytorch3d/common/__init__.py @@ -3,3 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + +from .types import Device, make_device, get_device + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/pytorch3d/common/types.py b/pytorch3d/common/types.py index 622c24e5..ab15a184 100644 --- a/pytorch3d/common/types.py +++ b/pytorch3d/common/types.py @@ -13,10 +13,33 @@ Device = Union[str, torch.device] def make_device(device: Device) -> torch.device: + """ + Makes an actual torch.device object from the device specified as + either a string or torch.device object. + + Args: + device: Device (as str or torch.device) + + Returns: + A matching torch.device object + """ return torch.device(device) if isinstance(device, str) else device def get_device(x, device: Optional[Device] = None) -> torch.device: + """ + Gets the device of the specified variable x if it is a tensor, or + falls back to a default CPU device otherwise. Allows overriding by + providing an explicit device. + + Args: + x: a torch.Tensor to get the device from or another type + device: Device (as str or torch.device) to fall back to + + Returns: + A matching torch.device object + """ + # User overrides device if device is not None: return make_device(device)