mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-30 02:28:56 +08:00
voxel grids with interpolation
Summary: Added voxel grid classes from TensoRF, both in their factorized (CP and VM) and full form. Reviewed By: bottler Differential Revision: D38465419 fbshipit-source-id: 8b306338af58dc50ef47a682616022a0512c0047
This commit is contained in:
committed by
Facebook GitHub Bot
parent
af799facdd
commit
edee25a1e5
@@ -7,6 +7,8 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.common.compat import prod
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
@@ -88,3 +90,98 @@ def create_embeddings_for_implicit_function(
|
||||
embeds = broadcast_global_code(embeds, global_code)
|
||||
|
||||
return embeds
|
||||
|
||||
|
||||
def interpolate_line(
|
||||
points: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Linearly interpolates values of source grids. The first dimension of points represents
|
||||
number of points and the second coordinate, for example ([[x0], [x1], ...]). The first
|
||||
dimension of argument source represents feature and ones after that the spatial
|
||||
dimension.
|
||||
|
||||
Arguments:
|
||||
points: shape (n_grids, n_points, 1),
|
||||
source: tensor of shape (n_grids, features, width),
|
||||
Returns:
|
||||
interpolated tensor of shape (n_grids, n_points, features)
|
||||
"""
|
||||
# To enable sampling of the source using the torch.functional.grid_sample
|
||||
# points need to have 2 coordinates.
|
||||
expansion = points.new_zeros(points.shape)
|
||||
points = torch.cat((points, expansion), dim=-1)
|
||||
|
||||
source = source[:, :, None, :]
|
||||
points = points[:, :, None, :]
|
||||
|
||||
out = F.grid_sample(
|
||||
grid=points,
|
||||
input=source,
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0].permute(0, 2, 1)
|
||||
|
||||
|
||||
def interpolate_plane(
|
||||
points: torch.Tensor,
|
||||
source: torch.Tensor,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Bilinearly interpolates values of source grids. The first dimension of points represents
|
||||
number of points and the second coordinates, for example ([[x0, y0], [x1, y1], ...]).
|
||||
The first dimension of argument source represents feature and ones after that the
|
||||
spatial dimension.
|
||||
|
||||
Arguments:
|
||||
points: shape (n_grids, n_points, 2),
|
||||
source: tensor of shape (n_grids, features, width, height),
|
||||
Returns:
|
||||
interpolated tensor of shape (n_grids, n_points, features)
|
||||
"""
|
||||
# permuting because torch.nn.functional.grid_sample works with
|
||||
# (features, height, width) and not
|
||||
# (features, width, height)
|
||||
source = source.permute(0, 1, 3, 2)
|
||||
points = points[:, :, None, :]
|
||||
|
||||
out = F.grid_sample(
|
||||
grid=points,
|
||||
input=source,
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0].permute(0, 2, 1)
|
||||
|
||||
|
||||
def interpolate_volume(
|
||||
points: torch.Tensor, source: torch.Tensor, **kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Interpolates values of source grids. The first dimension of points represents
|
||||
number of points and the second coordinates, for example
|
||||
[[x0, y0, z0], [x1, y1, z1], ...]. The first dimension of a source represents features
|
||||
and ones after that the spatial dimension.
|
||||
|
||||
Arguments:
|
||||
points: shape (n_grids, n_points, 3),
|
||||
source: tensor of shape (n_grids, features, width, height, depth),
|
||||
Returns:
|
||||
interpolated tensor of shape (n_grids, n_points, features)
|
||||
"""
|
||||
if "mode" in kwargs and kwargs["mode"] == "trilinear":
|
||||
kwargs = kwargs.copy()
|
||||
kwargs["mode"] = "bilinear"
|
||||
# permuting because torch.nn.functional.grid_sample works with
|
||||
# (features, depth, height, width) and not (features, width, height, depth)
|
||||
source = source.permute(0, 1, 4, 3, 2)
|
||||
grid = points[:, :, None, None, :]
|
||||
|
||||
out = F.grid_sample(
|
||||
grid=grid,
|
||||
input=source,
|
||||
**kwargs,
|
||||
)
|
||||
return out[:, :, :, 0, 0].permute(0, 2, 1)
|
||||
|
||||
428
pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Normal file
428
pytorch3d/implicitron/models/implicit_function/voxel_grid.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
This file contains classes that implement Voxel grids, both in their full resolution
|
||||
as in the factorized form. There are two factorized forms implemented, Tensor rank decomposition
|
||||
or CANDECOMP/PARAFAC (here CP) and Vector Matrix (here VM) factorization from the
|
||||
https://arxiv.org/abs/2203.09517.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
from pytorch3d.structures.volumes import VolumeLocator
|
||||
|
||||
from .utils import interpolate_line, interpolate_plane, interpolate_volume
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoxelGridValuesBase:
|
||||
pass
|
||||
|
||||
|
||||
class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
"""
|
||||
Base class for all the voxel grid variants whith added trilinear interpolation between
|
||||
voxels (for example if voxel (0.333, 1, 3) is queried that would return the result
|
||||
2/3*voxel[0, 1, 3] + 1/3*voxel[1, 1, 3])
|
||||
|
||||
Internally voxel grids are indexed by (features, x, y, z). If queried the point is not
|
||||
inside the voxel grid the vector that will be returned is determined by padding.
|
||||
|
||||
Members:
|
||||
align_corners: parameter used in torch.functional.grid_sample. For details go to
|
||||
https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html by
|
||||
default is True
|
||||
padding: padding mode for outside grid values 'zeros' | 'border' | 'reflection'.
|
||||
Default is 'zeros'
|
||||
mode: interpolation mode to calculate output values :
|
||||
'bilinear' | 'nearest' | 'bicubic' | 'trilinear'.
|
||||
Default: 'bilinear' Note: mode='bicubic' supports only FullResolutionVoxelGrid.
|
||||
When mode='bilinear' and the input is 5-D, the interpolation mode used internally
|
||||
will actually be trilinear.
|
||||
n_features: number of dimensions of base feature vector. Determines how many features
|
||||
the grid returns.
|
||||
resolution: 3-tuple containing x, y, z grid sizes corresponding to each axis.
|
||||
"""
|
||||
|
||||
align_corners: bool = True
|
||||
padding: str = "zeros"
|
||||
mode: str = "bilinear"
|
||||
n_features: int = 1
|
||||
resolution: Tuple[int, int, int] = (64, 64, 64)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__()
|
||||
|
||||
def evaluate_world(
|
||||
self,
|
||||
points: torch.Tensor,
|
||||
grid_values: VoxelGridValuesBase,
|
||||
locator: VolumeLocator,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates the voxel grid at points in the world coordinate frame.
|
||||
The interpolation type is determined by the `mode` member.
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_grids, n_points, 3)
|
||||
grid_values: an object of type Class.values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
locator: a VolumeLocator object
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
"""
|
||||
points_local = locator.world_to_local_coords(points)
|
||||
# pyre-ignore[29]
|
||||
return self.evaluate_local(points_local, grid_values)
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: VoxelGridValuesBase
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates the voxel grid at points in the local coordinate frame,
|
||||
The interpolation type is determined by the `mode` member.
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
||||
as members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
"""
|
||||
Using parameters from the __init__ method, this method returns the
|
||||
shapes of individual tensors needed to run the evaluate method.
|
||||
|
||||
Returns:
|
||||
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
||||
replace the shapes in the dictionary with tensors of those shapes and add the
|
||||
first 'batch' dimension. If the required shape is (a, b) and you want to
|
||||
have g grids than the tensor that replaces the shape should have the
|
||||
shape (g, a, b).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||
voxel_grid: torch.Tensor
|
||||
|
||||
|
||||
@registry.register
|
||||
class FullResolutionVoxelGrid(VoxelGridBase):
|
||||
"""
|
||||
Full resolution voxel grid equivalent to 4D tensor where shape is
|
||||
(features, width, height, depth) with linear interpolation between voxels.
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = FullResolutionVoxelGridValues
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: FullResolutionVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Evaluates the voxel grid at points in the local coordinate frame,
|
||||
The interpolation type is determined by the `mode` member.
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
"""
|
||||
return interpolate_volume(
|
||||
points,
|
||||
grid_values.voxel_grid,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
return {"voxel_grid": (self.n_features, *self.resolution)}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPFactorizedVoxelGridValues(VoxelGridValuesBase):
|
||||
vector_components_x: torch.Tensor
|
||||
vector_components_y: torch.Tensor
|
||||
vector_components_z: torch.Tensor
|
||||
basis_matrix: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@registry.register
|
||||
class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
"""
|
||||
Canonical Polyadic (CP/CANDECOMP/PARAFAC) Factorization factorizes the 3d grid into three
|
||||
vectors (x, y, z). For n_components=n, the 3d grid is a sum of the two outer products
|
||||
(call it ⊗) of each vector type (x, y, z):
|
||||
|
||||
3d_grid = x0 ⊗ y0 ⊗ z0 + x1 ⊗ y1 ⊗ z1 + ... + xn ⊗ yn ⊗ zn
|
||||
|
||||
These tensors are passed in a object of CPFactorizedVoxelGridValues (here obj) as
|
||||
obj.vector_components_x, obj.vector_components_y, obj.vector_components_z. Their shapes are
|
||||
`(n_components, r)` where `r` is the relevant resolution.
|
||||
|
||||
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
||||
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
||||
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
||||
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
||||
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
||||
n_features as F, n_components as C and n_grids as G:
|
||||
|
||||
3d_grid = (x ⊗ y ⊗ z) @ basis # GWHDC x GCF -> GWHDF
|
||||
|
||||
The basis feature vectors are passed as obj.basis_matrix.
|
||||
|
||||
Members:
|
||||
n_components: number of vector triplets, higher number gives better approximation.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
|
||||
basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = CPFactorizedVoxelGridValues
|
||||
|
||||
n_components: int = 24
|
||||
matrix_reduction: bool = True
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
def factor(i):
|
||||
axis = ["x", "y", "z"][i]
|
||||
index = points[..., i, None]
|
||||
vector = getattr(grid_values, "vector_components_" + axis)
|
||||
return interpolate_line(
|
||||
index,
|
||||
vector,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# collect points from all the vectors and multipy them out
|
||||
mult = factor(0) * factor(1) * factor(2)
|
||||
|
||||
# reduce the result from
|
||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
||||
if grid_values.basis_matrix is not None:
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
|
||||
return torch.bmm(mult, grid_values.basis_matrix)
|
||||
|
||||
return mult.sum(axis=-1, keepdim=True)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
||||
|
||||
shape_dict = {
|
||||
"vector_components_x": (self.n_components, self.resolution[0]),
|
||||
"vector_components_y": (self.n_components, self.resolution[1]),
|
||||
"vector_components_z": (self.n_components, self.resolution[2]),
|
||||
}
|
||||
if self.matrix_reduction:
|
||||
shape_dict["basis_matrix"] = (self.n_components, self.n_features)
|
||||
return shape_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class VMFactorizedVoxelGridValues(VoxelGridValuesBase):
|
||||
vector_components_x: torch.Tensor
|
||||
vector_components_y: torch.Tensor
|
||||
vector_components_z: torch.Tensor
|
||||
matrix_components_xy: torch.Tensor
|
||||
matrix_components_yz: torch.Tensor
|
||||
matrix_components_xz: torch.Tensor
|
||||
basis_matrix: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@registry.register
|
||||
class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
"""
|
||||
Implementation of Vector-Matrix Factorization of a tensor from
|
||||
https://arxiv.org/abs/2203.09517.
|
||||
|
||||
Vector-Matrix Factorization factorizes the 3d grid into three matrices
|
||||
(xy, xz, yz) and three vectors (x, y, z). For n_components=1, the 3d grid
|
||||
is a sum of the outer products (call it ⊗) of each matrix with its
|
||||
complementary vector:
|
||||
|
||||
3d_grid = xy ⊗ z + xz ⊗ y + yz ⊗ x.
|
||||
|
||||
These tensors are passed in a VMFactorizedVoxelGridValues object (here obj)
|
||||
as obj.matrix_components_xy, obj.matrix_components_xy, obj.vector_components_y, etc.
|
||||
|
||||
Their shapes are `(n_grids, n_components, r0, r1)` for matrix_components and
|
||||
(n_grids, n_components, r2)` for vector_componenets. Each of `r0, r1 and r2` coresponds
|
||||
to one resolution in (width, height and depth).
|
||||
|
||||
Each element of this sum has an extra dimension, which gets matrix-multiplied by an
|
||||
appropriate "basis matrix" of shape (n_grids, n_components, n_features). This multiplication
|
||||
brings us to the desired "n_features" dimensionality. If matrix_reduction=False the elements
|
||||
of different components are summed together to create (n_grids, n_components, 1) tensor.
|
||||
With some notation abuse, ignoring the interpolation operation, simplifying and denoting
|
||||
n_features as F, n_components as C (which can differ for each dimension) and n_grids as G:
|
||||
|
||||
3d_grid = concat((xy ⊗ z), (xz ⊗ y).permute(0, 2, 1),
|
||||
(yz ⊗ x).permute(2, 0, 1)) @ basis_matrix # GWHDC x GCF -> GWHDF
|
||||
|
||||
Members:
|
||||
n_components: total number of matrix vector pairs, this must be divisible by 3. Set
|
||||
this if you want to have equal representational power in all 3 directions. You
|
||||
must specify either n_components or distribution_of_components, you cannot
|
||||
specify both.
|
||||
distribution_of_components: if you do not want equal representational power in
|
||||
all 3 directions specify a tuple of numbers of matrix_vector pairs for each
|
||||
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
||||
either n_components or distribution_of_components, you cannot specify both.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
|
||||
the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
values_type: ClassVar[Type[VoxelGridValuesBase]] = VMFactorizedVoxelGridValues
|
||||
|
||||
n_components: Optional[int] = None
|
||||
distribution_of_components: Optional[Tuple[int, int, int]] = None
|
||||
matrix_reduction: bool = True
|
||||
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
# collect points from matrices and vectors and multiply them
|
||||
a = interpolate_plane(
|
||||
points[..., :2],
|
||||
grid_values.matrix_components_xy,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
) * interpolate_line(
|
||||
points[..., 2:],
|
||||
grid_values.vector_components_z,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
b = interpolate_plane(
|
||||
points[..., [0, 2]],
|
||||
grid_values.matrix_components_xz,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
) * interpolate_line(
|
||||
points[..., 1:2],
|
||||
grid_values.vector_components_y,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
c = interpolate_plane(
|
||||
points[..., 1:],
|
||||
grid_values.matrix_components_yz,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
) * interpolate_line(
|
||||
points[..., :1],
|
||||
grid_values.vector_components_x,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
# pyre-ignore[28]
|
||||
feats = torch.cat((a, b, c), axis=-1)
|
||||
|
||||
# reduce the result from
|
||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
||||
if grid_values.basis_matrix is not None:
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x
|
||||
# (n_grids, total_n_components, n_features)
|
||||
return torch.bmm(feats, grid_values.basis_matrix)
|
||||
# pyre-ignore[28]
|
||||
return feats.sum(axis=-1, keepdim=True)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
raise ValueError("Cannot set matrix_reduction=False and n_features to != 1")
|
||||
if self.distribution_of_components is None and self.n_components is None:
|
||||
raise ValueError(
|
||||
"You need to provide n_components or distribution_of_components"
|
||||
)
|
||||
if (
|
||||
self.distribution_of_components is not None
|
||||
and self.n_components is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"You cannot define n_components and distribution_of_components"
|
||||
)
|
||||
# pyre-ignore[58]
|
||||
if self.distribution_of_components is None and self.n_components % 3 != 0:
|
||||
raise ValueError("n_components must be divisible by 3")
|
||||
if self.distribution_of_components is None:
|
||||
# pyre-ignore[58]
|
||||
calculated_distribution_of_components = [
|
||||
self.n_components // 3 for _ in range(3)
|
||||
]
|
||||
else:
|
||||
calculated_distribution_of_components = self.distribution_of_components
|
||||
|
||||
shape_dict = {
|
||||
"vector_components_x": (
|
||||
calculated_distribution_of_components[1],
|
||||
self.resolution[0],
|
||||
),
|
||||
"vector_components_y": (
|
||||
calculated_distribution_of_components[2],
|
||||
self.resolution[1],
|
||||
),
|
||||
"vector_components_z": (
|
||||
calculated_distribution_of_components[0],
|
||||
self.resolution[2],
|
||||
),
|
||||
"matrix_components_xy": (
|
||||
calculated_distribution_of_components[0],
|
||||
self.resolution[0],
|
||||
self.resolution[1],
|
||||
),
|
||||
"matrix_components_yz": (
|
||||
calculated_distribution_of_components[1],
|
||||
self.resolution[1],
|
||||
self.resolution[2],
|
||||
),
|
||||
"matrix_components_xz": (
|
||||
calculated_distribution_of_components[2],
|
||||
self.resolution[0],
|
||||
self.resolution[2],
|
||||
),
|
||||
}
|
||||
if self.matrix_reduction:
|
||||
shape_dict["basis_matrix"] = (
|
||||
sum(calculated_distribution_of_components),
|
||||
self.n_features,
|
||||
)
|
||||
|
||||
return shape_dict
|
||||
Reference in New Issue
Block a user