mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 04:12:48 +08:00
Introduce device type and utility functions
Summary: Introduce device type and utility functions in common types module Reviewed By: nikhilaravi Differential Revision: D28970930 fbshipit-source-id: 191ec07390ed66a958c23eb2b43229312492e0b7
This commit is contained in:
parent
07da36d4c8
commit
48faf8eb7e
25
pytorch3d/common/types.py
Normal file
25
pytorch3d/common/types.py
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
Device = Union[str, torch.device]
|
||||||
|
|
||||||
|
|
||||||
|
def make_device(device: Device) -> torch.device:
|
||||||
|
return torch.device(device) if isinstance(device, str) else device
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(x, device: Optional[Device] = None) -> torch.device:
|
||||||
|
# User overrides device
|
||||||
|
if device is not None:
|
||||||
|
return make_device(device)
|
||||||
|
|
||||||
|
# Set device based on input tensor
|
||||||
|
if torch.is_tensor(x):
|
||||||
|
return x.device
|
||||||
|
|
||||||
|
# Default device is cpu
|
||||||
|
return torch.device("cpu")
|
Loading…
x
Reference in New Issue
Block a user