mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Weighted Umeyama.
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
This commit is contained in:
parent
e5b1d6d3a3
commit
e37085d999
@ -1,16 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import torch
|
||||
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.structures import utils as strutil
|
||||
from pytorch3d.ops import utils as oputil
|
||||
|
||||
|
||||
def corresponding_points_alignment(
|
||||
X: Union[torch.Tensor, Pointclouds],
|
||||
Y: Union[torch.Tensor, Pointclouds],
|
||||
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
|
||||
estimate_scale: bool = False,
|
||||
allow_reflection: bool = False,
|
||||
eps: float = 1e-8,
|
||||
@ -28,9 +30,14 @@ def corresponding_points_alignment(
|
||||
|
||||
Args:
|
||||
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
or a `Pointclouds` object.
|
||||
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
|
||||
or a `Pointclouds` object.
|
||||
or a `Pointclouds` object.
|
||||
weights: Batch of non-negative weights of
|
||||
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
|
||||
tensors that may have different shapes; in that case, the length of
|
||||
i-th tensor should be equal to the number of points in X_i and Y_i.
|
||||
Passing `None` means uniform weights.
|
||||
estimate_scale: If `True`, also estimates a scaling component `s`
|
||||
of the transformation. Otherwise assumes an identity
|
||||
scale and returns a tensor of ones.
|
||||
@ -59,25 +66,45 @@ def corresponding_points_alignment(
|
||||
"Point sets X and Y have to have the same \
|
||||
number of batches, points and dimensions."
|
||||
)
|
||||
if weights is not None:
|
||||
if isinstance(weights, list):
|
||||
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
|
||||
raise ValueError(
|
||||
"number of weights should equal to the "
|
||||
+ "number of points in the point cloud."
|
||||
)
|
||||
weights = [w[..., None] for w in weights]
|
||||
weights = strutil.list_to_padded(weights)[..., 0]
|
||||
|
||||
if Xt.shape[:2] != weights.shape:
|
||||
raise ValueError(
|
||||
"weights should have the same first two dimensions as X."
|
||||
)
|
||||
|
||||
b, n, dim = Xt.shape
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu[:, None]
|
||||
Yc = Yt - Ymu[:, None]
|
||||
|
||||
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
|
||||
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
|
||||
mask = (
|
||||
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
|
||||
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
|
||||
< num_points[:, None]
|
||||
).type_as(Xc)
|
||||
Xc *= mask[:, :, None]
|
||||
Yc *= mask[:, :, None]
|
||||
).type_as(Xt)
|
||||
weights = mask if weights is None else mask * weights.type_as(Xt)
|
||||
|
||||
# compute the centroids of the point sets
|
||||
Xmu = oputil.wmean(Xt, weights, eps=eps)
|
||||
Ymu = oputil.wmean(Yt, weights, eps=eps)
|
||||
|
||||
# mean-center the point sets
|
||||
Xc = Xt - Xmu
|
||||
Yc = Yt - Ymu
|
||||
|
||||
total_weight = torch.clamp(num_points, 1)
|
||||
# special handling for heterogeneous point clouds and/or input weights
|
||||
if weights is not None:
|
||||
Xc *= weights[:, :, None]
|
||||
Yc *= weights[:, :, None]
|
||||
total_weight = torch.clamp(weights.sum(1), eps)
|
||||
|
||||
if (num_points < (dim + 1)).any():
|
||||
warnings.warn(
|
||||
@ -87,7 +114,7 @@ def corresponding_points_alignment(
|
||||
|
||||
# compute the covariance XYcov between the point sets Xc, Yc
|
||||
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
|
||||
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
|
||||
XYcov = XYcov / total_weight[:, None, None]
|
||||
|
||||
# decompose the covariance matrix XYcov
|
||||
U, S, V = torch.svd(XYcov)
|
||||
@ -111,17 +138,16 @@ def corresponding_points_alignment(
|
||||
if estimate_scale:
|
||||
# estimate the scaling component of the transformation
|
||||
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
|
||||
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
|
||||
Xcov = (Xc * Xc).sum((1, 2)) / total_weight
|
||||
|
||||
# the scaling component
|
||||
s = trace_ES / torch.clamp(Xcov, eps)
|
||||
|
||||
# translation component
|
||||
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
|
||||
|
||||
T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
|
||||
else:
|
||||
# translation component
|
||||
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
|
||||
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]
|
||||
|
||||
# unit scaling since we do not estimate scale
|
||||
s = T.new_ones(b)
|
||||
|
44
pytorch3d/ops/utils.py
Normal file
44
pytorch3d/ops/utils.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def wmean(
|
||||
x: torch.Tensor,
|
||||
weight: Optional[torch.Tensor] = None,
|
||||
dim: Union[int, Tuple[int]] = -2,
|
||||
keepdim: bool = True,
|
||||
eps: float = 1e-9,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finds the mean of the input tensor across the specified dimension.
|
||||
If the `weight` argument is provided, computes weighted mean.
|
||||
Args:
|
||||
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
|
||||
weights: if given, non-negative tensor of shape `(*,)`. It must be
|
||||
broadcastable to `x.shape[:-1]`. Note that the weights for
|
||||
the last (spatial) dimension are assumed same;
|
||||
dim: dimension(s) in `x` to average over;
|
||||
keepdim: tells whether to keep the resulting singleton dimension.
|
||||
eps: minumum clamping value in the denominator.
|
||||
Returns:
|
||||
the mean tensor:
|
||||
* if `weights` is None => `mean(x, dim)`,
|
||||
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
|
||||
"""
|
||||
args = dict(dim=dim, keepdim=keepdim)
|
||||
|
||||
if weight is None:
|
||||
return x.mean(**args)
|
||||
|
||||
if any(
|
||||
xd != wd and xd != 1 and wd != 1
|
||||
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
|
||||
):
|
||||
raise ValueError("wmean: weights are not compatible with the tensor")
|
||||
|
||||
return (
|
||||
(x * weight[..., None]).sum(**args)
|
||||
/ weight[..., None].sum(**args).clamp(eps)
|
||||
)
|
@ -16,6 +16,7 @@ def bm_corresponding_points_alignment() -> None:
|
||||
"dim": [3, 20],
|
||||
"estimate_scale": [True, False],
|
||||
"n_points": [100, 10000],
|
||||
"random_weights": [False, True],
|
||||
"use_pointclouds": [False],
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import unittest
|
||||
|
||||
@ -35,13 +36,15 @@ class TestCaseMixin(unittest.TestCase):
|
||||
*,
|
||||
rtol: float = 1e-05,
|
||||
atol: float = 1e-08,
|
||||
equal_nan: bool = False
|
||||
equal_nan: bool = False,
|
||||
msg: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Verify that two tensors or arrays are the same shape and close.
|
||||
Args:
|
||||
input, other: two tensors or two arrays.
|
||||
rtol, atol, equal_nan: as for torch.allclose.
|
||||
msg: message in case the assertion is violated.
|
||||
Note:
|
||||
Optional arguments here are all keyword-only, to avoid confusion
|
||||
with msg arguments on other assert functions.
|
||||
@ -54,5 +57,7 @@ class TestCaseMixin(unittest.TestCase):
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
else:
|
||||
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
|
||||
self.assertTrue(close)
|
||||
close = np.allclose(
|
||||
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
||||
)
|
||||
self.assertTrue(close, msg)
|
||||
|
75
tests/test_ops_utils.py
Normal file
75
tests/test_ops_utils.py
Normal file
@ -0,0 +1,75 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
from pytorch3d.ops import utils as oputil
|
||||
|
||||
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def test_wmean(self):
|
||||
device = torch.device("cuda:0")
|
||||
n_points = 20
|
||||
|
||||
x = torch.rand(n_points, 3, device=device)
|
||||
weight = torch.rand(n_points, device=device)
|
||||
x_np = x.cpu().data.numpy()
|
||||
weight_np = weight.cpu().data.numpy()
|
||||
|
||||
# test unweighted
|
||||
mean = oputil.wmean(x, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test weighted
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test keepdim
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=True)
|
||||
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test binary weigths
|
||||
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test broadcasting
|
||||
x = torch.rand(10, n_points, 3, device=device)
|
||||
x_np = x.cpu().data.numpy()
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
weight = weight[None, None, :].repeat(3, 1, 1)
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
||||
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test failing broadcasting
|
||||
weight = torch.rand(x.shape[0], device=device)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
oputil.wmean(x, weight=weight, keepdim=False)
|
||||
self.assertTrue("weights are not compatible" in str(context.exception))
|
||||
|
||||
# test dim
|
||||
weight = torch.rand(x.shape[0], n_points, device=device)
|
||||
weight_np = np.tile(
|
||||
weight[:, :, None].cpu().data.numpy(),
|
||||
(1, 1, x_np.shape[-1]),
|
||||
)
|
||||
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=0, weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test dim tuple
|
||||
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
@ -6,6 +6,8 @@ import numpy as np
|
||||
import unittest
|
||||
import torch
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
from pytorch3d.ops import points_alignment
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
from pytorch3d.transforms import rotation_conversions
|
||||
@ -35,7 +37,7 @@ def _apply_pcl_transformation(X, R, T, s=None):
|
||||
return X_t
|
||||
|
||||
|
||||
class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
class TestCorrespondingPointsAlignment(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
@ -171,6 +173,7 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
estimate_scale=False,
|
||||
allow_reflection=False,
|
||||
reflect=False,
|
||||
random_weights=False,
|
||||
):
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
@ -198,12 +201,27 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
# point cloud X
|
||||
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
||||
|
||||
weights = None
|
||||
if random_weights:
|
||||
template = X.points_padded() if use_pointclouds else X
|
||||
weights = torch.rand_like(template[:, :, 0])
|
||||
weights = weights / weights.sum(dim=1, keepdim=True)
|
||||
# zero out some weights as zero weights are a common use case
|
||||
# this guarantees there are no zero weight
|
||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||
if use_pointclouds: # convert to List[Tensor]
|
||||
weights = [
|
||||
w[:npts]
|
||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
]
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def run_corresponding_points_alignment():
|
||||
points_alignment.corresponding_points_alignment(
|
||||
X,
|
||||
X_t,
|
||||
weights,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
@ -230,26 +248,28 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
"""
|
||||
|
||||
# run this for several different point cloud sizes
|
||||
for n_points in (100, 3, 2, 1, 0):
|
||||
for n_points in (100, 3, 2, 1):
|
||||
# run this for several different dimensionalities
|
||||
for dim in torch.arange(2, 10):
|
||||
for dim in range(2, 10):
|
||||
# switches whether we should use the Pointclouds inputs
|
||||
use_point_clouds_cases = (
|
||||
(True, False) if dim == 3 and n_points > 3 else (False,)
|
||||
)
|
||||
for use_pointclouds in use_point_clouds_cases:
|
||||
for estimate_scale in (False, True):
|
||||
for reflect in (False, True):
|
||||
for allow_reflection in (False, True):
|
||||
self._test_single_corresponding_points_alignment(
|
||||
batch_size=10,
|
||||
n_points=n_points,
|
||||
dim=int(dim),
|
||||
use_pointclouds=use_pointclouds,
|
||||
estimate_scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
allow_reflection=allow_reflection,
|
||||
)
|
||||
for random_weights in (False, True,):
|
||||
for use_pointclouds in use_point_clouds_cases:
|
||||
for estimate_scale in (False, True):
|
||||
for reflect in (False, True):
|
||||
for allow_reflection in (False, True):
|
||||
self._test_single_corresponding_points_alignment(
|
||||
batch_size=10,
|
||||
n_points=n_points,
|
||||
dim=dim,
|
||||
use_pointclouds=use_pointclouds,
|
||||
estimate_scale=estimate_scale,
|
||||
reflect=reflect,
|
||||
allow_reflection=allow_reflection,
|
||||
random_weights=random_weights,
|
||||
)
|
||||
|
||||
def _test_single_corresponding_points_alignment(
|
||||
self,
|
||||
@ -260,6 +280,7 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
estimate_scale=False,
|
||||
reflect=False,
|
||||
allow_reflection=False,
|
||||
random_weights=False,
|
||||
):
|
||||
"""
|
||||
Executes a single test for `corresponding_points_alignment` for a
|
||||
@ -294,6 +315,20 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
)
|
||||
R = torch.bmm(M, R)
|
||||
|
||||
weights = None
|
||||
if random_weights:
|
||||
template = X.points_padded() if use_pointclouds else X
|
||||
weights = torch.rand_like(template[:, :, 0])
|
||||
weights = weights / weights.sum(dim=1, keepdim=True)
|
||||
# zero out some weights as zero weights are a common use case
|
||||
# this guarantees there are no zero weight
|
||||
weights *= (weights * template.size()[1] > 0.3).to(weights)
|
||||
if use_pointclouds: # convert to List[Tensor]
|
||||
weights = [
|
||||
w[:npts]
|
||||
for w, npts in zip(weights, X.num_points_per_cloud())
|
||||
]
|
||||
|
||||
# apply the generated transformation to the generated
|
||||
# point cloud X
|
||||
X_t = _apply_pcl_transformation(X, R, T, s=s)
|
||||
@ -302,6 +337,7 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
R_est, T_est, s_est = points_alignment.corresponding_points_alignment(
|
||||
X,
|
||||
X_t,
|
||||
weights,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
@ -313,9 +349,40 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
f"use_pointclouds={use_pointclouds}, "
|
||||
f"estimate_scale={estimate_scale}, "
|
||||
f"reflect={reflect}, "
|
||||
f"allow_reflection={allow_reflection}."
|
||||
f"allow_reflection={allow_reflection},"
|
||||
f"random_weights={random_weights}."
|
||||
)
|
||||
|
||||
# if we test the weighted case, check that weights help with noise
|
||||
if random_weights and not use_pointclouds and n_points >= (dim + 10):
|
||||
# add noise to 20% points with smallest weight
|
||||
X_noisy = X_t.clone()
|
||||
_, mink_idx = torch.topk(-weights, int(n_points * 0.2), dim=1)
|
||||
mink_idx = mink_idx[:, :, None].expand(-1, -1, X_t.shape[-1])
|
||||
X_noisy.scatter_add_(
|
||||
1, mink_idx, 0.3 * torch.randn_like(mink_idx, dtype=X_t.dtype)
|
||||
)
|
||||
|
||||
def align_and_get_mse(weights_):
|
||||
R_n, T_n, s_n = points_alignment.corresponding_points_alignment(
|
||||
X_noisy,
|
||||
X_t,
|
||||
weights_,
|
||||
allow_reflection=allow_reflection,
|
||||
estimate_scale=estimate_scale,
|
||||
)
|
||||
|
||||
X_t_est = _apply_pcl_transformation(X_noisy, R_n, T_n, s=s_n)
|
||||
|
||||
return (
|
||||
((X_t_est - X_t) * weights[..., None]) ** 2
|
||||
).sum(dim=(1, 2)) / weights.sum(dim=-1)
|
||||
|
||||
# check that using weights leads to lower weighted_MSE(X_noisy, X_t)
|
||||
self.assertTrue(
|
||||
torch.all(align_and_get_mse(weights) <= align_and_get_mse(None))
|
||||
)
|
||||
|
||||
if reflect and not allow_reflection:
|
||||
# check that all rotations have det=1
|
||||
self._assert_all_close(
|
||||
@ -325,34 +392,44 @@ class TestCorrespondingPointsAlignment(unittest.TestCase):
|
||||
)
|
||||
|
||||
else:
|
||||
# mask out inputs with too few non-degenerate points for assertions
|
||||
w = (
|
||||
torch.ones_like(R_est[:, 0, 0])
|
||||
if weights is None or n_points >= dim + 10
|
||||
else (weights > 0.0).all(dim=1).to(R_est)
|
||||
)
|
||||
# check that the estimated tranformation is the same
|
||||
# as the ground truth
|
||||
if n_points >= (dim + 1):
|
||||
# the checks on transforms apply only when
|
||||
# the problem setup is unambiguous
|
||||
self._assert_all_close(R_est, R, assert_error_message)
|
||||
self._assert_all_close(T_est, T, assert_error_message)
|
||||
self._assert_all_close(s_est, s, assert_error_message)
|
||||
msg = assert_error_message
|
||||
self._assert_all_close(R_est, R, msg, w[:, None, None], atol=1e-5)
|
||||
self._assert_all_close(T_est, T, msg, w[:, None])
|
||||
self._assert_all_close(s_est, s, msg, w)
|
||||
|
||||
# check that the orthonormal part of the
|
||||
# transformation has a correct determinant (+1/-1)
|
||||
desired_det = R_est.new_ones(batch_size)
|
||||
if reflect:
|
||||
desired_det *= -1.0
|
||||
self._assert_all_close(
|
||||
torch.det(R_est), desired_det, assert_error_message
|
||||
)
|
||||
self._assert_all_close(torch.det(R_est), desired_det, msg, w)
|
||||
|
||||
# check that the transformed point cloud
|
||||
# X matches X_t
|
||||
X_t_est = _apply_pcl_transformation(X, R_est, T_est, s=s_est)
|
||||
self._assert_all_close(
|
||||
X_t, X_t_est, assert_error_message, atol=1e-5
|
||||
X_t, X_t_est, assert_error_message, w[:, None, None], atol=1e-5
|
||||
)
|
||||
|
||||
def _assert_all_close(self, a_, b_, err_message, atol=1e-6):
|
||||
def _assert_all_close(self, a_, b_, err_message, weights=None, atol=1e-6):
|
||||
if isinstance(a_, Pointclouds):
|
||||
a_ = a_.points_packed()
|
||||
if isinstance(b_, Pointclouds):
|
||||
b_ = b_.points_packed()
|
||||
self.assertTrue(torch.allclose(a_, b_, atol=atol), err_message)
|
||||
if weights is None:
|
||||
self.assertClose(a_, b_, atol=atol, msg=err_message)
|
||||
else:
|
||||
self.assertClose(
|
||||
a_ * weights, b_ * weights, atol=atol, msg=err_message
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user