mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: Quote formats, spelling Reviewed By: shapovalov Differential Revision: D40913734 fbshipit-source-id: d6dea65d5204b3c463c656a07ef9b447b7be6a0a
204 lines
5.0 KiB
Python
204 lines
5.0 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import collections
|
|
import dataclasses
|
|
import time
|
|
from contextlib import contextmanager
|
|
from typing import Any, Callable, Dict, Iterable, Iterator
|
|
|
|
import torch
|
|
|
|
|
|
@contextmanager
|
|
def evaluating(net: torch.nn.Module):
|
|
"""Temporarily switch to evaluation mode."""
|
|
istrain = net.training
|
|
try:
|
|
net.eval()
|
|
yield net
|
|
finally:
|
|
if istrain:
|
|
net.train()
|
|
|
|
|
|
def try_to_cuda(t: Any) -> Any:
|
|
"""
|
|
Try to move the input variable `t` to a cuda device.
|
|
|
|
Args:
|
|
t: Input.
|
|
|
|
Returns:
|
|
t_cuda: `t` moved to a cuda device, if supported.
|
|
"""
|
|
try:
|
|
t = t.cuda()
|
|
except AttributeError:
|
|
pass
|
|
return t
|
|
|
|
|
|
def try_to_cpu(t: Any) -> Any:
|
|
"""
|
|
Try to move the input variable `t` to a cpu device.
|
|
|
|
Args:
|
|
t: Input.
|
|
|
|
Returns:
|
|
t_cpu: `t` moved to a cpu device, if supported.
|
|
"""
|
|
try:
|
|
t = t.cpu()
|
|
except AttributeError:
|
|
pass
|
|
return t
|
|
|
|
|
|
def dict_to_cuda(batch: Dict[Any, Any]) -> Dict[Any, Any]:
|
|
"""
|
|
Move all values in a dictionary to cuda if supported.
|
|
|
|
Args:
|
|
batch: Input dict.
|
|
|
|
Returns:
|
|
batch_cuda: `batch` moved to a cuda device, if supported.
|
|
"""
|
|
return {k: try_to_cuda(v) for k, v in batch.items()}
|
|
|
|
|
|
def dict_to_cpu(batch):
|
|
"""
|
|
Move all values in a dictionary to cpu if supported.
|
|
|
|
Args:
|
|
batch: Input dict.
|
|
|
|
Returns:
|
|
batch_cpu: `batch` moved to a cpu device, if supported.
|
|
"""
|
|
return {k: try_to_cpu(v) for k, v in batch.items()}
|
|
|
|
|
|
def dataclass_to_cuda_(obj):
|
|
"""
|
|
Move all contents of a dataclass to cuda inplace if supported.
|
|
|
|
Args:
|
|
batch: Input dataclass.
|
|
|
|
Returns:
|
|
batch_cuda: `batch` moved to a cuda device, if supported.
|
|
"""
|
|
for f in dataclasses.fields(obj):
|
|
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
|
return obj
|
|
|
|
|
|
def dataclass_to_cpu_(obj):
|
|
"""
|
|
Move all contents of a dataclass to cpu inplace if supported.
|
|
|
|
Args:
|
|
batch: Input dataclass.
|
|
|
|
Returns:
|
|
batch_cuda: `batch` moved to a cpu device, if supported.
|
|
"""
|
|
for f in dataclasses.fields(obj):
|
|
setattr(obj, f.name, try_to_cpu(getattr(obj, f.name)))
|
|
return obj
|
|
|
|
|
|
# TODO: test it
|
|
def cat_dataclass(batch, tensor_collator: Callable):
|
|
"""
|
|
Concatenate all fields of a list of dataclasses `batch` to a single
|
|
dataclass object using `tensor_collator`.
|
|
|
|
Args:
|
|
batch: Input list of dataclasses.
|
|
|
|
Returns:
|
|
concatenated_batch: All elements of `batch` concatenated to a single
|
|
dataclass object.
|
|
tensor_collator: The function used to concatenate tensor fields.
|
|
"""
|
|
|
|
elem = batch[0]
|
|
collated = {}
|
|
|
|
for f in dataclasses.fields(elem):
|
|
elem_f = getattr(elem, f.name)
|
|
if elem_f is None:
|
|
collated[f.name] = None
|
|
elif torch.is_tensor(elem_f):
|
|
collated[f.name] = tensor_collator([getattr(e, f.name) for e in batch])
|
|
elif dataclasses.is_dataclass(elem_f):
|
|
collated[f.name] = cat_dataclass(
|
|
[getattr(e, f.name) for e in batch], tensor_collator
|
|
)
|
|
elif isinstance(elem_f, collections.abc.Mapping):
|
|
collated[f.name] = {
|
|
k: tensor_collator([getattr(e, f.name)[k] for e in batch])
|
|
if elem_f[k] is not None
|
|
else None
|
|
for k in elem_f
|
|
}
|
|
else:
|
|
raise ValueError("Unsupported field type for concatenation")
|
|
|
|
return type(elem)(**collated)
|
|
|
|
|
|
def recursive_visitor(it: Iterable[Any]) -> Iterator[Any]:
|
|
for x in it:
|
|
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
|
yield from recursive_visitor(x)
|
|
else:
|
|
yield x
|
|
|
|
|
|
def get_inlier_indicators(
|
|
tensor: torch.Tensor, dim: int, outlier_rate: float
|
|
) -> torch.Tensor:
|
|
remove_elements = int(min(outlier_rate, 1.0) * tensor.shape[dim] / 2)
|
|
hi = torch.topk(tensor, remove_elements, dim=dim).indices.tolist()
|
|
lo = torch.topk(-tensor, remove_elements, dim=dim).indices.tolist()
|
|
remove_indices = set(recursive_visitor([hi, lo]))
|
|
keep_indices = tensor.new_ones(tensor.shape[dim : dim + 1], dtype=torch.bool)
|
|
keep_indices[list(remove_indices)] = False
|
|
return keep_indices
|
|
|
|
|
|
class Timer:
|
|
"""
|
|
A simple class for timing execution.
|
|
|
|
Example::
|
|
|
|
with Timer():
|
|
print("This print statement is timed.")
|
|
|
|
"""
|
|
|
|
def __init__(self, name="timer", quiet=False):
|
|
self.name = name
|
|
self.quiet = quiet
|
|
|
|
def __enter__(self):
|
|
self.start = time.time()
|
|
return self
|
|
|
|
def __exit__(self, *args):
|
|
self.end = time.time()
|
|
self.interval = self.end - self.start
|
|
if not self.quiet:
|
|
print("%20s: %1.6f sec" % (self.name, self.interval))
|