fix recent lint

Summary: lint clean again

Reviewed By: patricklabatut

Differential Revision: D20868775

fbshipit-source-id: ade4301c1012c5c6943186432465215701d635a9
This commit is contained in:
Jeremy Reizenstein
2020-04-06 06:38:50 -07:00
committed by Facebook GitHub Bot
parent 90dc7a0856
commit b87058c62a
6 changed files with 29 additions and 60 deletions

View File

@@ -1,12 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import warnings
from typing import List, Optional, Tuple, Union
import torch
from typing import List, Tuple, Union
from pytorch3d.structures.pointclouds import Pointclouds
from pytorch3d.structures import utils as strutil
import torch
from pytorch3d.ops import utils as oputil
from pytorch3d.structures import utils as strutil
from pytorch3d.structures.pointclouds import Pointclouds
def corresponding_points_alignment(
@@ -77,9 +77,7 @@ def corresponding_points_alignment(
weights = strutil.list_to_padded(weights)[..., 0]
if Xt.shape[:2] != weights.shape:
raise ValueError(
"weights should have the same first two dimensions as X."
)
raise ValueError("weights should have the same first two dimensions as X.")
b, n, dim = Xt.shape
@@ -120,9 +118,7 @@ def corresponding_points_alignment(
U, S, V = torch.svd(XYcov)
# identity matrix used for fixing reflections
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
b, 1, 1
)
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
if not allow_reflection:
# reflection test:

View File

@@ -27,7 +27,7 @@ def wmean(
* if `weights` is None => `mean(x, dim)`,
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
"""
args = dict(dim=dim, keepdim=keepdim)
args = {"dim": dim, "keepdim": keepdim}
if weight is None:
return x.mean(**args)
@@ -38,7 +38,6 @@ def wmean(
):
raise ValueError("wmean: weights are not compatible with the tensor")
return (
(x * weight[..., None]).sum(**args)
/ weight[..., None].sum(**args).clamp(eps)
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
eps
)