mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
64 lines
2.0 KiB
Python
64 lines
2.0 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
|
|
|
from typing import Optional
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
class TestCaseMixin(unittest.TestCase):
|
|
def assertSeparate(self, tensor1, tensor2) -> None:
|
|
"""
|
|
Verify that tensor1 and tensor2 have their data in distinct locations.
|
|
"""
|
|
self.assertNotEqual(tensor1.storage().data_ptr(), tensor2.storage().data_ptr())
|
|
|
|
def assertNotSeparate(self, tensor1, tensor2) -> None:
|
|
"""
|
|
Verify that tensor1 and tensor2 have their data in the same locations.
|
|
"""
|
|
self.assertEqual(tensor1.storage().data_ptr(), tensor2.storage().data_ptr())
|
|
|
|
def assertAllSeparate(self, tensor_list) -> None:
|
|
"""
|
|
Verify that all tensors in tensor_list have their data in
|
|
distinct locations.
|
|
"""
|
|
ptrs = [i.storage().data_ptr() for i in tensor_list]
|
|
self.assertCountEqual(ptrs, set(ptrs))
|
|
|
|
def assertClose(
|
|
self,
|
|
input,
|
|
other,
|
|
*,
|
|
rtol: float = 1e-05,
|
|
atol: float = 1e-08,
|
|
equal_nan: bool = False,
|
|
msg: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Verify that two tensors or arrays are the same shape and close.
|
|
Args:
|
|
input, other: two tensors or two arrays.
|
|
rtol, atol, equal_nan: as for torch.allclose.
|
|
msg: message in case the assertion is violated.
|
|
Note:
|
|
Optional arguments here are all keyword-only, to avoid confusion
|
|
with msg arguments on other assert functions.
|
|
"""
|
|
|
|
self.assertEqual(np.shape(input), np.shape(other))
|
|
|
|
if torch.is_tensor(input):
|
|
close = torch.allclose(
|
|
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
|
)
|
|
else:
|
|
close = np.allclose(
|
|
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
|
|
)
|
|
self.assertTrue(close, msg)
|