mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Remove unused pyre-ignore or pyre-fixme
				
					
				
			Reviewed By: bottler Differential Revision: D47223471 fbshipit-source-id: 8bdabf2a69dd7aec7202141122a9c69220ba7ef1
This commit is contained in:
		
							parent
							
								
									f68371d398
								
							
						
					
					
						commit
						4e7715ce66
					
				@ -316,7 +316,7 @@ def adjust_camera_to_bbox_crop_(
 | 
			
		||||
 | 
			
		||||
    focal_length_px, principal_point_px = _convert_ndc_to_pixels(
 | 
			
		||||
        camera.focal_length[0],
 | 
			
		||||
        camera.principal_point[0],  # pyre-ignore
 | 
			
		||||
        camera.principal_point[0],
 | 
			
		||||
        image_size_wh,
 | 
			
		||||
    )
 | 
			
		||||
    principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
 | 
			
		||||
@ -328,7 +328,7 @@ def adjust_camera_to_bbox_crop_(
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    camera.focal_length = focal_length[None]
 | 
			
		||||
    camera.principal_point = principal_point_cropped[None]  # pyre-ignore
 | 
			
		||||
    camera.principal_point = principal_point_cropped[None]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def adjust_camera_to_image_scale_(
 | 
			
		||||
@ -338,7 +338,7 @@ def adjust_camera_to_image_scale_(
 | 
			
		||||
) -> PerspectiveCameras:
 | 
			
		||||
    focal_length_px, principal_point_px = _convert_ndc_to_pixels(
 | 
			
		||||
        camera.focal_length[0],
 | 
			
		||||
        camera.principal_point[0],  # pyre-ignore
 | 
			
		||||
        camera.principal_point[0],
 | 
			
		||||
        original_size_wh,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
 | 
			
		||||
        return (img - self._resnet_mean) / self._resnet_std
 | 
			
		||||
 | 
			
		||||
    def get_feat_dims(self) -> int:
 | 
			
		||||
        # pyre-fixme[29]
 | 
			
		||||
        return sum(self._feat_dim.values())
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
@ -181,13 +180,8 @@ class ResNetFeatureExtractor(FeatureExtractorBase):
 | 
			
		||||
                imgs_normed = self._resnet_normalize_image(imgs_resized)
 | 
			
		||||
            else:
 | 
			
		||||
                imgs_normed = imgs_resized
 | 
			
		||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.modules.module.Module]`
 | 
			
		||||
            #  is not a function.
 | 
			
		||||
            feats = self.stem(imgs_normed)
 | 
			
		||||
            # pyre-fixme[6]: For 1st param expected `Iterable[Variable[_T1]]` but
 | 
			
		||||
            #  got `Union[Tensor, Module]`.
 | 
			
		||||
            # pyre-fixme[6]: For 2nd param expected `Iterable[Variable[_T2]]` but
 | 
			
		||||
            #  got `Union[Tensor, Module]`.
 | 
			
		||||
            for stage, (layer, proj) in enumerate(zip(self.layers, self.proj_layers)):
 | 
			
		||||
                feats = layer(feats)
 | 
			
		||||
                # just a sanity check below
 | 
			
		||||
 | 
			
		||||
@ -463,10 +463,6 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
            )
 | 
			
		||||
        custom_args["global_code"] = global_code
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
 | 
			
		||||
        #  torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
 | 
			
		||||
        #  torch.nn.Module]` is not a function.
 | 
			
		||||
        for func in self._implicit_functions:
 | 
			
		||||
            func.bind_args(**custom_args)
 | 
			
		||||
 | 
			
		||||
@ -489,10 +485,6 @@ class GenericModel(ImplicitronModelBase):  # pyre-ignore: 13
 | 
			
		||||
        # Unbind the custom arguments to prevent pytorch from storing
 | 
			
		||||
        # large buffers of intermediate results due to points in the
 | 
			
		||||
        # bound arguments.
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__iter__)[[Named(self,
 | 
			
		||||
        #  torch.Tensor)], typing.Iterator[typing.Any]], torch.Tensor], torch.Tensor,
 | 
			
		||||
        #  torch.nn.Module]` is not a function.
 | 
			
		||||
        for func in self._implicit_functions:
 | 
			
		||||
            func.unbind_args()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,7 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
        return key_map
 | 
			
		||||
 | 
			
		||||
    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
 | 
			
		||||
        return (self._autodecoder_codes.weight**2).mean()  # pyre-ignore[16]
 | 
			
		||||
        return (self._autodecoder_codes.weight**2).mean()
 | 
			
		||||
 | 
			
		||||
    def get_encoding_dim(self) -> int:
 | 
			
		||||
        return self.encoding_dim
 | 
			
		||||
@ -93,7 +93,6 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
                # pyre-fixme[9]: x has type `Union[List[str], LongTensor]`; used as
 | 
			
		||||
                #  `Tensor`.
 | 
			
		||||
                x = torch.tensor(
 | 
			
		||||
                    # pyre-ignore[29]
 | 
			
		||||
                    [self._key_map[elem] for elem in x],
 | 
			
		||||
                    dtype=torch.long,
 | 
			
		||||
                    device=next(self.parameters()).device,
 | 
			
		||||
@ -101,7 +100,6 @@ class Autodecoder(Configurable, torch.nn.Module):
 | 
			
		||||
            except StopIteration:
 | 
			
		||||
                raise ValueError("Not enough n_instances in the autodecoder") from None
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        return self._autodecoder_codes(x)
 | 
			
		||||
 | 
			
		||||
    def _load_key_map_hook(
 | 
			
		||||
 | 
			
		||||
@ -119,7 +119,7 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module):
 | 
			
		||||
        if frame_timestamp.shape[-1] != 1:
 | 
			
		||||
            raise ValueError("Frame timestamp's last dimensions should be one.")
 | 
			
		||||
        time = frame_timestamp / self.time_divisor
 | 
			
		||||
        return self._harmonic_embedding(time)  # pyre-ignore: 29
 | 
			
		||||
        return self._harmonic_embedding(time)
 | 
			
		||||
 | 
			
		||||
    def calculate_squared_encoding_norm(self) -> Optional[torch.Tensor]:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
@ -230,14 +230,9 @@ class MLPWithInputSkips(Configurable, torch.nn.Module):
 | 
			
		||||
            # if the skip tensor is None, we use `x` instead.
 | 
			
		||||
            z = x
 | 
			
		||||
        skipi = 0
 | 
			
		||||
        # pyre-fixme[6]: For 1st param expected `Iterable[Variable[_T]]` but got
 | 
			
		||||
        #  `Union[Tensor, Module]`.
 | 
			
		||||
        for li, layer in enumerate(self.mlp):
 | 
			
		||||
            # pyre-fixme[58]: `in` is not supported for right operand type
 | 
			
		||||
            #  `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`.
 | 
			
		||||
            if li in self._input_skips:
 | 
			
		||||
                if self._skip_affine_trans:
 | 
			
		||||
                    # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._Te...
 | 
			
		||||
                    y = self._apply_affine_layer(self.skip_affines[skipi], y, z)
 | 
			
		||||
                else:
 | 
			
		||||
                    y = torch.cat((y, z), dim=-1)
 | 
			
		||||
 | 
			
		||||
@ -141,16 +141,11 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            self.embed_fn is None and fun_viewpool is None and global_code is None
 | 
			
		||||
        ):
 | 
			
		||||
            return torch.tensor(
 | 
			
		||||
                [],
 | 
			
		||||
                device=rays_points_world.device,
 | 
			
		||||
                dtype=rays_points_world.dtype
 | 
			
		||||
                # pyre-fixme[6]: For 2nd param expected `int` but got `Union[Module,
 | 
			
		||||
                #  Tensor]`.
 | 
			
		||||
                [], device=rays_points_world.device, dtype=rays_points_world.dtype
 | 
			
		||||
            ).view(0, self.out_dim)
 | 
			
		||||
 | 
			
		||||
        embeddings = []
 | 
			
		||||
        if self.embed_fn is not None:
 | 
			
		||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
            embeddings.append(self.embed_fn(rays_points_world))
 | 
			
		||||
 | 
			
		||||
        if fun_viewpool is not None:
 | 
			
		||||
@ -169,17 +164,13 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        embedding = torch.cat(embeddings, dim=-1)
 | 
			
		||||
        x = embedding
 | 
			
		||||
        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__s...
 | 
			
		||||
        for layer_idx in range(self.num_layers - 1):
 | 
			
		||||
            if layer_idx in self.skip_in:
 | 
			
		||||
                x = torch.cat([x, embedding], dim=-1) / 2**0.5
 | 
			
		||||
 | 
			
		||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
            x = self.linear_layers[layer_idx](x)
 | 
			
		||||
 | 
			
		||||
            # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase...
 | 
			
		||||
            if layer_idx < self.num_layers - 2:
 | 
			
		||||
                # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
                x = self.softplus(x)
 | 
			
		||||
 | 
			
		||||
        return x
 | 
			
		||||
 | 
			
		||||
@ -113,10 +113,8 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        # Normalize the ray_directions to unit l2 norm.
 | 
			
		||||
        rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
 | 
			
		||||
        # Obtain the harmonic embedding of the normalized ray directions.
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        rays_embedding = self.harmonic_embedding_dir(rays_directions_normed)
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        return self.color_layer((self.intermediate_linear(features), rays_embedding))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
@ -171,7 +169,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        embeds = create_embeddings_for_implicit_function(
 | 
			
		||||
            xyz_world=rays_points_world,
 | 
			
		||||
            # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
 | 
			
		||||
            #  for 2nd param but got `Union[None, torch.Tensor, torch.nn.Module]`.
 | 
			
		||||
            xyz_embedding_function=self.harmonic_embedding_xyz
 | 
			
		||||
            if self.input_xyz
 | 
			
		||||
@ -183,14 +180,12 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # embeds.shape = [minibatch x n_src x n_rays x n_pts x self.n_harmonic_functions*6+3]
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        features = self.xyz_encoder(embeds)
 | 
			
		||||
        # features.shape = [minibatch x ... x self.n_hidden_neurons_xyz]
 | 
			
		||||
        # NNs operate on the flattenned rays; reshaping to the correct spatial size
 | 
			
		||||
        # TODO: maybe make the transformer work on non-flattened tensors to avoid this reshape
 | 
			
		||||
        features = features.reshape(*rays_points_world.shape[:-1], -1)
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        raw_densities = self.density_layer(features)
 | 
			
		||||
        # raw_densities.shape = [minibatch x ... x 1] in [0-1]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -101,8 +101,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        embeds = create_embeddings_for_implicit_function(
 | 
			
		||||
            xyz_world=rays_points_world,
 | 
			
		||||
            # pyre-fixme[6]: Expected `Optional[typing.Callable[..., typing.Any]]`
 | 
			
		||||
            #  for 2nd param but got `Union[torch.Tensor, torch.nn.Module]`.
 | 
			
		||||
            xyz_embedding_function=self._harmonic_embedding,
 | 
			
		||||
            global_code=global_code,
 | 
			
		||||
            fun_viewpool=fun_viewpool,
 | 
			
		||||
@ -112,7 +110,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        # Before running the network, we have to resize embeds to ndims=3,
 | 
			
		||||
        # otherwise the SRN layers consume huge amounts of memory.
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        raymarch_features = self._net(
 | 
			
		||||
            embeds.view(embeds.shape[0], -1, embeds.shape[-1])
 | 
			
		||||
        )
 | 
			
		||||
@ -167,9 +164,7 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
 | 
			
		||||
        # Normalize the ray_directions to unit l2 norm.
 | 
			
		||||
        rays_directions_normed = torch.nn.functional.normalize(rays_directions, dim=-1)
 | 
			
		||||
        # Obtain the harmonic embedding of the normalized ray directions.
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        rays_embedding = self._harmonic_embedding(rays_directions_normed)
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        return self._color_layer((features, rays_embedding))
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
@ -198,7 +193,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
 | 
			
		||||
                denoting the color of each ray point.
 | 
			
		||||
        """
 | 
			
		||||
        # raymarch_features.shape = [minibatch x ... x pts_per_ray x 3]
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        features = self._net(raymarch_features)
 | 
			
		||||
        # features.shape = [minibatch x ... x self.n_hidden_units]
 | 
			
		||||
 | 
			
		||||
@ -213,7 +207,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module):
 | 
			
		||||
        # NNs operate on the flattenned rays; reshaping to the correct spatial size
 | 
			
		||||
        features = features.reshape(*raymarch_features.shape[:-1], -1)
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        raw_densities = self._density_layer(features)
 | 
			
		||||
 | 
			
		||||
        rays_colors = self._get_colors(features, directions)
 | 
			
		||||
@ -274,7 +267,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
 | 
			
		||||
        srn_raymarch_function.
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        net = self._hypernet(global_code)
 | 
			
		||||
 | 
			
		||||
        # use the hyper-net generated network to instantiate the raymarch module
 | 
			
		||||
@ -310,7 +302,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module):
 | 
			
		||||
        # across LSTM iterations for the same global_code.
 | 
			
		||||
        if self.cached_srn_raymarch_function is None:
 | 
			
		||||
            # generate the raymarching network from the hypernet
 | 
			
		||||
            # pyre-fixme[16]: `SRNRaymarchHyperNet` has no attribute
 | 
			
		||||
            self.cached_srn_raymarch_function = self._run_hypernet(global_code)
 | 
			
		||||
        (srn_raymarch_function,) = cast(
 | 
			
		||||
            Tuple[SRNRaymarchFunction], self.cached_srn_raymarch_function
 | 
			
		||||
@ -337,7 +328,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
    def create_raymarch_function(self) -> None:
 | 
			
		||||
        self.raymarch_function = SRNRaymarchFunction(
 | 
			
		||||
            latent_dim=self.latent_dim,
 | 
			
		||||
            # pyre-ignore[32]
 | 
			
		||||
            **self.raymarch_function_args,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -395,7 +385,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        self.hypernet = SRNRaymarchHyperNet(
 | 
			
		||||
            latent_dim=self.latent_dim,
 | 
			
		||||
            latent_dim_hypernet=self.latent_dim_hypernet,
 | 
			
		||||
            # pyre-ignore[32]
 | 
			
		||||
            **self.hypernet_args,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -267,7 +267,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
                for name, tensor in vars(grid_values_with_wanted_resolution).items()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        return self.values_type(**params), True
 | 
			
		||||
 | 
			
		||||
    def get_resolution_change_epochs(self) -> Tuple[int, ...]:
 | 
			
		||||
@ -881,8 +880,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
            torch.Tensor of shape (..., n_features)
 | 
			
		||||
        """
 | 
			
		||||
        locator = self._get_volume_locator()
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
 | 
			
		||||
        #  torch.nn.modules.module.Module]` is not a function.
 | 
			
		||||
        grid_values = self.voxel_grid.values_type(**self.params)
 | 
			
		||||
        # voxel grids operate with extra n_grids dimension, which we fix to one
 | 
			
		||||
        return self.voxel_grid.evaluate_world(points[None], grid_values, locator)[0]
 | 
			
		||||
@ -896,8 +893,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
                replace current parameters
 | 
			
		||||
        """
 | 
			
		||||
        if self.hold_voxel_grid_as_parameters:
 | 
			
		||||
            # pyre-ignore [16]
 | 
			
		||||
            # Nones are converted to empty tensors by Parameter()
 | 
			
		||||
            self.params = torch.nn.ParameterDict(
 | 
			
		||||
                {
 | 
			
		||||
                    k: torch.nn.Parameter(val)
 | 
			
		||||
@ -948,7 +943,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
        Returns:
 | 
			
		||||
            True if parameter change has happened else False.
 | 
			
		||||
        """
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        grid_values = self.voxel_grid.values_type(**self.params)
 | 
			
		||||
        grid_values, change = self.voxel_grid.change_resolution(
 | 
			
		||||
            grid_values, epoch=epoch
 | 
			
		||||
@ -996,19 +990,16 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
        '''
 | 
			
		||||
        new_params = {}
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        for name in self.params:
 | 
			
		||||
            key = prefix + "params." + name
 | 
			
		||||
            if key in state_dict:
 | 
			
		||||
                new_params[name] = torch.zeros_like(state_dict[key])
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        self.set_voxel_grid_parameters(self.voxel_grid.values_type(**new_params))
 | 
			
		||||
 | 
			
		||||
    def get_device(self) -> torch.device:
 | 
			
		||||
        """
 | 
			
		||||
        Returns torch.device on which module parameters are located
 | 
			
		||||
        """
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        return next(val for val in self.params.values() if val is not None).device
 | 
			
		||||
 | 
			
		||||
    def crop_self(self, min_point: torch.Tensor, max_point: torch.Tensor) -> None:
 | 
			
		||||
@ -1024,7 +1015,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
            nothing
 | 
			
		||||
        """
 | 
			
		||||
        locator = self._get_volume_locator()
 | 
			
		||||
        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
 | 
			
		||||
        #  torch.nn.modules.module.Module]` is not a function.
 | 
			
		||||
        old_grid_values = self.voxel_grid.values_type(**self.params)
 | 
			
		||||
        new_grid_values = self.voxel_grid.crop_world(
 | 
			
		||||
@ -1033,7 +1023,6 @@ class VoxelGridModule(Configurable, torch.nn.Module):
 | 
			
		||||
        grid_values, _ = self.voxel_grid.change_resolution(
 | 
			
		||||
            new_grid_values, grid_values_with_wanted_resolution=old_grid_values
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore [16]
 | 
			
		||||
        self.params = torch.nn.ParameterDict(
 | 
			
		||||
            {
 | 
			
		||||
                k: torch.nn.Parameter(val)
 | 
			
		||||
 | 
			
		||||
@ -187,21 +187,16 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.voxel_grid_scaffold = self._create_voxel_grid_scaffold()
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.harmonic_embedder_xyz_density = HarmonicEmbedding(
 | 
			
		||||
            **self.harmonic_embedder_xyz_density_args
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.harmonic_embedder_xyz_color = HarmonicEmbedding(
 | 
			
		||||
            **self.harmonic_embedder_xyz_color_args
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.harmonic_embedder_dir_color = HarmonicEmbedding(
 | 
			
		||||
            **self.harmonic_embedder_dir_color_args
 | 
			
		||||
        )
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self._scaffold_ready = False
 | 
			
		||||
 | 
			
		||||
    def forward(
 | 
			
		||||
@ -252,7 +247,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        # ########## filter the points using the scaffold ########## #
 | 
			
		||||
        if self._scaffold_ready and self.scaffold_filter_points:
 | 
			
		||||
            with torch.no_grad():
 | 
			
		||||
                # pyre-ignore[29]
 | 
			
		||||
                non_empty_points = self.voxel_grid_scaffold(points)[..., 0] > 0
 | 
			
		||||
            points = points[non_empty_points]
 | 
			
		||||
            if len(points) == 0:
 | 
			
		||||
@ -364,7 +358,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
                feature dimensionality which `decoder_density` returns
 | 
			
		||||
        """
 | 
			
		||||
        embeds_density = self.voxel_grid_density(points)
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        harmonic_embedding_density = self.harmonic_embedder_xyz_density(embeds_density)
 | 
			
		||||
        # shape = [..., density_dim]
 | 
			
		||||
        return self.decoder_density(harmonic_embedding_density)
 | 
			
		||||
@ -407,13 +400,11 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        # ########## embed with the harmonic function ########## #
 | 
			
		||||
        # Obtain the harmonic embedding of the voxel grid output.
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        harmonic_embedding_color = self.harmonic_embedder_xyz_color(embeds_color)
 | 
			
		||||
 | 
			
		||||
        # Normalize the ray_directions to unit l2 norm.
 | 
			
		||||
        rays_directions_normed = torch.nn.functional.normalize(directions, dim=-1)
 | 
			
		||||
        # Obtain the harmonic embedding of the normalized ray directions.
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        harmonic_embedding_dir = self.harmonic_embedder_dir_color(
 | 
			
		||||
            rays_directions_normed
 | 
			
		||||
        )
 | 
			
		||||
@ -482,10 +473,8 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            an object inside, else False.
 | 
			
		||||
        """
 | 
			
		||||
        # find bounding box
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
 | 
			
		||||
        assert self._scaffold_ready, "Scaffold has to be calculated before cropping."
 | 
			
		||||
        # pyre-ignore[29]
 | 
			
		||||
        occupancy = self.voxel_grid_scaffold(points)[..., 0] > 0
 | 
			
		||||
        non_zero_idxs = torch.nonzero(occupancy)
 | 
			
		||||
        if len(non_zero_idxs) == 0:
 | 
			
		||||
@ -517,7 +506,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
        planes = []
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        points = self.voxel_grid_scaffold.get_grid_points(epoch=epoch)
 | 
			
		||||
 | 
			
		||||
        chunk_size = (
 | 
			
		||||
@ -537,9 +525,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
            stride=1,
 | 
			
		||||
        )
 | 
			
		||||
        occupancy_cube = density_cube > self.scaffold_empty_space_threshold
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.voxel_grid_scaffold.params["voxel_grid"] = occupancy_cube.float()
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self._scaffold_ready = True
 | 
			
		||||
 | 
			
		||||
        return False
 | 
			
		||||
@ -556,7 +542,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        decoding function to this value.
 | 
			
		||||
        """
 | 
			
		||||
        grid_args = self.voxel_grid_density_args
 | 
			
		||||
        # pyre-ignore[6]
 | 
			
		||||
        grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
 | 
			
		||||
 | 
			
		||||
        embedder_args = self.harmonic_embedder_xyz_density_args
 | 
			
		||||
@ -585,7 +570,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
        decoding function to this value.
 | 
			
		||||
        """
 | 
			
		||||
        grid_args = self.voxel_grid_color_args
 | 
			
		||||
        # pyre-ignore[6]
 | 
			
		||||
        grid_output_dim = VoxelGridModule.get_output_dim(grid_args)
 | 
			
		||||
 | 
			
		||||
        embedder_args = self.harmonic_embedder_xyz_color_args
 | 
			
		||||
@ -619,9 +603,7 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module):
 | 
			
		||||
                    `self.voxel_grid_density`
 | 
			
		||||
        """
 | 
			
		||||
        return VoxelGridModule(
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            extents=self.voxel_grid_density_args["extents"],
 | 
			
		||||
            # pyre-ignore[29]
 | 
			
		||||
            translation=self.voxel_grid_density_args["translation"],
 | 
			
		||||
            voxel_grid_class_type="FullResolutionVoxelGrid",
 | 
			
		||||
            hold_voxel_grid_as_parameters=False,
 | 
			
		||||
 | 
			
		||||
@ -134,7 +134,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            # run the lstm marcher
 | 
			
		||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
            state_h, state_c = self._lstm(
 | 
			
		||||
                raymarch_features.view(-1, raymarch_features.shape[-1]),
 | 
			
		||||
                states[-1],
 | 
			
		||||
@ -142,7 +141,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module):
 | 
			
		||||
            if state_h.requires_grad:
 | 
			
		||||
                state_h.register_hook(lambda x: x.clamp(min=-10, max=10))
 | 
			
		||||
            # predict the next step size
 | 
			
		||||
            # pyre-fixme[29]: `Union[torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
            signed_distance = self._out_layer(state_h).view(ray_bundle_t.lengths.shape)
 | 
			
		||||
            # log the lstm states
 | 
			
		||||
            states.append((state_h, state_c))
 | 
			
		||||
 | 
			
		||||
@ -171,7 +171,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
        """
 | 
			
		||||
        sample_mask = None
 | 
			
		||||
        if (
 | 
			
		||||
            # pyre-fixme[29]
 | 
			
		||||
            self._sampling_mode[evaluation_mode] == RenderSamplingMode.MASK_SAMPLE
 | 
			
		||||
            and mask is not None
 | 
			
		||||
        ):
 | 
			
		||||
@ -188,7 +187,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module):
 | 
			
		||||
            EvaluationMode.EVALUATION: self._evaluation_raysampler,
 | 
			
		||||
        }[evaluation_mode]
 | 
			
		||||
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        ray_bundle = raysampler(
 | 
			
		||||
            cameras=cameras,
 | 
			
		||||
            mask=sample_mask,
 | 
			
		||||
 | 
			
		||||
@ -170,9 +170,9 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
 | 
			
		||||
            rays_densities = torch.relu(rays_densities)
 | 
			
		||||
 | 
			
		||||
        weighted_densities = deltas * rays_densities
 | 
			
		||||
        capped_densities = self._capping_function(weighted_densities)  # pyre-ignore: 29
 | 
			
		||||
        capped_densities = self._capping_function(weighted_densities)
 | 
			
		||||
 | 
			
		||||
        rays_opacities = self._capping_function(  # pyre-ignore: 29
 | 
			
		||||
        rays_opacities = self._capping_function(
 | 
			
		||||
            torch.cumsum(weighted_densities, dim=-1)
 | 
			
		||||
        )
 | 
			
		||||
        opacities = rays_opacities[..., -1:]
 | 
			
		||||
@ -181,9 +181,7 @@ class AccumulativeRaymarcherBase(RaymarcherBase, torch.nn.Module):
 | 
			
		||||
        )
 | 
			
		||||
        absorption_shifted[..., : self.surface_thickness] = 1.0
 | 
			
		||||
 | 
			
		||||
        weights = self._weight_function(  # pyre-ignore: 29
 | 
			
		||||
            capped_densities, absorption_shifted
 | 
			
		||||
        )
 | 
			
		||||
        weights = self._weight_function(capped_densities, absorption_shifted)
 | 
			
		||||
        features = (weights[..., None] * rays_features).sum(dim=-2)
 | 
			
		||||
        depth = (weights * ray_lengths)[..., None].sum(dim=-2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -58,7 +58,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):  # pyre-ign
 | 
			
		||||
 | 
			
		||||
    def create_ray_tracer(self) -> None:
 | 
			
		||||
        self.ray_tracer = RayTracing(
 | 
			
		||||
            # pyre-ignore[32]
 | 
			
		||||
            **self.ray_tracer_args,
 | 
			
		||||
            object_bounding_sphere=self.object_bounding_sphere,
 | 
			
		||||
        )
 | 
			
		||||
@ -146,7 +145,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):  # pyre-ign
 | 
			
		||||
            eikonal_points = torch.empty(
 | 
			
		||||
                n_eik_points,
 | 
			
		||||
                3,
 | 
			
		||||
                # pyre-fixme[6]: For 3rd param expected `Union[None, str, device]`
 | 
			
		||||
                #  but got `Union[device, Tensor, Module]`.
 | 
			
		||||
                device=self._bg_color.device,
 | 
			
		||||
            ).uniform_(-eik_bounding_box, eik_bounding_box)
 | 
			
		||||
@ -205,7 +203,7 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module):  # pyre-ign
 | 
			
		||||
            normals_full.view(-1, 3)[surface_mask] = normals
 | 
			
		||||
            render_full.view(-1, self.render_features_dimensions)[
 | 
			
		||||
                surface_mask
 | 
			
		||||
            ] = self._rgb_network(  # pyre-fixme[29]:
 | 
			
		||||
            ] = self._rgb_network(
 | 
			
		||||
                features,
 | 
			
		||||
                differentiable_surface_points[None],
 | 
			
		||||
                normals,
 | 
			
		||||
 | 
			
		||||
@ -530,11 +530,6 @@ def _get_ray_dir_dot_prods(camera: CamerasBase, pts: torch.Tensor):
 | 
			
		||||
 | 
			
		||||
    # does not produce nans randomly unlike get_camera_center() below
 | 
			
		||||
    cam_centers_rep = -torch.bmm(
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
        #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
 | 
			
		||||
        #  torch.Tensor, torch.nn.modules.module.Module]` is not a function.
 | 
			
		||||
        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch.Tensor.permute)[[N...
 | 
			
		||||
        camera_rep.T[:, None],
 | 
			
		||||
        camera_rep.R.permute(0, 2, 1),
 | 
			
		||||
    ).reshape(-1, *([1] * (pts.ndim - 2)), 3)
 | 
			
		||||
 | 
			
		||||
@ -120,16 +120,12 @@ def corresponding_cameras_alignment(
 | 
			
		||||
 | 
			
		||||
    # create a new cameras object and set the R and T accordingly
 | 
			
		||||
    cameras_src_aligned = cameras_src.clone()
 | 
			
		||||
    # pyre-fixme[6]: For 2nd param expected `Tensor` but got `Union[Tensor, Module]`.
 | 
			
		||||
    cameras_src_aligned.R = torch.bmm(align_t_R.expand_as(cameras_src.R), cameras_src.R)
 | 
			
		||||
    cameras_src_aligned.T = (
 | 
			
		||||
        torch.bmm(
 | 
			
		||||
            align_t_T[:, None].repeat(cameras_src.R.shape[0], 1, 1),
 | 
			
		||||
            # pyre-fixme[6]: For 2nd param expected `Tensor` but got `Union[Tensor,
 | 
			
		||||
            #  Module]`.
 | 
			
		||||
            cameras_src.R,
 | 
			
		||||
        )[:, 0]
 | 
			
		||||
        # pyre-fixme[29]: `Union[BoundMethod[typing.Callable(torch._C._TensorBase.__m...
 | 
			
		||||
        + cameras_src.T * align_t_s
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -177,7 +173,6 @@ def _align_camera_extrinsics(
 | 
			
		||||
        R_A = (U V^T)^T
 | 
			
		||||
        ```
 | 
			
		||||
    """
 | 
			
		||||
    # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, Module]`.
 | 
			
		||||
    RRcov = torch.bmm(cameras_src.R, cameras_tgt.R.transpose(2, 1)).mean(0)
 | 
			
		||||
    U, _, V = torch.svd(RRcov)
 | 
			
		||||
    align_t_R = V @ U.t()
 | 
			
		||||
@ -207,17 +202,7 @@ def _align_camera_extrinsics(
 | 
			
		||||
        T_A = mean(B) - mean(A) * s_A
 | 
			
		||||
        ```
 | 
			
		||||
    """
 | 
			
		||||
    # pyre-fixme[29]:
 | 
			
		||||
    #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
    #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
 | 
			
		||||
    #  torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
    # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, Module]`.
 | 
			
		||||
    A = torch.bmm(cameras_src.R, cameras_src.T[:, :, None])[:, :, 0]
 | 
			
		||||
    # pyre-fixme[29]:
 | 
			
		||||
    #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
    #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
 | 
			
		||||
    #  torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
    # pyre-fixme[6]: For 1st param expected `Tensor` but got `Union[Tensor, Module]`.
 | 
			
		||||
    B = torch.bmm(cameras_src.R, cameras_tgt.T[:, :, None])[:, :, 0]
 | 
			
		||||
    Amu = A.mean(0, keepdim=True)
 | 
			
		||||
    Bmu = B.mean(0, keepdim=True)
 | 
			
		||||
 | 
			
		||||
@ -63,8 +63,8 @@ def _opencv_from_cameras_projection(
 | 
			
		||||
    cameras: PerspectiveCameras,
 | 
			
		||||
    image_size: torch.Tensor,
 | 
			
		||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
 | 
			
		||||
    R_pytorch3d = cameras.R.clone()  # pyre-ignore
 | 
			
		||||
    T_pytorch3d = cameras.T.clone()  # pyre-ignore
 | 
			
		||||
    R_pytorch3d = cameras.R.clone()
 | 
			
		||||
    T_pytorch3d = cameras.T.clone()
 | 
			
		||||
    focal_pytorch3d = cameras.focal_length
 | 
			
		||||
    p0_pytorch3d = cameras.principal_point
 | 
			
		||||
    T_pytorch3d[:, :2] *= -1
 | 
			
		||||
 | 
			
		||||
@ -201,8 +201,8 @@ class CamerasBase(TensorProperties):
 | 
			
		||||
        """
 | 
			
		||||
        R: torch.Tensor = kwargs.get("R", self.R)
 | 
			
		||||
        T: torch.Tensor = kwargs.get("T", self.T)
 | 
			
		||||
        self.R = R  # pyre-ignore[16]
 | 
			
		||||
        self.T = T  # pyre-ignore[16]
 | 
			
		||||
        self.R = R
 | 
			
		||||
        self.T = T
 | 
			
		||||
        world_to_view_transform = get_world_to_view_transform(R=R, T=T)
 | 
			
		||||
        return world_to_view_transform
 | 
			
		||||
 | 
			
		||||
@ -226,8 +226,8 @@ class CamerasBase(TensorProperties):
 | 
			
		||||
            a Transform3d object which represents a batch of transforms
 | 
			
		||||
            of shape (N, 3, 3)
 | 
			
		||||
        """
 | 
			
		||||
        self.R: torch.Tensor = kwargs.get("R", self.R)  # pyre-ignore[16]
 | 
			
		||||
        self.T: torch.Tensor = kwargs.get("T", self.T)  # pyre-ignore[16]
 | 
			
		||||
        self.R: torch.Tensor = kwargs.get("R", self.R)
 | 
			
		||||
        self.T: torch.Tensor = kwargs.get("T", self.T)
 | 
			
		||||
        world_to_view_transform = self.get_world_to_view_transform(R=self.R, T=self.T)
 | 
			
		||||
        view_to_proj_transform = self.get_projection_transform(**kwargs)
 | 
			
		||||
        return world_to_view_transform.compose(view_to_proj_transform)
 | 
			
		||||
 | 
			
		||||
@ -264,9 +264,7 @@ class PointLights(TensorProperties):
 | 
			
		||||
        shape (P, 3) or (N, H, W, K, 3).
 | 
			
		||||
        """
 | 
			
		||||
        if self.location.ndim == points.ndim:
 | 
			
		||||
            # pyre-fixme[7]
 | 
			
		||||
            return self.location
 | 
			
		||||
        # pyre-fixme[29]
 | 
			
		||||
        return self.location[:, None, None, None, :]
 | 
			
		||||
 | 
			
		||||
    def diffuse(self, normals, points) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
@ -585,21 +585,9 @@ def _add_struct_from_batch(
 | 
			
		||||
    if isinstance(batched_struct, CamerasBase):
 | 
			
		||||
        # we can't index directly into camera batches
 | 
			
		||||
        R, T = batched_struct.R, batched_struct.T
 | 
			
		||||
        # pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor,
 | 
			
		||||
        #  torch.nn.Module]`.
 | 
			
		||||
        r_idx = min(scene_num, len(R) - 1)
 | 
			
		||||
        # pyre-fixme[6]: Expected `Sized` for 1st param but got `Union[torch.Tensor,
 | 
			
		||||
        #  torch.nn.Module]`.
 | 
			
		||||
        t_idx = min(scene_num, len(T) - 1)
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
        #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
 | 
			
		||||
        #  torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        R = R[r_idx].unsqueeze(0)
 | 
			
		||||
        # pyre-fixme[29]:
 | 
			
		||||
        #  `Union[BoundMethod[typing.Callable(torch.Tensor.__getitem__)[[Named(self,
 | 
			
		||||
        #  torch.Tensor), Named(item, typing.Any)], typing.Any], torch.Tensor],
 | 
			
		||||
        #  torch.Tensor, torch.nn.Module]` is not a function.
 | 
			
		||||
        T = T[t_idx].unsqueeze(0)
 | 
			
		||||
        struct = CamerasBase(device=batched_struct.device, R=R, T=T)
 | 
			
		||||
    elif _is_ray_bundle(batched_struct) and not _is_heterogeneous_ray_bundle(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user