mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Summary: License lint codebase Reviewed By: theschnitz Differential Revision: D29001799 fbshipit-source-id: 5c59869911785b0181b1663bbf430bc8b7fb2909
79 lines
2.8 KiB
Python
79 lines
2.8 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import torch
|
|
from common_testing import TestCaseMixin
|
|
from pytorch3d.ops import utils as oputil
|
|
|
|
|
|
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
|
|
def setUp(self) -> None:
|
|
super().setUp()
|
|
torch.manual_seed(42)
|
|
np.random.seed(42)
|
|
|
|
def test_wmean(self):
|
|
device = torch.device("cuda:0")
|
|
n_points = 20
|
|
|
|
x = torch.rand(n_points, 3, device=device)
|
|
weight = torch.rand(n_points, device=device)
|
|
x_np = x.cpu().data.numpy()
|
|
weight_np = weight.cpu().data.numpy()
|
|
|
|
# test unweighted
|
|
mean = oputil.wmean(x, keepdim=False)
|
|
mean_gt = np.average(x_np, axis=-2)
|
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
|
|
|
# test weighted
|
|
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
|
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
|
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
|
|
|
# test keepdim
|
|
mean = oputil.wmean(x, weight=weight, keepdim=True)
|
|
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
|
|
|
|
# test binary weigths
|
|
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
|
|
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
|
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
|
|
|
# test broadcasting
|
|
x = torch.rand(10, n_points, 3, device=device)
|
|
x_np = x.cpu().data.numpy()
|
|
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
|
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
|
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
|
|
|
weight = weight[None, None, :].repeat(3, 1, 1)
|
|
mean = oputil.wmean(x, weight=weight, keepdim=False)
|
|
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
|
|
|
|
# test failing broadcasting
|
|
weight = torch.rand(x.shape[0], device=device)
|
|
with self.assertRaises(ValueError) as context:
|
|
oputil.wmean(x, weight=weight, keepdim=False)
|
|
self.assertTrue("weights are not compatible" in str(context.exception))
|
|
|
|
# test dim
|
|
weight = torch.rand(x.shape[0], n_points, device=device)
|
|
weight_np = np.tile(
|
|
weight[:, :, None].cpu().data.numpy(), (1, 1, x_np.shape[-1])
|
|
)
|
|
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
|
|
mean_gt = np.average(x_np, axis=0, weights=weight_np)
|
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|
|
|
|
# test dim tuple
|
|
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
|
|
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
|
|
self.assertClose(mean.cpu().data.numpy(), mean_gt)
|