mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 23:30:35 +08:00
Weighted Umeyama.
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e5b1d6d3a3
commit
e37085d999
75
tests/test_ops_utils.py
Normal file
75
tests/test_ops_utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from common_testing import TestCaseMixin
|
||||
|
||||
from pytorch3d.ops import utils as oputil
|
||||
|
||||
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
super().setUp()
|
||||
torch.manual_seed(42)
|
||||
np.random.seed(42)
|
||||
|
||||
def test_wmean(self):
|
||||
device = torch.device("cuda:0")
|
||||
n_points = 20
|
||||
|
||||
x = torch.rand(n_points, 3, device=device)
|
||||
weight = torch.rand(n_points, device=device)
|
||||
x_np = x.cpu().data.numpy()
|
||||
weight_np = weight.cpu().data.numpy()
|
||||
|
||||
# test unweighted
|
||||
mean = oputil.wmean(x, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test weighted
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test keepdim
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=True)
|
||||
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test binary weigths
|
||||
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test broadcasting
|
||||
x = torch.rand(10, n_points, 3, device=device)
|
||||
x_np = x.cpu().data.numpy()
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
weight = weight[None, None, :].repeat(3, 1, 1)
|
||||
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
||||
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test failing broadcasting
|
||||
weight = torch.rand(x.shape[0], device=device)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
oputil.wmean(x, weight=weight, keepdim=False)
|
||||
self.assertTrue("weights are not compatible" in str(context.exception))
|
||||
|
||||
# test dim
|
||||
weight = torch.rand(x.shape[0], n_points, device=device)
|
||||
weight_np = np.tile(
|
||||
weight[:, :, None].cpu().data.numpy(),
|
||||
(1, 1, x_np.shape[-1]),
|
||||
)
|
||||
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=0, weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
|
||||
# test dim tuple
|
||||
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
|
||||
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
|
||||
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
||||
Reference in New Issue
Block a user