mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 03:40:34 +08:00
fix recent lint
Summary: lint clean again Reviewed By: patricklabatut Differential Revision: D20868775 fbshipit-source-id: ade4301c1012c5c6943186432465215701d635a9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
90dc7a0856
commit
b87058c62a
@@ -2,12 +2,11 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
from pytorch3d.ops import points_alignment
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
@@ -54,18 +53,14 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
# generate random rotation matrices with orthogonalization of
|
||||
# random normal square matrices, followed by a transformation
|
||||
# that ensures determinant(R)==1
|
||||
H = torch.randn(
|
||||
batch_size, dim, dim, dtype=torch.float32, device=device
|
||||
)
|
||||
H = torch.randn(batch_size, dim, dim, dtype=torch.float32, device=device)
|
||||
U, _, V = torch.svd(H)
|
||||
E = torch.eye(dim, dtype=torch.float32, device=device)[None].repeat(
|
||||
batch_size, 1, 1
|
||||
)
|
||||
E[:, -1, -1] = torch.det(torch.bmm(U, V.transpose(2, 1)))
|
||||
R = torch.bmm(torch.bmm(U, E), V.transpose(2, 1))
|
||||
assert torch.allclose(
|
||||
torch.det(R), R.new_ones(batch_size), atol=1e-4
|
||||
)
|
||||
assert torch.allclose(torch.det(R), R.new_ones(batch_size), atol=1e-4)
|
||||
|
||||
return R
|
||||
|
||||
@@ -94,19 +89,13 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
dtype=torch.int64,
|
||||
)
|
||||
X_list = [
|
||||
torch.randn(
|
||||
int(n_pt), dim, device=device, dtype=torch.float32
|
||||
)
|
||||
torch.randn(int(n_pt), dim, device=device, dtype=torch.float32)
|
||||
for n_pt in n_points_per_batch
|
||||
]
|
||||
X = Pointclouds(X_list)
|
||||
else:
|
||||
X = torch.randn(
|
||||
batch_size,
|
||||
n_points,
|
||||
dim,
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
batch_size, n_points, dim, device=device, dtype=torch.float32
|
||||
)
|
||||
X = Pointclouds(list(X))
|
||||
else:
|
||||
@@ -143,11 +132,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
# randomly select one of the dimensions to reflect for each
|
||||
# element in the batch
|
||||
dim_to_reflect = torch.randint(
|
||||
low=0,
|
||||
high=dim,
|
||||
size=(batch_size,),
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
low=0, high=dim, size=(batch_size,), device=device, dtype=torch.int64
|
||||
)
|
||||
|
||||
# convert dim_to_reflect to a batch of reflection matrices M
|
||||
@@ -211,8 +196,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||
if use_pointclouds: # convert to List[Tensor]
|
||||
weights = [
|
||||
w[:npts]
|
||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
w[:npts] for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
@@ -255,7 +239,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
use_point_clouds_cases = (
|
||||
(True, False) if dim == 3 and n_points > 3 else (False,)
|
||||
)
|
||||
for random_weights in (False, True,):
|
||||
for random_weights in (False, True):
|
||||
for use_pointclouds in use_point_clouds_cases:
|
||||
for estimate_scale in (False, True):
|
||||
for reflect in (False, True):
|
||||
@@ -325,8 +309,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||
if use_pointclouds: # convert to List[Tensor]
|
||||
weights = [
|
||||
w[:npts]
|
||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
w[:npts] for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
]
|
||||
|
||||
# apply the generated transformation to the generated
|
||||
@@ -374,9 +357,9 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)
|
||||
|
||||
return (
|
||||
((X_t_est - X_t) * weights[..., None]) ** 2
|
||||
).sum(dim=(1, 2)) / weights.sum(dim=-1)
|
||||
return (((X_t_est - X_t) * weights[..., None]) ** 2).sum(
|
||||
dim=(1, 2)
|
||||
) / weights.sum(dim=-1)
|
||||
|
||||
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
|
||||
self.assertTrue(
|
||||
@@ -386,9 +369,7 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
if reflect and not allow_reflection:
|
||||
# check that all rotations have det=1
|
||||
self._assert_all_close(
|
||||
torch.det(R_est),
|
||||
R_est.new_ones(batch_size),
|
||||
assert_error_message,
|
||||
torch.det(R_est), R_est.new_ones(batch_size), assert_error_message
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -430,6 +411,4 @@ class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
if weights is None:
|
||||
self.assertClose(a_, b_, atol=atol, msg=err_message)
|
||||
else:
|
||||
self.assertClose(
|
||||
a_ * weights, b_ * weights, atol=atol, msg=err_message
|
||||
)
|
||||
self.assertClose(a_ * weights, b_ * weights, atol=atol, msg=err_message)
|
||||
|
||||
Reference in New Issue
Block a user