# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import collections import dataclasses import time from contextlib import contextmanager from typing import Any, Callable, Dict import torch @contextmanager def evaluating(net: torch.nn.Module): """Temporarily switch to evaluation mode.""" istrain = net.training try: net.eval() yield net finally: if istrain: net.train() def try_to_cuda(t: Any) -> Any: """ Try to move the input variable `t` to a cuda device. Args: t: Input. Returns: t_cuda: `t` moved to a cuda device, if supported. """ try: t = t.cuda() except AttributeError: pass return t def try_to_cpu(t: Any) -> Any: """ Try to move the input variable `t` to a cpu device. Args: t: Input. Returns: t_cpu: `t` moved to a cpu device, if supported. """ try: t = t.cpu() except AttributeError: pass return t def dict_to_cuda(batch: Dict[Any, Any]) -> Dict[Any, Any]: """ Move all values in a dictionary to cuda if supported. Args: batch: Input dict. Returns: batch_cuda: `batch` moved to a cuda device, if supported. """ return {k: try_to_cuda(v) for k, v in batch.items()} def dict_to_cpu(batch): """ Move all values in a dictionary to cpu if supported. Args: batch: Input dict. Returns: batch_cpu: `batch` moved to a cpu device, if supported. """ return {k: try_to_cpu(v) for k, v in batch.items()} def dataclass_to_cuda_(obj): """ Move all contents of a dataclass to cuda inplace if supported. Args: batch: Input dataclass. Returns: batch_cuda: `batch` moved to a cuda device, if supported. """ for f in dataclasses.fields(obj): setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) return obj def dataclass_to_cpu_(obj): """ Move all contents of a dataclass to cpu inplace if supported. Args: batch: Input dataclass. Returns: batch_cuda: `batch` moved to a cpu device, if supported. """ for f in dataclasses.fields(obj): setattr(obj, f.name, try_to_cpu(getattr(obj, f.name))) return obj # TODO: test it def cat_dataclass(batch, tensor_collator: Callable): """ Concatenate all fields of a list of dataclasses `batch` to a single dataclass object using `tensor_collator`. Args: batch: Input list of dataclasses. Returns: concatenated_batch: All elements of `batch` concatenated to a single dataclass object. tensor_collator: The function used to concatenate tensor fields. """ elem = batch[0] collated = {} for f in dataclasses.fields(elem): elem_f = getattr(elem, f.name) if elem_f is None: collated[f.name] = None elif torch.is_tensor(elem_f): collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch]) elif dataclasses.is_dataclass(elem_f): collated[f.name] = cat_dataclass( [getattr(e, f.name) for e in batch], tensor_collator ) elif isinstance(elem_f, collections.abc.Mapping): collated[f.name] = { k: tensor_collator([getattr(e, f.name)[k] for e in batch]) if elem_f[k] is not None else None for k in elem_f } else: raise ValueError("Unsupported field type for concatenation") return type(elem)(**collated) class Timer: """ A simple class for timing execution. Example: ``` with Timer(): print("This print statement is timed.") ``` """ def __init__(self, name="timer", quiet=False): self.name = name self.quiet = quiet def __enter__(self): self.start = time.time() return self def __exit__(self, *args): self.end = time.time() self.interval = self.end - self.start if not self.quiet: print("%20s: %1.6f sec" % (self.name, self.interval))