Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
Patrick Labatut
2020-03-29 14:46:33 -07:00
committed by Facebook GitHub Bot
parent eb512ffde3
commit d57daa6f85
110 changed files with 705 additions and 1850 deletions

View File

@@ -1,9 +1,10 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import numpy as np
import warnings
from typing import Any, Union
import numpy as np
import torch
@@ -45,10 +46,7 @@ class TensorAccessor(object):
# Convert the attribute to a tensor if it is not a tensor.
if not torch.is_tensor(value):
value = torch.tensor(
value,
device=v.device,
dtype=v.dtype,
requires_grad=v.requires_grad,
value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad
)
# Check the shapes match the existing shape and the shape of the index.
@@ -253,9 +251,7 @@ class TensorProperties(object):
return self
def format_tensor(
input, dtype=torch.float32, device: str = "cpu"
) -> torch.Tensor:
def format_tensor(input, dtype=torch.float32, device: str = "cpu") -> torch.Tensor:
"""
Helper function for converting a scalar value to a tensor.
@@ -276,9 +272,7 @@ def format_tensor(
return input
def convert_to_tensors_and_broadcast(
*args, dtype=torch.float32, device: str = "cpu"
):
def convert_to_tensors_and_broadcast(*args, dtype=torch.float32, device: str = "cpu"):
"""
Helper function to handle parsing an arbitrary number of inputs (*args)
which all need to have the same batch dimension.