mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-30 02:28:56 +08:00
arbitrary shape input to voxel_grids
Summary: Add the ability to process arbitrary point shapes `[n_grids, ..., 3]` instead of only `[n_grids, n_points, 3]`. Reviewed By: bottler Differential Revision: D39574373 fbshipit-source-id: 0a9ecafe9ea58cd8f909644de43a1185ecf934f4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6ae6ff9cf7
commit
db3c12abfb
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from typing import ClassVar, Dict, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
registry,
|
||||
@@ -81,12 +82,12 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_grids, n_points, 3)
|
||||
of a form (n_grids, ..., 3)
|
||||
grid_values: an object of type Class.values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
locator: a VolumeLocator object
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
torch.Tensor: shape (n_grids, ..., n_features)
|
||||
"""
|
||||
points_local = locator.world_to_local_coords(points)
|
||||
return self.evaluate_local(points_local, grid_values)
|
||||
@@ -100,11 +101,11 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
of a form (n_grids, ..., 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type VMFactorizedVoxelGrid.values_type which has tensors
|
||||
as members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
torch.Tensor: shape (n_grids, ..., n_features)
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -117,11 +118,28 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
a dictionary of needed shapes. To use the evaluate_local and evaluate_world methods
|
||||
replace the shapes in the dictionary with tensors of those shapes and add the
|
||||
first 'batch' dimension. If the required shape is (a, b) and you want to
|
||||
have g grids than the tensor that replaces the shape should have the
|
||||
have g grids then the tensor that replaces the shape should have the
|
||||
shape (g, a, b).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def get_output_dim(args: DictConfig) -> int:
|
||||
"""
|
||||
Given all the arguments of the grid's __init__, returns output's last dimension length.
|
||||
|
||||
In particular, if self.evaluate_world or self.evaluate_local
|
||||
are called with `points` of shape (n_grids, n_points, 3),
|
||||
their output will be of shape
|
||||
(n_grids, n_points, grid.get_output_dim()).
|
||||
|
||||
Args:
|
||||
args: DictConfig which would be used to initialize the object
|
||||
Returns:
|
||||
output's last dimension length
|
||||
"""
|
||||
return args["n_features"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FullResolutionVoxelGridValues(VoxelGridValuesBase):
|
||||
@@ -149,19 +167,23 @@ class FullResolutionVoxelGrid(VoxelGridBase):
|
||||
|
||||
Arguments:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3), in a normalized form (coordinates are in [-1, 1])
|
||||
of a form (..., 3), in a normalized form (coordinates are in [-1, 1])
|
||||
grid_values: an object of type values_type which has tensors as
|
||||
members which have shapes derived from the get_shapes() method
|
||||
Returns:
|
||||
torch.Tensor: shape (n_grids, n_points, n_features)
|
||||
torch.Tensor: shape (n_grids, ..., n_features)
|
||||
"""
|
||||
return interpolate_volume(
|
||||
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||
recorded_shape = points.shape
|
||||
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||
interpolated = interpolate_volume(
|
||||
points,
|
||||
grid_values.voxel_grid,
|
||||
align_corners=self.align_corners,
|
||||
padding_mode=self.padding,
|
||||
mode=self.mode,
|
||||
)
|
||||
return interpolated.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
return {"voxel_grid": (self.n_features, *self.resolution)}
|
||||
@@ -202,10 +224,11 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
Members:
|
||||
n_components: number of vector triplets, higher number gives better approximation.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by the
|
||||
basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
|
||||
to return to starting shape (n_grids, ..., 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
@@ -219,8 +242,8 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: CPFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
def factor(i):
|
||||
axis = ["x", "y", "z"][i]
|
||||
def factor(axis):
|
||||
i = {"x": 0, "y": 1, "z": 2}[axis]
|
||||
index = points[..., i, None]
|
||||
vector = getattr(grid_values, "vector_components_" + axis)
|
||||
return interpolate_line(
|
||||
@@ -231,17 +254,25 @@ class CPFactorizedVoxelGrid(VoxelGridBase):
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||
recorded_shape = points.shape
|
||||
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||
|
||||
# collect points from all the vectors and multipy them out
|
||||
mult = factor(0) * factor(1) * factor(2)
|
||||
mult = factor("x") * factor("y") * factor("z")
|
||||
|
||||
# reduce the result from
|
||||
# (n_grids, n_points, n_components) to (n_grids, n_points, n_features)
|
||||
# (n_grids, n_points_total, n_components) to (n_grids, n_points_total, n_features)
|
||||
if grid_values.basis_matrix is not None:
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x (total_n_components, n_features)
|
||||
return torch.bmm(mult, grid_values.basis_matrix)
|
||||
|
||||
return mult.sum(axis=-1, keepdim=True)
|
||||
# (n_grids, n_points_total, n_features) =
|
||||
# (n_grids, n_points_total, total_n_components) @
|
||||
# (n_grids, total_n_components, n_features)
|
||||
result = torch.bmm(mult, grid_values.basis_matrix)
|
||||
else:
|
||||
# (n_grids, n_points_total, 1) from (n_grids, n_points_total, n_features)
|
||||
result = mult.sum(axis=-1, keepdim=True)
|
||||
# (n_grids, ..., n_features)
|
||||
return result.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple[int, int]]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
@@ -308,10 +339,11 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
coordinate of a form (n_xy_planes, n_yz_planes, n_xz_planes). You must specify
|
||||
either n_components or distribution_of_components, you cannot specify both.
|
||||
matrix_reduction: how to transform components. If matrix_reduction=True result
|
||||
matrix of shape (n_grids, n_points, n_components) is batch matrix multiplied by
|
||||
the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points, n_components)
|
||||
is summed along the rows to get (n_grids, n_points, 1).
|
||||
matrix of shape (n_grids, n_points_total, n_components) is batch matrix multiplied
|
||||
by the basis_matrix of shape (n_grids, n_components, n_features). If
|
||||
matrix_reduction=False, the result tensor of (n_grids, n_points_total, n_components)
|
||||
is summed along the rows to get (n_grids, n_points_total, 1), which is then viewed
|
||||
to return to starting shape (n_grids, ..., 1).
|
||||
"""
|
||||
|
||||
# the type of grid_values argument needed to run evaluate_local()
|
||||
@@ -326,6 +358,10 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
def evaluate_local(
|
||||
self, points: torch.Tensor, grid_values: VMFactorizedVoxelGridValues
|
||||
) -> torch.Tensor:
|
||||
# (n_grids, n_points_total, n_features) from (n_grids, ..., n_features)
|
||||
recorded_shape = points.shape
|
||||
points = points.view(points.shape[0], -1, points.shape[-1])
|
||||
|
||||
# collect points from matrices and vectors and multiply them
|
||||
a = interpolate_plane(
|
||||
points[..., :2],
|
||||
@@ -375,9 +411,13 @@ class VMFactorizedVoxelGrid(VoxelGridBase):
|
||||
# (n_grids, n_points, n_features) =
|
||||
# (n_grids, n_points, total_n_components) x
|
||||
# (n_grids, total_n_components, n_features)
|
||||
return torch.bmm(feats, grid_values.basis_matrix)
|
||||
# pyre-ignore[28]
|
||||
return feats.sum(axis=-1, keepdim=True)
|
||||
result = torch.bmm(feats, grid_values.basis_matrix)
|
||||
else:
|
||||
# pyre-ignore[28]
|
||||
# (n_grids, n_points, 1) from (n_grids, n_points, n_features)
|
||||
result = feats.sum(axis=-1, keepdim=True)
|
||||
# (n_grids, ..., n_features)
|
||||
return result.view(*recorded_shape[:-1], -1)
|
||||
|
||||
def get_shapes(self) -> Dict[str, Tuple]:
|
||||
if self.matrix_reduction is False and self.n_features != 1:
|
||||
@@ -494,9 +534,9 @@ class VoxelGridModule(Configurable, torch.nn.Module):
|
||||
|
||||
Args:
|
||||
points (torch.Tensor): tensor of points that you want to query
|
||||
of a form (n_points, 3)
|
||||
of a form (..., 3)
|
||||
Returns:
|
||||
torch.Tensor of shape (n_points, n_features)
|
||||
torch.Tensor of shape (..., n_features)
|
||||
"""
|
||||
locator = VolumeLocator(
|
||||
batch_size=1,
|
||||
|
||||
Reference in New Issue
Block a user