mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Introduce device type and utility functions
Summary: Introduce device type and utility functions in common types module Reviewed By: nikhilaravi Differential Revision: D28970930 fbshipit-source-id: 191ec07390ed66a958c23eb2b43229312492e0b7
This commit is contained in:
parent
07da36d4c8
commit
48faf8eb7e
25
pytorch3d/common/types.py
Normal file
25
pytorch3d/common/types.py
Normal file
@ -0,0 +1,25 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
Device = Union[str, torch.device]
|
||||
|
||||
|
||||
def make_device(device: Device) -> torch.device:
|
||||
return torch.device(device) if isinstance(device, str) else device
|
||||
|
||||
|
||||
def get_device(x, device: Optional[Device] = None) -> torch.device:
|
||||
# User overrides device
|
||||
if device is not None:
|
||||
return make_device(device)
|
||||
|
||||
# Set device based on input tensor
|
||||
if torch.is_tensor(x):
|
||||
return x.device
|
||||
|
||||
# Default device is cpu
|
||||
return torch.device("cpu")
|
Loading…
x
Reference in New Issue
Block a user