mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
fix recent lint
Summary: lint clean again Reviewed By: patricklabatut Differential Revision: D20868775 fbshipit-source-id: ade4301c1012c5c6943186432465215701d635a9
This commit is contained in:
parent
90dc7a0856
commit
b87058c62a
@ -1,12 +1,12 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
import torch
|
|
||||||
|
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
import torch
|
||||||
from pytorch3d.structures import utils as strutil
|
|
||||||
from pytorch3d.ops import utils as oputil
|
from pytorch3d.ops import utils as oputil
|
||||||
|
from pytorch3d.structures import utils as strutil
|
||||||
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
|
|
||||||
|
|
||||||
def corresponding_points_alignment(
|
def corresponding_points_alignment(
|
||||||
@ -77,9 +77,7 @@ def corresponding_points_alignment(
|
|||||||
weights = strutil.list_to_padded(weights)[..., 0]
|
weights = strutil.list_to_padded(weights)[..., 0]
|
||||||
|
|
||||||
if Xt.shape[:2] != weights.shape:
|
if Xt.shape[:2] != weights.shape:
|
||||||
raise ValueError(
|
raise ValueError("weights should have the same first two dimensions as X.")
|
||||||
"weights should have the same first two dimensions as X."
|
|
||||||
)
|
|
||||||
|
|
||||||
b, n, dim = Xt.shape
|
b, n, dim = Xt.shape
|
||||||
|
|
||||||
@ -120,9 +118,7 @@ def corresponding_points_alignment(
|
|||||||
U, S, V = torch.svd(XYcov)
|
U, S, V = torch.svd(XYcov)
|
||||||
|
|
||||||
# identity matrix used for fixing reflections
|
# identity matrix used for fixing reflections
|
||||||
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(
|
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
|
||||||
b, 1, 1
|
|
||||||
)
|
|
||||||
|
|
||||||
if not allow_reflection:
|
if not allow_reflection:
|
||||||
# reflection test:
|
# reflection test:
|
||||||
|
@ -27,7 +27,7 @@ def wmean(
|
|||||||
* if `weights` is None => `mean(x, dim)`,
|
* if `weights` is None => `mean(x, dim)`,
|
||||||
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
|
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
|
||||||
"""
|
"""
|
||||||
args = dict(dim=dim, keepdim=keepdim)
|
args = {"dim": dim, "keepdim": keepdim}
|
||||||
|
|
||||||
if weight is None:
|
if weight is None:
|
||||||
return x.mean(**args)
|
return x.mean(**args)
|
||||||
@ -38,7 +38,6 @@ def wmean(
|
|||||||
):
|
):
|
||||||
raise ValueError("wmean: weights are not compatible with the tensor")
|
raise ValueError("wmean: weights are not compatible with the tensor")
|
||||||
|
|
||||||
return (
|
return (x * weight[..., None]).sum(**args) / weight[..., None].sum(**args).clamp(
|
||||||
(x * weight[..., None]).sum(**args)
|
eps
|
||||||
/ weight[..., None].sum(**args).clamp(eps)
|
|
||||||
)
|
)
|
||||||
|
@ -3,8 +3,8 @@
|
|||||||
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from fvcore.common.benchmark import benchmark
|
|
||||||
|
|
||||||
|
from fvcore.common.benchmark import benchmark
|
||||||
from test_points_alignment import TestCorrespondingPointsAlignment
|
from test_points_alignment import TestCorrespondingPointsAlignment
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -57,7 +56,5 @@ class TestCaseMixin(unittest.TestCase):
|
|||||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
close = np.allclose(
|
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
|
||||||
)
|
|
||||||
self.assertTrue(close, msg)
|
self.assertTrue(close, msg)
|
||||||
|
@ -3,11 +3,10 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
|
||||||
from pytorch3d.ops import utils as oputil
|
from pytorch3d.ops import utils as oputil
|
||||||
|
|
||||||
|
|
||||||
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
super().setUp()
|
super().setUp()
|
||||||
@ -62,8 +61,7 @@ class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
|||||||
# test dim
|
# test dim
|
||||||
weight = torch.rand(x.shape[0], n_points, device=device)
|
weight = torch.rand(x.shape[0], n_points, device=device)
|
||||||
weight_np = np.tile(
|
weight_np = np.tile(
|
||||||
weight[:, :, None].cpu().data.numpy(),
|
weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1])
|
||||||
(1, 1, x_np.shape[-1]),
|
|
||||||
)
|
)
|
||||||
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
|
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
|
||||||
mean_gt = np.average(x_np, axis=0, weights=weight_np)
|
mean_gt = np.average(x_np, axis=0, weights=weight_np)
|
||||||
|
@ -2,12 +2,11 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
|
|
||||||
from pytorch3d.ops import points_alignment
|
from pytorch3d.ops import points_alignment
|
||||||
from pytorch3d.structures.pointclouds import Pointclouds
|
from pytorch3d.structures.pointclouds import Pointclouds
|
||||||
from pytorch3d.transforms import rotation_conversions
|
from pytorch3d.transforms import rotation_conversions
|
||||||
@ -54,18 +53,14 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
# generate random rotation matrices with orthogonalization of
|
# generate random rotation matrices with orthogonalization of
|
||||||
# random normal square matrices, followed by a transformation
|
# random normal square matrices, followed by a transformation
|
||||||
# that ensures determinant(R)==1
|
# that ensures determinant(R)==1
|
||||||
H = torch.randn(
|
H = torch.randn(batch_size, dim, dim, dtype=torch.float32, device=device)
|
||||||
batch_size, dim, dim, dtype=torch.float32, device=device
|
|
||||||
)
|
|
||||||
U, _, V = torch.svd(H)
|
U, _, V = torch.svd(H)
|
||||||
E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat(
|
E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat(
|
||||||
batch_size, 1, 1
|
batch_size, 1, 1
|
||||||
)
|
)
|
||||||
E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1)))
|
E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1)))
|
||||||
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
||||||
assert torch.allclose(
|
assert torch.allclose(torch.det(R), R.new_ones(batch_size), atol=1e-4)
|
||||||
torch.det(R), R.new_ones(batch_size), atol=1e-4
|
|
||||||
)
|
|
||||||
|
|
||||||
return R
|
return R
|
||||||
|
|
||||||
@ -94,19 +89,13 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
)
|
)
|
||||||
X_list = [
|
X_list = [
|
||||||
torch.randn(
|
torch.randn(int(n_pt), dim, device=device, dtype=torch.float32)
|
||||||
int(n_pt), dim, device=device, dtype=torch.float32
|
|
||||||
)
|
|
||||||
for n_pt in n_points_per_batch
|
for n_pt in n_points_per_batch
|
||||||
]
|
]
|
||||||
X = Pointclouds(X_list)
|
X = Pointclouds(X_list)
|
||||||
else:
|
else:
|
||||||
X = torch.randn(
|
X = torch.randn(
|
||||||
batch_size,
|
batch_size, n_points, dim, device=device, dtype=torch.float32
|
||||||
n_points,
|
|
||||||
dim,
|
|
||||||
device=device,
|
|
||||||
dtype=torch.float32,
|
|
||||||
)
|
)
|
||||||
X = Pointclouds(list(X))
|
X = Pointclouds(list(X))
|
||||||
else:
|
else:
|
||||||
@ -143,11 +132,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
# randomly select one of the dimensions to reflect for each
|
# randomly select one of the dimensions to reflect for each
|
||||||
# element in the batch
|
# element in the batch
|
||||||
dim_to_reflect = torch.randint(
|
dim_to_reflect = torch.randint(
|
||||||
low=0,
|
low=0, high=dim, size=(batch_size,), device=device, dtype=torch.int64
|
||||||
high=dim,
|
|
||||||
size=(batch_size,),
|
|
||||||
device=device,
|
|
||||||
dtype=torch.int64,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert dim_to_reflect to a batch of reflection matrices M
|
# convert dim_to_reflect to a batch of reflection matrices M
|
||||||
@ -211,8 +196,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||||
if use_pointclouds: # convert to List[Tensor]
|
if use_pointclouds: # convert to List[Tensor]
|
||||||
weights = [
|
weights = [
|
||||||
w[:npts]
|
w[:npts] for w, npts in zip(weights, X.num_points_per_cloud())
|
||||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
@ -255,7 +239,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
use_point_clouds_cases = (
|
use_point_clouds_cases = (
|
||||||
(True, False) if dim == 3 and n_points > 3 else (False,)
|
(True, False) if dim == 3 and n_points > 3 else (False,)
|
||||||
)
|
)
|
||||||
for random_weights in (False, True,):
|
for random_weights in (False, True):
|
||||||
for use_pointclouds in use_point_clouds_cases:
|
for use_pointclouds in use_point_clouds_cases:
|
||||||
for estimate_scale in (False, True):
|
for estimate_scale in (False, True):
|
||||||
for reflect in (False, True):
|
for reflect in (False, True):
|
||||||
@ -325,8 +309,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||||
if use_pointclouds: # convert to List[Tensor]
|
if use_pointclouds: # convert to List[Tensor]
|
||||||
weights = [
|
weights = [
|
||||||
w[:npts]
|
w[:npts] for w, npts in zip(weights, X.num_points_per_cloud())
|
||||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# apply the generated transformation to the generated
|
# apply the generated transformation to the generated
|
||||||
@ -374,9 +357,9 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)
|
X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)
|
||||||
|
|
||||||
return (
|
return (((X_t_est - X_t) * weights[..., None]) ** 2).sum(
|
||||||
((X_t_est - X_t) * weights[..., None]) ** 2
|
dim=(1, 2)
|
||||||
).sum(dim=(1, 2)) / weights.sum(dim=-1)
|
) / weights.sum(dim=-1)
|
||||||
|
|
||||||
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
|
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
@ -386,9 +369,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
if reflect and not allow_reflection:
|
if reflect and not allow_reflection:
|
||||||
# check that all rotations have det=1
|
# check that all rotations have det=1
|
||||||
self._assert_all_close(
|
self._assert_all_close(
|
||||||
torch.det(R_est),
|
torch.det(R_est), R_est.new_ones(batch_size), assert_error_message
|
||||||
R_est.new_ones(batch_size),
|
|
||||||
assert_error_message,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -430,6 +411,4 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
|||||||
if weights is None:
|
if weights is None:
|
||||||
self.assertClose(a_, b_, atol=atol, msg=err_message)
|
self.assertClose(a_, b_, atol=atol, msg=err_message)
|
||||||
else:
|
else:
|
||||||
self.assertClose(
|
self.assertClose(a_ * weights, b_ * weights, atol=atol, msg=err_message)
|
||||||
a_ * weights, b_ * weights, atol=atol, msg=err_message
|
|
||||||
)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user