Fix parameters not wrapped with nn.Parameter, antialiasing compatibility

Summary: Some things fail if a parameter is not wraped; in particular, it prevented other tensors moving to GPU.

Reviewed By: bottler

Differential Revision: D40819932

fbshipit-source-id: a23b38ceacd7f0dc131cb0355fef1178e3e2f7fd
This commit is contained in:
Roman Shapovalov 2022-10-31 01:43:00 -07:00 committed by Facebook GitHub Bot
parent 88620b6847
commit f711c4bfe9
3 changed files with 47 additions and 28 deletions

View File

@ -75,7 +75,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
shift: a scalar which is added to the scaled input before performing
the operation. Defaults to 0.
operation: which operation to perform on the transformed input. Options are:
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
`RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`.
"""
scale: float = 1
@ -91,7 +91,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
DecoderActivation.IDENTITY,
]:
raise ValueError(
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
"`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
)
def forward(
@ -165,22 +165,18 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
def __post_init__(self):
super().__init__()
if self.last_activation not in [
DecoderActivation.RELU,
DecoderActivation.SOFTPLUS,
DecoderActivation.SIGMOID,
DecoderActivation.IDENTITY,
]:
raise ValueError(
"`last_activation` can only be `relu`,"
" `softplus`, `sigmoid` or identity."
)
try:
last_activation = {
DecoderActivation.RELU: torch.nn.ReLU(True),
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
DecoderActivation.IDENTITY: torch.nn.Identity(),
}[self.last_activation]
except KeyError as e:
raise ValueError(
"`last_activation` can only be `RELU`,"
" `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
) from e
layers = []
skip_affine_layers = []

View File

@ -15,9 +15,12 @@ these classes.
"""
import warnings
from collections.abc import Mapping
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
from distutils.version import LooseVersion
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
import torch
from omegaconf import DictConfig
@ -67,7 +70,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
padding: str = "zeros"
mode: str = "bilinear"
n_features: int = 1
resolution_changes: Dict[int, List[int]] = field(
# return the line below once we drop OmegaConf 2.1 support
# resolution_changes: Dict[int, List[int]] = field(
resolution_changes: Dict[int, Any] = field(
default_factory=lambda: {0: [128, 128, 128]}
)
@ -212,6 +217,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
)
interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
if antialias and not interpolate_has_antialias:
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")
interp_kwargs = {"antialias": antialias} if interpolate_has_antialias else {}
def change_individual_resolution(tensor, wanted_resolution):
if mode == "linear":
n_dim = len(wanted_resolution)
@ -223,8 +235,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
size=wanted_resolution,
mode=new_mode,
align_corners=align_corners,
antialias=antialias,
recompute_scale_factor=False,
**interp_kwargs,
)
if epoch is not None:
@ -880,7 +892,14 @@ class VoxelGridModule(Configurable, torch.nn.Module):
"""
if self.hold_voxel_grid_as_parameters:
# pyre-ignore [16]
self.params = torch.nn.ParameterDict(vars(params))
# Nones are converted to empty tensors by Parameter()
self.params = torch.nn.ParameterDict(
{
k: torch.nn.Parameter(val)
for k, val in vars(params).items()
if val is not None
}
)
else:
# Torch Module to hold parameters since they can only be registered
# at object level.
@ -1011,7 +1030,11 @@ class VoxelGridModule(Configurable, torch.nn.Module):
)
# pyre-ignore [16]
self.params = torch.nn.ParameterDict(
{k: v for k, v in vars(grid_values).items()}
{
k: torch.nn.Parameter(val)
for k, val in vars(grid_values).items()
if val is not None
}
)
# New center of voxel grid is the middle point between max and min points.
self.translation = tuple((max_point + min_point) / 2)

View File

@ -527,10 +527,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
return False
@classmethod
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
def decoder_density_tweak_args(cls, type_, args: DictConfig) -> None:
args.pop("input_dim", None)
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
def create_decoder_density_impl(self, type_, args: DictConfig) -> None:
"""
Decoding functions come after harmonic embedding and voxel grid. In order to not
calculate the input dimension of the decoder in the config file this function
@ -548,7 +548,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
embedder_args["append_input"],
)
cls = registry.get(DecoderFunctionBase, type)
cls = registry.get(DecoderFunctionBase, type_)
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
if need_input_dim:
self.decoder_density = cls(input_dim=input_dim, **args)
@ -556,10 +556,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
self.decoder_density = cls(**args)
@classmethod
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
def decoder_color_tweak_args(cls, type_, args: DictConfig) -> None:
args.pop("input_dim", None)
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
def create_decoder_color_impl(self, type_, args: DictConfig) -> None:
"""
Decoding functions come after harmonic embedding and voxel grid. In order to not
calculate the input dimension of the decoder in the config file this function
@ -587,7 +587,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
input_dim = input_dim0 + input_dim1
cls = registry.get(DecoderFunctionBase, type)
cls = registry.get(DecoderFunctionBase, type_)
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
if need_input_dim:
self.decoder_color = cls(input_dim=input_dim, **args)