pytorch3d/tests/test_ops_utils.py
Roman Shapovalov e37085d999 Weighted Umeyama.
Summary:
1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own.
2. Refactored to use the same code for the Pointclouds mask and passed weights.
3. Added test cases with random weights.
4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: https://github.com/pytorch/pytorch/issues/31421 ).

Reviewed By: gkioxari

Differential Revision: D20070293

fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
2020-04-03 02:59:11 -07:00

76 lines
2.7 KiB
Python

# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import utils as oputil
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
def test_wmean(self):
device = torch.device("cuda:0")
n_points = 20
x = torch.rand(n_points, 3, device=device)
weight = torch.rand(n_points, device=device)
x_np = x.cpu().data.numpy()
weight_np = weight.cpu().data.numpy()
# test unweighted
mean = oputil.wmean(x, keepdim=False)
mean_gt = np.average(x_np, axis=-2)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test weighted
mean = oputil.wmean(x, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test keepdim
mean = oputil.wmean(x, weight=weight, keepdim=True)
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
# test binary weigths
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test broadcasting
x = torch.rand(10, n_points, 3, device=device)
x_np = x.cpu().data.numpy()
mean = oputil.wmean(x, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
weight = weight[None, None, :].repeat(3, 1, 1)
mean = oputil.wmean(x, weight=weight, keepdim=False)
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
# test failing broadcasting
weight = torch.rand(x.shape[0], device=device)
with self.assertRaises(ValueError) as context:
oputil.wmean(x, weight=weight, keepdim=False)
self.assertTrue("weights are not compatible" in str(context.exception))
# test dim
weight = torch.rand(x.shape[0], n_points, device=device)
weight_np = np.tile(
weight[:, :, None].cpu().data.numpy(),
(1, 1, x_np.shape[-1]),
)
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=0, weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)
# test dim tuple
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
self.assertClose(mean.cpu().data.numpy(), mean_gt)