From 48faf8eb7ed1293f05156214bcfbb2dbbfa5c431 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 9 Jun 2021 15:48:56 -0700 Subject: [PATCH] Introduce device type and utility functions Summary: Introduce device type and utility functions in common types module Reviewed By: nikhilaravi Differential Revision: D28970930 fbshipit-source-id: 191ec07390ed66a958c23eb2b43229312492e0b7 --- pytorch3d/common/types.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 pytorch3d/common/types.py diff --git a/pytorch3d/common/types.py b/pytorch3d/common/types.py new file mode 100644 index 00000000..4a3bf10d --- /dev/null +++ b/pytorch3d/common/types.py @@ -0,0 +1,25 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Optional, Union + +import torch + + +Device = Union[str, torch.device] + + +def make_device(device: Device) -> torch.device: + return torch.device(device) if isinstance(device, str) else device + + +def get_device(x, device: Optional[Device] = None) -> torch.device: + # User overrides device + if device is not None: + return make_device(device) + + # Set device based on input tensor + if torch.is_tensor(x): + return x.device + + # Default device is cpu + return torch.device("cpu")