mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Fix parameters not wrapped with nn.Parameter, antialiasing compatibility
Summary: Some things fail if a parameter is not wraped; in particular, it prevented other tensors moving to GPU. Reviewed By: bottler Differential Revision: D40819932 fbshipit-source-id: a23b38ceacd7f0dc131cb0355fef1178e3e2f7fd
This commit is contained in:
		
							parent
							
								
									88620b6847
								
							
						
					
					
						commit
						f711c4bfe9
					
				@ -75,7 +75,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
 | 
			
		||||
        shift: a scalar which is added to the scaled input before performing
 | 
			
		||||
            the operation. Defaults to 0.
 | 
			
		||||
        operation: which operation to perform on the transformed input. Options are:
 | 
			
		||||
            `relu`, `softplus`, `sigmoid` and `identity`. Defaults to `identity`.
 | 
			
		||||
            `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`. Defaults to `IDENTITY`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    scale: float = 1
 | 
			
		||||
@ -91,7 +91,7 @@ class ElementwiseDecoder(DecoderFunctionBase):
 | 
			
		||||
            DecoderActivation.IDENTITY,
 | 
			
		||||
        ]:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "`operation` can only be `relu`, `softplus`, `sigmoid` or identity."
 | 
			
		||||
                "`operation` can only be `RELU`, `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
@ -165,22 +165,18 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        if self.last_activation not in [
 | 
			
		||||
            DecoderActivation.RELU,
 | 
			
		||||
            DecoderActivation.SOFTPLUS,
 | 
			
		||||
            DecoderActivation.SIGMOID,
 | 
			
		||||
            DecoderActivation.IDENTITY,
 | 
			
		||||
        ]:
 | 
			
		||||
        try:
 | 
			
		||||
            last_activation = {
 | 
			
		||||
                DecoderActivation.RELU: torch.nn.ReLU(True),
 | 
			
		||||
                DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
 | 
			
		||||
                DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
 | 
			
		||||
                DecoderActivation.IDENTITY: torch.nn.Identity(),
 | 
			
		||||
            }[self.last_activation]
 | 
			
		||||
        except KeyError as e:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "`last_activation` can only be `relu`,"
 | 
			
		||||
                " `softplus`, `sigmoid` or identity."
 | 
			
		||||
            )
 | 
			
		||||
        last_activation = {
 | 
			
		||||
            DecoderActivation.RELU: torch.nn.ReLU(True),
 | 
			
		||||
            DecoderActivation.SOFTPLUS: torch.nn.Softplus(),
 | 
			
		||||
            DecoderActivation.SIGMOID: torch.nn.Sigmoid(),
 | 
			
		||||
            DecoderActivation.IDENTITY: torch.nn.Identity(),
 | 
			
		||||
        }[self.last_activation]
 | 
			
		||||
                "`last_activation` can only be `RELU`,"
 | 
			
		||||
                " `SOFTPLUS`, `SIGMOID` or `IDENTITY`."
 | 
			
		||||
            ) from e
 | 
			
		||||
 | 
			
		||||
        layers = []
 | 
			
		||||
        skip_affine_layers = []
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,12 @@ these classes.
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from collections.abc import Mapping
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
 | 
			
		||||
 | 
			
		||||
from distutils.version import LooseVersion
 | 
			
		||||
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from omegaconf import DictConfig
 | 
			
		||||
