move LinearWithRepeat to pytorch3d

Summary: Move this simple layer from the NeRF project into pytorch3d.

Reviewed By: shapovalov

Differential Revision: D34126972

fbshipit-source-id: a9c6d6c3c1b662c1b844ea5d1b982007d4df83e6
This commit is contained in:
Jeremy Reizenstein 2022-02-14 04:51:02 -08:00 committed by Facebook GitHub Bot
parent ef21a6f6aa
commit 2a1de3b610
6 changed files with 75 additions and 8 deletions

View File

@ -7,10 +7,9 @@
from typing import Tuple from typing import Tuple
import torch import torch
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points from pytorch3d.renderer import HarmonicEmbedding, RayBundle, ray_bundle_to_ray_points
from .linear_with_repeat import LinearWithRepeat
def _xavier_init(linear): def _xavier_init(linear):
""" """

View File

@ -4,13 +4,15 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import math
from typing import Tuple from typing import Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Parameter, init
class LinearWithRepeat(torch.nn.Linear): class LinearWithRepeat(torch.nn.Module):
""" """
if x has shape (..., k, n1) if x has shape (..., k, n1)
and y has shape (..., n2) and y has shape (..., n2)
@ -50,6 +52,40 @@ class LinearWithRepeat(torch.nn.Linear):
and sent that through the Linear. and sent that through the Linear.
""" """
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
"""
Copied from torch.nn.Linear.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(
torch.empty((out_features, in_features), **factory_kwargs)
)
if bias:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
"""
Copied from torch.nn.Linear.
"""
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: def forward(self, input: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
n1 = input[0].shape[-1] n1 = input[0].shape[-1]
output1 = F.linear(input[0], self.weight[:, :n1], self.bias) output1 = F.linear(input[0], self.weight[:, :n1], self.bias)

View File

@ -73,8 +73,8 @@ from .points import (
from .utils import ( from .utils import (
TensorProperties, TensorProperties,
convert_to_tensors_and_broadcast, convert_to_tensors_and_broadcast,
ndc_to_grid_sample_coords,
ndc_grid_sample, ndc_grid_sample,
ndc_to_grid_sample_coords,
) )

View File

@ -8,7 +8,7 @@
import copy import copy
import inspect import inspect
import warnings import warnings
from typing import Any, Optional, Union, Tuple from typing import Any, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch

View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.common.linear_with_repeat import LinearWithRepeat
class TestLinearWithRepeat(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
def test_simple(self):
x = torch.rand(4, 6, 7, 3)
y = torch.rand(4, 6, 4)
linear = torch.nn.Linear(7, 8)
torch.nn.init.xavier_uniform_(linear.weight.data)
linear.bias.data.uniform_()
equivalent = torch.cat([x, y.unsqueeze(-2).expand(4, 6, 7, 4)], dim=-1)
expected = linear.forward(equivalent)
linear_with_repeat = LinearWithRepeat(7, 8)
linear_with_repeat.load_state_dict(linear.state_dict())
actual = linear_with_repeat.forward((x, y))
self.assertClose(actual, expected, rtol=1e-4)

View File

@ -12,16 +12,16 @@ import torch
from common_testing import TestCaseMixin from common_testing import TestCaseMixin
from pytorch3d.ops import eyes from pytorch3d.ops import eyes
from pytorch3d.renderer import ( from pytorch3d.renderer import (
PerspectiveCameras,
AlphaCompositor, AlphaCompositor,
PointsRenderer, PerspectiveCameras,
PointsRasterizationSettings, PointsRasterizationSettings,
PointsRasterizer, PointsRasterizer,
PointsRenderer,
) )
from pytorch3d.renderer.utils import ( from pytorch3d.renderer.utils import (
TensorProperties, TensorProperties,
ndc_to_grid_sample_coords,
ndc_grid_sample, ndc_grid_sample,
ndc_to_grid_sample_coords,
) )
from pytorch3d.structures import Pointclouds from pytorch3d.structures import Pointclouds