move LinearWithRepeat to pytorch3d

Summary: Move this simple layer from the NeRF project into pytorch3d.

Reviewed By: shapovalov

Differential Revision: D34126972

fbshipit-source-id: a9c6d6c3c1b662c1b844ea5d1b982007d4df83e6
This commit is contained in:
Jeremy Reizenstein
2022-02-14 04:51:02 -08:00
committed by Facebook GitHub Bot
parent ef21a6f6aa
commit 2a1de3b610
6 changed files with 75 additions and 8 deletions

View File

@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def test_simple(self):
x = torch.rand(4, 6, 7, 3)
y = torch.rand(4, 6, 4)
linear = torch.nn.Linear(7, 8)
torch.nn.init.xavier_uniform_(linear.weight.data)
linear.bias.data.uniform_()
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
expected = linear.forward(equivalent)
linear_with_repeat = LinearWithRepeat(7, 8)
linear_with_repeat.load_state_dict(linear.state_dict())
actual = linear_with_repeat.forward((x, y))
self.assertClose(actual, expected, rtol=1e-4)

View File

@@ -12,16 +12,16 @@ import torch
from common_testing import TestCaseMixin
from pytorch3d.ops import eyes
from pytorch3d.renderer import (
PerspectiveCameras,
AlphaCompositor,
PointsRenderer,
PerspectiveCameras,
PointsRasterizationSettings,
PointsRasterizer,
PointsRenderer,
)
from pytorch3d.renderer.utils import (
TensorProperties,
ndc_to_grid_sample_coords,
ndc_grid_sample,
ndc_to_grid_sample_coords,
)
from pytorch3d.structures import Pointclouds