@ -67,7 +70,9 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    padding: str = "zeros"
 | 
			
		||||
    mode: str = "bilinear"
 | 
			
		||||
    n_features: int = 1
 | 
			
		||||
    resolution_changes: Dict[int, List[int]] = field(
 | 
			
		||||
    # return the line below once we drop OmegaConf 2.1 support
 | 
			
		||||
    # resolution_changes: Dict[int, List[int]] = field(
 | 
			
		||||
    resolution_changes: Dict[int, Any] = field(
 | 
			
		||||
        default_factory=lambda: {0: [128, 128, 128]}
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -212,6 +217,13 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
                + "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
 | 
			
		||||
 | 
			
		||||
        if antialias and not interpolate_has_antialias:
 | 
			
		||||
            warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")
 | 
			
		||||
 | 
			
		||||
        interp_kwargs = {"antialias": antialias} if interpolate_has_antialias else {}
 | 
			
		||||
 | 
			
		||||
        def change_individual_resolution(tensor, wanted_resolution):
 | 
			
		||||
            if mode == "linear":
 | 
			
		||||
                n_dim = len(wanted_resolution)
 | 
			
		||||
@ -223,8 +235,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
                size=wanted_resolution,
 | 
			
		||||
                mode=new_mode,
 | 
			
		||||
                align_corners=align_corners,
 | 
			
		||||
                antialias=antialias,
 | 
			
		||||
                recompute_scale_factor=False,
 | 
			
		||||
                **interp_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if epoch is not None:
 | 
			
		||||
@ -880,7 +892,14 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
        if self.hold_voxel_grid_as_parameters:
 | 
			
		||||
            # pyre-ignore [16]
 | 
			
		||||
            self.params = torch.nn.ParameterDict(vars(params))
 | 
			
		||||
            # Nones are converted to empty tensors by Parameter()
 | 
			
		||||
            self.params = torch.nn.ParameterDict(
 | 
			
		||||
                {
 | 
			
		||||
                    k: torch.nn.Parameter(val)
 | 
			
		||||
                    for k, val in vars(params).items()
 | 
			
		||||
                    if val is not None
 | 
			
		||||
                }
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            # Torch Module to hold parameters since they can only be registered
 | 
			
		||||
            # at object level.
 | 
			
		||||
@ -1011,7 +1030,11 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore [16]
 | 
			
		||||
        self.params = torch.nn.ParameterDict(
 | 
			
		||||
            {k: v for k, v in vars(grid_values).items()}
 | 
			
		||||
            {
 | 
			
		||||
                k: torch.nn.Parameter(val)
 | 
			
		||||
                for k, val in vars(grid_values).items()
 | 
			
		||||
                if val is not None
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        # New center of voxel grid is the middle point between max and min points.
 | 
			
		||||
        self.translation = tuple((max_point + min_point) / 2)
 | 
			
		||||
 | 
			
		||||
@ -527,10 +527,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def decoder_density_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
    def decoder_density_tweak_args(cls, type_, args: DictConfig) -> None:
 | 
			
		||||
        args.pop("input_dim", None)
 | 
			
		||||
 | 
			
		||||
    def create_decoder_density_impl(self, type, args: DictConfig) -> None:
 | 
			
		||||
    def create_decoder_density_impl(self, type_, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Decoding functions come after harmonic embedding and voxel grid. In order to not
 | 
			
		||||
        calculate the input dimension of the decoder in the config file this function
 | 
			
		||||
@ -548,7 +548,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            embedder_args["append_input"],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        cls = registry.get(DecoderFunctionBase, type)
 | 
			
		||||
        cls = registry.get(DecoderFunctionBase, type_)
 | 
			
		||||
        need_input_dim = any(field.name == "input_dim" for field in fields(cls))
 | 
			
		||||
        if need_input_dim:
 | 
			
		||||
            self.decoder_density = cls(input_dim=input_dim, **args)
 | 
			
		||||
@ -556,10 +556,10 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            self.decoder_density = cls(**args)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def decoder_color_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
    def decoder_color_tweak_args(cls, type_, args: DictConfig) -> None:
 | 
			
		||||
        args.pop("input_dim", None)
 | 
			
		||||
 | 
			
		||||
    def create_decoder_color_impl(self, type, args: DictConfig) -> None:
 | 
			
		||||
    def create_decoder_color_impl(self, type_, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Decoding functions come after harmonic embedding and voxel grid. In order to not
 | 
			
		||||
        calculate the input dimension of the decoder in the config file this function
 | 
			
		||||
@ -587,7 +587,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        input_dim = input_dim0 + input_dim1
 | 
			
		||||
 | 
			
		||||
        cls = registry.get(DecoderFunctionBase, type)
 | 
			
		||||
        cls = registry.get(DecoderFunctionBase, type_)
 | 
			
		||||
        need_input_dim = any(field.name == "input_dim" for field in fields(cls))
 | 
			
		||||
        if need_input_dim:
 | 
			
		||||
            self.decoder_color = cls(input_dim=input_dim, **args)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user