mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix parameters not wrapped with nn.Parameter, antialiasing compatibility
Summary: Some things fail if a parameter is not wraped; in particular, it prevented other tensors moving to GPU. Reviewed By: bottler Differential Revision: D40819932 fbshipit-source-id: a23b38ceacd7f0dc131cb0355fef1178e3e2f7fd
This commit is contained in:
parent
88620b6847
commit
f711c4bfe9
@ -75,7 +75,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
|||||||
shift: a scalar which is added to the scaled input before performing
|
shift: a scalar which is added to the scaled input before performing
|
||||||
the operation. Defaults to 0.
|
the operation. Defaults to 0.
|
||||||
operation: which operation to perform on the transformed input. Options are:
|
operation: which operation to perform on the transformed input. Options are:
|
||||||
`relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
|
`RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
scale: float = 1
|
scale: float = 1
|
||||||
@ -91,7 +91,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
|
|||||||
DecoderActivation.IDENTITY,
|
DecoderActivation.IDENTITY,
|
||||||
]:
|
]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
|
"`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -165,22 +165,18 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if self.last_activation not in [
|
try:
|
||||||
DecoderActivation.RELU,
|
last_activation = {
|
||||||
DecoderActivation.SOFTPLUS,
|
DecoderActivation.RELU: torch.nn.ReLU(True),
|
||||||
DecoderActivation.SIGMOID,
|
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
|
||||||
DecoderActivation.IDENTITY,
|
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
|
||||||
]:
|
DecoderActivation.IDENTITY: torch.nn.Identity(),
|
||||||
|
}[self.last_activation]
|
||||||
|
except KeyError as e:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`last_activation` can only be `relu`,"
|
"`last_activation` can only be `RELU`,"
|
||||||
" `softplus`, `sigmoid` or identity."
|
" `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
|
||||||
)
|
) from e
|
||||||
last_activation = {
|
|
||||||
DecoderActivation.RELU: torch.nn.ReLU(True),
|
|
||||||
DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
|
|
||||||
DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
|
|
||||||
DecoderActivation.IDENTITY: torch.nn.Identity(),
|
|
||||||
}[self.last_activation]
|
|
||||||
|
|
||||||
layers = []
|
layers = []
|
||||||
skip_affine_layers = []
|
skip_affine_layers = []
|
||||||
|
@ -15,9 +15,12 @@ these classes.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import warnings
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
@ -67,7 +70,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
padding: str = "zeros"
|
padding: str = "zeros"
|
||||||
mode: str = "bilinear"
|
mode: str = "bilinear"
|
||||||
n_features: int = 1
|
n_features: int = 1
|
||||||
resolution_changes: Dict[int, List[int]] = field(
|
# return the line below once we drop OmegaConf 2.1 support
|
||||||
|
# resolution_changes: Dict[int, List[int]] = field(
|
||||||
|
resolution_changes: Dict[int, Any] = field(
|
||||||
default_factory=lambda: {0: [128, 128, 128]}
|
default_factory=lambda: {0: [128, 128, 128]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -212,6 +217,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
|
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
|
||||||
|
|
||||||
|
if antialias and not interpolate_has_antialias:
|
||||||
|
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")
|
||||||
|
|
||||||
|
interp_kwargs = {"antialias": antialias} if interpolate_has_antialias else {}
|
||||||
|
|
||||||
def change_individual_resolution(tensor, wanted_resolution):
|
def change_individual_resolution(tensor, wanted_resolution):
|
||||||
if mode == "linear":
|
if mode == "linear":
|
||||||
n_dim = len(wanted_resolution)
|
n_dim = len(wanted_resolution)
|
||||||
@ -223,8 +235,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
|||||||
size=wanted_resolution,
|
size=wanted_resolution,
|
||||||
mode=new_mode,
|
mode=new_mode,
|
||||||
align_corners=align_corners,
|
align_corners=align_corners,
|
||||||
antialias=antialias,
|
|
||||||
recompute_scale_factor=False,
|
recompute_scale_factor=False,
|
||||||
|
**interp_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if epoch is not None:
|
if epoch is not None:
|
||||||
@ -880,7 +892,14 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
if self.hold_voxel_grid_as_parameters:
|
if self.hold_voxel_grid_as_parameters:
|
||||||
# pyre-ignore [16]
|
# pyre-ignore [16]
|
||||||
self.params = torch.nn.ParameterDict(vars(params))
|
# Nones are converted to empty tensors by Parameter()
|
||||||
|
self.params = torch.nn.ParameterDict(
|
||||||
|
{
|
||||||
|
k: torch.nn.Parameter(val)
|
||||||
|
for k, val in vars(params).items()
|
||||||
|
if val is not None
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Torch Module to hold parameters since they can only be registered
|
# Torch Module to hold parameters since they can only be registered
|
||||||
# at object level.
|
# at object level.
|
||||||
@ -1011,7 +1030,11 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
|||||||
)
|
)
|
||||||
# pyre-ignore [16]
|
# pyre-ignore [16]
|
||||||
self.params = torch.nn.ParameterDict(
|
self.params = torch.nn.ParameterDict(
|
||||||
{k: v for k, v in vars(grid_values).items()}
|
{
|
||||||
|
k: torch.nn.Parameter(val)
|
||||||
|
for k, val in vars(grid_values).items()
|
||||||
|
if val is not None
|
||||||
|
}
|
||||||
)
|
)
|
||||||
# New center of voxel grid is the middle point between max and min points.
|
# New center of voxel grid is the middle point between max and min points.
|
||||||
self.translation = tuple((max_point + min_point) / 2)
|
self.translation = tuple((max_point + min_point) / 2)
|
||||||
|
@ -527,10 +527,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
|
def decoder_density_tweak_args(cls, type_, args: DictConfig) -> None:
|
||||||
args.pop("input_dim", None)
|
args.pop("input_dim", None)
|
||||||
|
|
||||||
def create_decoder_density_impl(self, type, args: DictConfig) -> None:
|
def create_decoder_density_impl(self, type_, args: DictConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||||
calculate the input dimension of the decoder in the config file this function
|
calculate the input dimension of the decoder in the config file this function
|
||||||
@ -548,7 +548,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
embedder_args["append_input"],
|
embedder_args["append_input"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cls = registry.get(DecoderFunctionBase, type)
|
cls = registry.get(DecoderFunctionBase, type_)
|
||||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||||
if need_input_dim:
|
if need_input_dim:
|
||||||
self.decoder_density = cls(input_dim=input_dim, **args)
|
self.decoder_density = cls(input_dim=input_dim, **args)
|
||||||
@ -556,10 +556,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
self.decoder_density = cls(**args)
|
self.decoder_density = cls(**args)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
|
def decoder_color_tweak_args(cls, type_, args: DictConfig) -> None:
|
||||||
args.pop("input_dim", None)
|
args.pop("input_dim", None)
|
||||||
|
|
||||||
def create_decoder_color_impl(self, type, args: DictConfig) -> None:
|
def create_decoder_color_impl(self, type_, args: DictConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
Decoding functions come after harmonic embedding and voxel grid. In order to not
|
||||||
calculate the input dimension of the decoder in the config file this function
|
calculate the input dimension of the decoder in the config file this function
|
||||||
@ -587,7 +587,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
|
|||||||
|
|
||||||
input_dim = input_dim0 + input_dim1
|
input_dim = input_dim0 + input_dim1
|
||||||
|
|
||||||
cls = registry.get(DecoderFunctionBase, type)
|
cls = registry.get(DecoderFunctionBase, type_)
|
||||||
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
need_input_dim = any(field.name == "input_dim" for field in fields(cls))
|
||||||
if need_input_dim:
|
if need_input_dim:
|
||||||
self.decoder_color = cls(input_dim=input_dim, **args)
|
self.decoder_color = cls(input_dim=input_dim, **args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user