mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Fix parameters not wrapped with nn.Parameter, antialiasing compatibility
Summary: Some things fail if a parameter is not wraped; in particular, it prevented other tensors moving to GPU. Reviewed By: bottler Differential Revision: D40819932 fbshipit-source-id: a23b38ceacd7f0dc131cb0355fef1178e3e2f7fd
This commit is contained in:
parent
88620b6847
commit
f711c4bfe9
@ -75,7 +75,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
||||
shift: a scalar which is added to the scaled input before performing
|
||||
the operation. Defaults to 0.
|
||||
operation: which operation to perform on the transformed input. Options are:
|
||||
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
|
||||
`RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`.
|
||||
"""
|
||||
|
||||
scale: float = 1
|
||||
@ -91,7 +91,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
||||
DecoderActivation.IDENTITY,
|
||||
]:
|
||||
raise ValueError(
|
||||
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
|
||||
"`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -165,22 +165,18 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.last_activation not in [
|
||||
DecoderActivation.RELU,
|
||||
DecoderActivation.SOFTPLUS,
|
||||
DecoderActivation.SIGMOID,
|
||||
DecoderActivation.IDENTITY,
|
||||
]:
|
||||
try:
|
||||
last_activation = {
|
||||
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
|
||||
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
|
||||
DecoderActivation.IDENTITY: torch.nn.Identity(),
|
||||
}[self.last_activation]
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"`last_activation` can only be `relu`,"
|
||||
" `softplus`, `sigmoid` or identity."
|
||||
)
|
||||
last_activation = {
|
||||
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
|
||||
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
|
||||
DecoderActivation.IDENTITY: torch.nn.Identity(),
|
||||
}[self.last_activation]
|
||||
"`last_activation` can only be `RELU`,"
|
||||
" `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
|
||||
) from e
|
||||
|
||||
layers = []
|
||||
skip_affine_layers = []
|
||||
|
@ -15,9 +15,12 @@ these classes.
|
||||
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
@ -67,7 +70,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
padding: str = "zeros"
|
||||
mode: str = "bilinear"
|
||||
n_features: int = 1
|
||||
resolution_changes: Dict[int, List[int]] = field(
|
||||
# return the line below once we drop OmegaConf 2.1 support
|
||||
# resolution_changes: Dict[int, List[int]] = field(
|
||||
resolution_changes: Dict[int, Any] = field(
|
||||
default_factory=lambda: {0: [128, 128, 128]}
|
||||
)
|
||||
|
||||
@ -212,6 +217,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
|
||||
)
|
||||
|
||||
interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
|
||||
|
||||
if antialias and not interpolate_has_antialias:
|
||||
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")
|
||||
|
||||
interp_kwargs = {"antialias": antialias} if interpolate_has_antialias else {}
|
||||
|
||||
def change_individual_resolution(tensor, wanted_resolution):
|
||||
if mode == "linear":
|
||||
n_dim = len(wanted_resolution)
|
||||
@ -223,8 +235,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
size=wanted_resolution,
|
||||
mode=new_mode,
|
||||
align_corners=align_corners,
|
||||
antialias=antialias,
|
||||
recompute_scale_factor=False,
|
||||
**interp_kwargs,
|
||||
)
|
||||
|
||||
if epoch is not None:
|
||||
@ -880,7 +892,14 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
"""
|
||||
if self.hold_voxel_grid_as_parameters:
|
||||
# pyre-ignore [16]
|
||||
self.params = torch.nn.ParameterDict(vars(params))
|
||||
# Nones are converted to empty tensors by Parameter()
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{
|
||||
k: torch.nn.Parameter(val)
|
||||
for k, val in vars(params).items()
|
||||
if val is not None
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Torch Module to hold parameters since they can only be registered
|
||||
# at object level.
|
||||
@ -1011,7 +1030,11 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
)
|
||||
# pyre-ignore [16]
|
||||
self.params = torch.nn.ParameterDict(
|
||||
{k: v for k, v in vars(grid_values).items()}
|
||||
{
|
||||
k: torch.nn.Parameter(val)
|
||||
for k, val in vars(grid_values).items()
|
||||
if val is not None
|
||||
}
|
||||
)
|
||||
# New center of voxel grid is the middle point between max and min points.
|
||||
self.translation = tuple((max_point + min_point) / 2)
|
||||
|
@ -527,10 +527,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
def decoder_density_tweak_args(cls, type_, args: DictConfig) -> None:
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
|
||||
def create_decoder_density_impl(self, type_, args: DictConfig) -> None:
|
||||
"""
|
||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||
calculate the input dimension of the decoder in the config file this function
|
||||
@ -548,7 +548,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
embedder_args["append_input"],
|
||||
)
|
||||
|
||||
cls = registry.get(DecoderFunctionBase, type)
|
||||
cls = registry.get(DecoderFunctionBase, type_)
|
||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||
if need_input_dim:
|
||||
self.decoder_density = cls(input_dim=input_dim, **args)
|
||||
@ -556,10 +556,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
self.decoder_density = cls(**args)
|
||||
|
||||
@classmethod
|
||||
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
def decoder_color_tweak_args(cls, type_, args: DictConfig) -> None:
|
||||
args.pop("input_dim", None)
|
||||
|
||||
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
|
||||
def create_decoder_color_impl(self, type_, args: DictConfig) -> None:
|
||||
"""
|
||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||
calculate the input dimension of the decoder in the config file this function
|
||||
@ -587,7 +587,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
||||
|
||||
input_dim = input_dim0 + input_dim1
|
||||
|
||||
cls = registry.get(DecoderFunctionBase, type)
|
||||
cls = registry.get(DecoderFunctionBase, type_)
|
||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||
if need_input_dim:
|
||||
self.decoder_color = cls(input_dim=input_dim, **args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user