Initial commit

fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
facebook-github-bot
2020-01-23 11:53:41 -08:00
commit dbf06b504b
211 changed files with 47362 additions and 0 deletions

11
pytorch3d/ops/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from .cubify import cubify
from .graph_conv import GraphConv
from .nearest_neighbor_points import nn_points_idx
from .sample_points_from_meshes import sample_points_from_meshes
from .subdivide_meshes import SubdivideMeshes
from .vert_align import vert_align
__all__ = [k for k in globals().keys() if not k.startswith("_")]

208
pytorch3d/ops/cubify.py Normal file
View File

@@ -0,0 +1,208 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
from pytorch3d.structures import Meshes
def unravel_index(idx, dims) -> torch.Tensor:
r"""
Equivalent to np.unravel_index
Args:
idx: A LongTensor whose elements are indices into the
flattened version of an array of dimensions dims.
dims: The shape of the array to be indexed.
Implemented only for dims=(N, H, W, D)
"""
if len(dims) != 4:
raise ValueError("Expects a 4-element list.")
N, H, W, D = dims
n = torch.div(idx, H * W * D)
h = torch.div(idx - n * H * W * D, W * D)
w = torch.div(idx - n * H * W * D - h * W * D, D)
d = idx - n * H * W * D - h * W * D - w * D
return torch.stack((n, h, w, d), dim=1)
def ravel_index(idx, dims) -> torch.Tensor:
"""
Computes the linear index in an array of shape dims.
It performs the reverse functionality of unravel_index
Args:
idx: A LongTensor of shape (N, 3). Each row corresponds to indices into an
array of dimensions dims.
dims: The shape of the array to be indexed.
Implemented only for dims=(H, W, D)
"""
if len(dims) != 3:
raise ValueError("Expects a 3-element list")
if idx.shape[1] != 3:
raise ValueError("Expects an index tensor of shape Nx3")
H, W, D = dims
linind = idx[:, 0] * W * D + idx[:, 1] * D + idx[:, 2]
return linind
@torch.no_grad()
def cubify(voxels, thresh, device=None) -> Meshes:
r"""
Converts a voxel to a mesh by replacing each occupied voxel with a cube
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
internal faces are removed.
Args:
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
thresh: A scalar threshold. If a voxel occupancy is larger than
thresh, the voxel is considered occupied.
Returns:
meshes: A Meshes object of the corresponding meshes.
"""
if device is None:
device = voxels.device
if len(voxels) == 0:
return Meshes(verts=[], faces=[])
N, D, H, W = voxels.size()
# vertices corresponding to a unit cube: 8x3
cube_verts = torch.tensor(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1],
],
dtype=torch.int64,
device=device,
)
# faces corresponding to a unit cube: 12x3
cube_faces = torch.tensor(
[
[0, 1, 2],
[1, 3, 2], # left face: 0, 1
[2, 3, 6],
[3, 7, 6], # bottom face: 2, 3
[0, 2, 6],
[0, 6, 4], # front face: 4, 5
[0, 5, 1],
[0, 4, 5], # up face: 6, 7
[6, 7, 5],
[6, 5, 4], # right face: 8, 9
[1, 7, 3],
[1, 5, 7], # back face: 10, 11
],
dtype=torch.int64,
device=device,
)
wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2)
wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1)
wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1)
voxelt = voxels.ge(thresh).float()
# N x 1 x D x H x W
voxelt = voxelt.view(N, 1, D, H, W)
# N x 1 x (D-1) x (H-1) x (W-1)
voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float()
voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float()
voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float()
# 12 x N x 1 x D x H x W
faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device)
# add left face
faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x
faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x
# add bottom face
faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y
faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y
# add front face
faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z
faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z
# add up face
faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y
faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y
# add right face
faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x
faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x
# add back face
faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z
faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z
faces_idx *= voxelt
# N x H x W x D x 12
faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1)
# (NHWD) x 12
faces_idx = faces_idx.contiguous()
faces_idx = faces_idx.view(-1, cube_faces.size(0))
# boolean to linear index
# NF x 2
linind = torch.nonzero(faces_idx)
# NF x 4
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
# NF x 3: faces
faces = torch.index_select(cube_faces, 0, linind[:, 1])
grid_faces = []
for d in range(cube_faces.size(1)):
# NF x 3
xyz = torch.index_select(cube_verts, 0, faces[:, d])
permute_idx = torch.tensor([1, 0, 2], device=device)
yxz = torch.index_select(xyz, 1, permute_idx)
yxz += nyxz[:, 1:]
# NF x 1
temp = ravel_index(yxz, (H + 1, W + 1, D + 1))
grid_faces.append(temp)
# NF x 3
grid_faces = torch.stack(grid_faces, dim=1)
y, x, z = torch.meshgrid(
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
)
y = y.to(device=device, dtype=torch.float32)
y = y * 2.0 / (H - 1.0) - 1.0
x = x.to(device=device, dtype=torch.float32)
x = x * 2.0 / (W - 1.0) - 1.0
z = z.to(device=device, dtype=torch.float32)
z = z * 2.0 / (D - 1.0) - 1.0
# ((H+1)(W+1)(D+1)) x 3
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
if len(nyxz) == 0:
verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N
faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N
return Meshes(verts=verts_list, faces=faces_list)
num_verts = grid_verts.size(0)
grid_faces += nyxz[:, 0].view(-1, 1) * num_verts
idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device)
idleverts.scatter_(0, grid_faces.flatten(), 0)
grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts
split_size = torch.bincount(nyxz[:, 0], minlength=N)
faces_list = list(torch.split(grid_faces, split_size.tolist(), 0))
idleverts = idleverts.view(N, num_verts)
idlenum = idleverts.cumsum(1)
verts_list = [
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
for n in range(N)
]
faces_list = [
nface - idlenum[n][nface] for n, nface in enumerate(faces_list)
]
return Meshes(verts=verts_list, faces=faces_list)

174
pytorch3d/ops/graph_conv.py Normal file
View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
from pytorch3d import _C
class GraphConv(nn.Module):
"""A single graph convolution layer."""
def __init__(
self,
input_dim: int,
output_dim: int,
init: str = "normal",
directed: bool = False,
):
"""
Args:
input_dim: Number of input features per vertex.
output_dim: Number of output features per vertex.
init: Weight initialization method. Can be one of ['zero', 'normal'].
directed: Bool indicating if edges in the graph are directed.
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.directed = directed
self.w0 = nn.Linear(input_dim, output_dim)
self.w1 = nn.Linear(input_dim, output_dim)
if init == "normal":
nn.init.normal_(self.w0.weight, mean=0, std=0.01)
nn.init.normal_(self.w1.weight, mean=0, std=0.01)
self.w0.bias.data.zero_()
self.w1.bias.data.zero_()
elif init == "zero":
self.w0.weight.data.zero_()
self.w1.weight.data.zero_()
else:
raise ValueError('Invalid GraphConv initialization "%s"' % init)
def forward(self, verts, edges):
"""
Args:
verts: FloatTensor of shape (V, input_dim) where V is the number of
vertices and input_dim is the number of input features
per vertex. input_dim has to match the input_dim specified
in __init__.
edges: LongTensor of shape (E, 2) where E is the number of edges
where each edge has the indices of the two vertices which
form the edge.
Returns:
out: FloatTensor of shape (V, output_dim) where output_dim is the
number of output features per vertex.
"""
if verts.is_cuda != edges.is_cuda:
raise ValueError(
"verts and edges tensors must be on the same device."
)
if verts.shape[0] == 0:
# empty graph.
return verts.sum() * 0.0
verts_w0 = self.w0(verts) # (V, output_dim)
verts_w1 = self.w1(verts) # (V, output_dim)
if torch.cuda.is_available() and verts.is_cuda and edges.is_cuda:
neighbor_sums = gather_scatter(verts_w1, edges, self.directed)
else:
neighbor_sums = gather_scatter_python(
verts_w1, edges, self.directed
) # (V, output_dim)
# Add neighbor features to each vertex's features.
out = verts_w0 + neighbor_sums
return out
def __repr__(self):
Din, Dout, directed = self.input_dim, self.output_dim, self.directed
return "GraphConv(%d -> %d, directed=%r)" % (Din, Dout, directed)
def gather_scatter_python(input, edges, directed: bool = False):
"""
Python implementation of gather_scatter for aggregating features of
neighbor nodes in a graph.
Given a directed graph: v0 -> v1 -> v2 the updated feature for v1 depends
on v2 in order to be consistent with Morris et al. AAAI 2019
(https://arxiv.org/abs/1810.02244). This only affects
directed graphs; for undirected graphs v1 will depend on both v0 and v2,
no matter which way the edges are physically stored.
Args:
input: Tensor of shape (num_vertices, input_dim).
edges: Tensor of edge indices of shape (num_edges, 2).
directed: bool indicating if edges are directed.
Returns:
output: Tensor of same shape as input.
"""
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
num_vertices, input_feature_dim = input.shape
num_edges = edges.shape[0]
output = torch.zeros_like(input)
idx0 = edges[:, 0].view(num_edges, 1).expand(num_edges, input_feature_dim)
idx1 = edges[:, 1].view(num_edges, 1).expand(num_edges, input_feature_dim)
output = output.scatter_add(0, idx0, input.gather(0, idx1))
if not directed:
output = output.scatter_add(0, idx1, input.gather(0, idx0))
return output
class GatherScatter(Function):
"""
Torch autograd Function wrapper for gather_scatter C++/CUDA implementations.
"""
@staticmethod
def forward(ctx, input, edges, directed=False):
"""
Args:
ctx: Context object used to calculate gradients.
input: Tensor of shape (num_vertices, input_dim)
edges: Tensor of edge indices of shape (num_edges, 2)
directed: Bool indicating if edges are directed.
Returns:
output: Tensor of same shape as input.
"""
if not (input.dim() == 2):
raise ValueError("input can only have 2 dimensions.")
if not (edges.dim() == 2):
raise ValueError("edges can only have 2 dimensions.")
if not (edges.shape[1] == 2):
raise ValueError("edges must be of shape (num_edges, 2).")
if not (input.dtype == torch.float32):
raise ValueError("input has to be of type torch.float32.")
ctx.directed = directed
input, edges = input.contiguous(), edges.contiguous()
ctx.save_for_backward(edges)
backward = False
output = _C.gather_scatter(input, edges, directed, backward)
return output
@staticmethod
@once_differentiable
def backward(ctx, grad_output):
grad_output = grad_output.contiguous()
edges = ctx.saved_tensors[0]
directed = ctx.directed
backward = True
grad_input = _C.gather_scatter(grad_output, edges, directed, backward)
grad_edges = None
grad_directed = None
return grad_input, grad_edges, grad_directed
gather_scatter = GatherScatter.apply

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from pytorch3d import _C
def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
"""
Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1.
Args:
p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
containing P1 points of dimension D.
p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
containing P2 points of dimension D.
p2_normals: [optional] FloatTensor of shape (N, P2, D) giving
normals for p2. Default: None.
Returns:
3-element tuple containing
- **p1_nn_points**: FloatTensor of shape (N, P1, D) where
p1_neighbors[n, i] is the point in p2[n] which is
the nearest neighbor to p1[n, i].
- **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of
the neighbors.
- **p1_nn_normals**: Normal vectors for each point in p1_neighbors;
only returned if p2_normals is passed
else return [].
"""
N, P1, D = p1.shape
with torch.no_grad():
p1_nn_idx = _C.nn_points_idx(
p1.contiguous(), p2.contiguous()
) # (N, P1)
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
if p2_normals is None:
p1_nn_normals = []
else:
if p2_normals.shape != p2.shape:
raise ValueError("p2_normals has incorrect shape.")
p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded)
return p1_nn_points, p1_nn_idx, p1_nn_normals

View File

@@ -0,0 +1,127 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
"""
This module implements utility functions for sampling points from
batches of meshes.
"""
import sys
from typing import Tuple, Union
import torch
from pytorch3d import _C
def sample_points_from_meshes(
meshes, num_samples: int = 10000, return_normals: bool = False
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Convert a batch of meshes to a pointcloud by uniformly sampling points on
the surface of the mesh with probability proportional to the face area.
Args:
meshes: A Meshes object with a batch of N meshes.
num_samples: Integer giving the number of point samples per mesh.
return_normals: If True, return normals for the sampled points.
eps: (float) used to clamp the norm of the normals to avoid dividing by 0.
Returns:
2-element tuple containing
- **samples**: FloatTensor of shape (N, num_samples, 3) giving the
coordinates of sampled points for each mesh in the batch. For empty
meshes the corresponding row in the samples array will be filled with 0.
- **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector
to each sampled point. Only returned if return_normals is True.
For empty meshes the corresponding row in the normals array will
be filled with 0.
"""
if meshes.isempty():
raise ValueError("Meshes are empty.")
verts = meshes.verts_packed()
faces = meshes.faces_packed()
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
num_meshes = len(meshes)
num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
# Intialize samples tensor with fill value 0 for empty meshes.
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
# Only compute samples for non empty meshes
with torch.no_grad():
areas, _ = _C.face_areas_normals(
verts, faces
) # Face areas can be zero.
max_faces = meshes.num_faces_per_mesh().max().item()
areas_padded = _C.packed_to_padded_tensor(
areas, mesh_to_face[meshes.valid], max_faces
) # (N, F)
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
sample_face_idxs = areas_padded.multinomial(
num_samples, replacement=True
) # (N, num_samples)
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
# Get the vertex coordinates of the sampled faces.
face_verts = verts[faces.long()]
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
# Randomly generate barycentric coords.
w0, w1, w2 = _rand_barycentric_coords(
num_valid_meshes, num_samples, verts.dtype, verts.device
)
# Use the barycentric coords to get a point on each sampled face.
a = v0[sample_face_idxs] # (N, num_samples, 3)
b = v1[sample_face_idxs]
c = v2[sample_face_idxs]
samples[meshes.valid] = (
w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
)
if return_normals:
# Intialize normals tensor with fill value 0 for empty meshes.
# Normals for the sampled points are face normals computed from
# the vertices of the face in which the sampled point lies.
normals = torch.zeros(
(num_meshes, num_samples, 3), device=meshes.device
)
vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
vert_normals = vert_normals / vert_normals.norm(
dim=1, p=2, keepdim=True
).clamp(min=sys.float_info.epsilon)
vert_normals = vert_normals[sample_face_idxs]
normals[meshes.valid] = vert_normals
return samples, normals
else:
return samples
def _rand_barycentric_coords(
size1, size2, dtype, device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Helper function to generate random barycentric coordinates which are uniformly
distributed over a triangle.
Args:
size1, size2: The number of coordinates generated will be size1*size2.
Output tensors will each be of shape (size1, size2).
dtype: Datatype to generate.
device: A torch.device object on which the outputs will be allocated.
Returns:
w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric
coordinates
"""
uv = torch.rand(2, size1, size2, dtype=dtype, device=device)
u, v = uv[0], uv[1]
u_sqrt = u.sqrt()
w0 = 1.0 - u_sqrt
w1 = u_sqrt * (1.0 - v)
w2 = u_sqrt * v
return w0, w1, w2

View File

@@ -0,0 +1,479 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn as nn
from pytorch3d.structures import Meshes
class SubdivideMeshes(nn.Module):
"""
Subdivide a triangle mesh by adding a new vertex at the center of each edge
and dividing each face into four new faces. Vectors of vertex
attributes can also be subdivided by averaging the values of the attributes
at the two vertices which form each edge. This implementation
preserves face orientation - if the vertices of a face are all ordered
counter-clockwise, then the faces in the subdivided meshes will also have
their vertices ordered counter-clockwise.
If meshes is provided as an input, the initializer performs the relatively
expensive computation of determining the new face indices. This one-time
computation can be reused for all meshes with the same face topology
but different vertex positions.
"""
def __init__(self, meshes=None):
"""
Args:
meshes: Meshes object or None. If a meshes object is provided,
the first mesh is used to compute the new faces of the
subdivided topology which can be reused for meshes with
the same input topology.
"""
super(SubdivideMeshes, self).__init__()
self.precomputed = False
self._N = -1
if meshes is not None:
# This computation is on indices, so gradients do not need to be
# tracked.
mesh = meshes[0]
with torch.no_grad():
subdivided_faces = self.subdivide_faces(mesh)
if subdivided_faces.shape[1] != 3:
raise ValueError("faces can only have three vertices")
self.register_buffer("_subdivided_faces", subdivided_faces)
self.precomputed = True
def subdivide_faces(self, meshes):
r"""
Args:
meshes: a Meshes object.
Returns:
subdivided_faces_packed: (4*sum(F_n), 3) shape LongTensor of
original and new faces.
Refer to pytorch3d.structures.meshes.py for more details on packed
representations of faces.
Each face is split into 4 faces e.g. Input face
::
v0
/\
/ \
/ \
e1 / \ e0
/ \
/ \
/ \
/______________\
v2 e2 v1
faces_packed = [[0, 1, 2]]
faces_packed_to_edges_packed = [[2, 1, 0]]
`faces_packed_to_edges_packed` is used to represent all the new
vertex indices corresponding to the mid-points of edges in the mesh.
The actual vertex coordinates will be computed in the forward function.
To get the indices of the new vertices, offset
`faces_packed_to_edges_packed` by the total number of vertices.
::
faces_packed_to_edges_packed = [[2, 1, 0]] + 3 = [[5, 4, 3]]
e.g. subdivided face
::
v0
/\
/ \
/ f0 \
v4 /______\ v3
/\ /\
/ \ f3 / \
/ f2 \ / f1 \
/______\/______\
v2 v5 v1
f0 = [0, 3, 4]
f1 = [1, 5, 3]
f2 = [2, 4, 5]
f3 = [5, 4, 3]
"""
verts_packed = meshes.verts_packed()
with torch.no_grad():
faces_packed = meshes.faces_packed()
faces_packed_to_edges_packed = meshes.faces_packed_to_edges_packed()
faces_packed_to_edges_packed += verts_packed.shape[0]
f0 = torch.stack(
[
faces_packed[:, 0],
faces_packed_to_edges_packed[:, 2],
faces_packed_to_edges_packed[:, 1],
],
dim=1,
)
f1 = torch.stack(
[
faces_packed[:, 1],
faces_packed_to_edges_packed[:, 0],
faces_packed_to_edges_packed[:, 2],
],
dim=1,
)
f2 = torch.stack(
[
faces_packed[:, 2],
faces_packed_to_edges_packed[:, 1],
faces_packed_to_edges_packed[:, 0],
],
dim=1,
)
f3 = faces_packed_to_edges_packed
subdivided_faces_packed = torch.cat(
[f0, f1, f2, f3], dim=0
) # (4*sum(F_n), 3)
return subdivided_faces_packed
def forward(self, meshes, feats=None):
"""
Subdivide a batch of meshes by adding a new vertex on each edge, and
dividing each face into four new faces. New meshes contains two types
of vertices:
1) Vertices that appear in the input meshes.
Data for these vertices are copied from the input meshes.
2) New vertices at the midpoint of each edge.
Data for these vertices is the average of the data for the two
vertices that make up the edge.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Should be parallel to the packed vert representation of the
input meshes; so it should have shape (V, D) where V is the
total number of verts in the input meshes. Default: None.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
self._N = len(meshes)
if self.precomputed:
return self.subdivide_homogeneous(meshes, feats)
else:
return self.subdivide_heterogenerous(meshes, feats)
def subdivide_homogeneous(self, meshes, feats=None):
"""
Subdivide verts (and optionally features) of a batch of meshes
where each mesh has the same topology of faces. The subdivided faces
are precomputed in the initializer.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
verts = meshes.verts_padded() # (N, V, D)
edges = meshes[0].edges_packed()
# The set of faces is the same across the different meshes.
new_faces = self._subdivided_faces.view(1, -1, 3).expand(
self._N, -1, -1
)
# Add one new vertex at the midpoint of each edge by taking the average
# of the vertices that form each edge.
new_verts = verts[:, edges].mean(dim=2)
new_verts = torch.cat(
[verts, new_verts], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_feats = None
# Calculate features for new vertices.
if feats is not None:
if feats.dim() == 2:
# feats is in packed format, transform it from packed to
# padded, i.e. (N*V, D) to (N, V, D).
feats = feats.view(verts.size(0), verts.size(1), feats.size(1))
if feats.dim() != 3:
raise ValueError(
"features need to be of shape (N, V, D) or (N*V, D)"
)
# Take average of the features at the vertices that form each edge.
new_feats = feats[:, edges].mean(dim=2)
new_feats = torch.cat(
[feats, new_feats], dim=1
) # (sum(V_n)+sum(E_n), 3)
new_meshes = Meshes(verts=new_verts, faces=new_faces)
if feats is None:
return new_meshes
else:
return new_meshes, new_feats
def subdivide_heterogenerous(self, meshes, feats=None):
"""
Subdivide faces, verts (and optionally features) of a batch of meshes
where each mesh can have different face topologies.
Args:
meshes: Meshes object representing a batch of meshes.
feats: Per-vertex features to be subdivided along with the verts.
Returns:
2-element tuple containing
- **new_meshes**: Meshes object of a batch of subdivided meshes.
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
(packed) vertices of the subdivided meshes. Only returned
if feats is not None.
"""
# The computation of new faces is on face indices, so gradients do not
# need to be tracked.
verts = meshes.verts_packed()
with torch.no_grad():
new_faces = self.subdivide_faces(meshes)
edges = meshes.edges_packed()
face_to_mesh_idx = meshes.faces_packed_to_mesh_idx()
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()
num_edges_per_mesh = edge_to_mesh_idx.bincount(minlength=self._N)
num_verts_per_mesh = meshes.num_verts_per_mesh()
num_faces_per_mesh = meshes.num_faces_per_mesh()
# Add one new vertex at the midpoint of each edge.
new_verts_per_mesh = num_verts_per_mesh + num_edges_per_mesh # (N,)
new_face_to_mesh_idx = torch.cat([face_to_mesh_idx] * 4, dim=0)
# Calculate the indices needed to group the new and existing verts
# for each mesh.
verts_sort_idx = create_verts_index(
num_verts_per_mesh, num_edges_per_mesh, meshes.device
) # (sum(V_n)+sum(E_n),)
verts_ordered_idx_init = torch.zeros(
new_verts_per_mesh.sum(),
dtype=torch.int64,
device=meshes.device,
) # (sum(V_n)+sum(E_n),)
# Reassign vertex indices so that existing and new vertices for each
# mesh are sequential.
verts_ordered_idx = verts_ordered_idx_init.scatter_add(
0,
verts_sort_idx,
torch.arange(new_verts_per_mesh.sum(), device=meshes.device),
)
# Retrieve vertex indices for each face.
new_faces = verts_ordered_idx[new_faces]
# Calculate the indices needed to group the existing and new faces
# for each mesh.
face_sort_idx = create_faces_index(
num_faces_per_mesh, device=meshes.device
)
# Reorder the faces to sequentially group existing and new faces
# for each mesh.
new_faces = new_faces[face_sort_idx]
new_face_to_mesh_idx = new_face_to_mesh_idx[face_sort_idx]
new_faces_per_mesh = new_face_to_mesh_idx.bincount(
minlength=self._N
) # (sum(F_n)*4)
# Add one new vertex at the midpoint of each edge by taking the average
# of the verts that form each edge.
new_verts = verts[edges].mean(dim=1)
new_verts = torch.cat([verts, new_verts], dim=0)
# Reorder the verts to sequentially group existing and new verts for
# each mesh.
new_verts = new_verts[verts_sort_idx]
if feats is not None:
new_feats = feats[edges].mean(dim=1)
new_feats = torch.cat([feats, new_feats], dim=0)
new_feats = new_feats[verts_sort_idx]
verts_list = list(new_verts.split(new_verts_per_mesh.tolist(), 0))
faces_list = list(new_faces.split(new_faces_per_mesh.tolist(), 0))
new_verts_per_mesh_cumsum = torch.cat(
[
new_verts_per_mesh.new_full(size=(1,), fill_value=0.0),
new_verts_per_mesh.cumsum(0)[:-1],
],
dim=0,
)
faces_list = [
faces_list[n] - new_verts_per_mesh_cumsum[n] for n in range(self._N)
]
if feats is not None:
feats_list = new_feats.split(new_verts_per_mesh.tolist(), 0)
new_meshes = Meshes(verts=verts_list, faces=faces_list)
if feats is None:
return new_meshes
else:
new_feats = torch.cat(feats_list, dim=0)
return new_meshes, new_feats
def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
"""
Helper function to group the vertex indices for each mesh. New vertices are
stacked at the end of the original verts tensor, so in order to have
sequential packing, the verts tensor needs to be reordered so that the
vertices corresponding to each mesh are grouped together.
Args:
verts_per_mesh: Tensor of shape (N,) giving the number of vertices
in each mesh in the batch where N is the batch size.
edges_per_mesh: Tensor of shape (N,) giving the number of edges
in each mesh in the batch
Returns:
verts_idx: A tensor with vert indices for each mesh ordered sequentially
by mesh index.
"""
# e.g. verts_per_mesh = (4, 5, 6)
# e.g. edges_per_mesh = (5, 7, 9)
V = verts_per_mesh.sum() # e.g. 15
E = edges_per_mesh.sum() # e.g. 21
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
edges_per_mesh_cumsum = edges_per_mesh.cumsum(
dim=0
) # (N,) e.g. (5, 12, 21)
v_to_e_idx = verts_per_mesh_cumsum.clone()
# vertex to edge index.
v_to_e_idx[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
# vertex to edge offset.
v_to_e_offset = (
V - verts_per_mesh_cumsum
) # e.g. 15 - (4, 9, 15) = (11, 6, 0)
v_to_e_offset[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
e_to_v_idx = (
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
) # (4, 9) + (5, 12) = (9, 21)
e_to_v_offset = (
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
) # (4, 9) - (5, 12) - 15 = (-16, -18)
# Add one new vertex per edge.
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
idx_diffs[v_to_e_idx] += v_to_e_offset
idx_diffs[e_to_v_idx] += e_to_v_offset
# e.g.
# [
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
# ]
verts_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
# 4, 5, 6, 7, 8, 20, 21, 22, 23, 24, 25, 26, --> mesh 1
# 9, 10, 11, 12, 13, 14, 27, 28, 29, 30, 31, 32, 33, 34, 35 --> mesh 2
# ]
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
return verts_idx
def create_faces_index(faces_per_mesh, device=None):
"""
Helper function to group the faces indices for each mesh. New faces are
stacked at the end of the original faces tensor, so in order to have
sequential packing, the faces tensor needs to be reordered to that faces
corresponding to each mesh are grouped together.
Args:
faces_per_mesh: Tensor of shape (N,) giving the number of faces
in each mesh in the batch where N is the batch size.
Returns:
faces_idx: A tensor with face indices for each mesh ordered sequentially
by mesh index.
"""
# e.g. faces_per_mesh = [2, 5, 3]
F = faces_per_mesh.sum() # e.g. 10
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone()
switch1_idx[1:] += (
3 * faces_per_mesh_cumsum[:-1]
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
switch2_idx[1:] += (
2 * faces_per_mesh_cumsum[:-1]
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
switch3_idx[1:] += faces_per_mesh_cumsum[
:-1
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
idx_diffs[switch1_idx] += switch123_offset
idx_diffs[switch2_idx] += switch123_offset
idx_diffs[switch3_idx] += switch123_offset
idx_diffs[switch4_idx] -= 3 * F
# e.g
# [
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
# ]
faces_idx = idx_diffs.cumsum(dim=0) - 1
# e.g.
# [
# 0, 1, 10, 11, 20, 21, 30, 31,
# 2, 3, 4, 5, 6, 12, 13, 14, 15, 16, 22, 23, 24, 25, 26, 32, 33, 34, 35, 36,
# 7, 8, 9, 17, 18, 19, 27, 28, 29, 37, 38, 39
# ]
# where for mesh 0, [0, 1] are the indices of the existing faces, and
# [10, 11, 20, 21, 30, 31] are the indices of the new faces after subdivision.
return faces_idx

101
pytorch3d/ops/vert_align.py Normal file
View File

@@ -0,0 +1,101 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
import torch.nn.functional as F
def vert_align(
feats,
verts,
return_packed: bool = False,
interp_mode: str = "bilinear",
padding_mode: str = "zeros",
align_corners: bool = True,
) -> torch.Tensor:
"""
Sample vertex features from a feature map. This operation is called
"perceptual feaure pooling" in [1] or "vert align" in [2].
[1] Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single
RGB Images", ECCV 2018.
[2] Gkioxari et al, "Mesh R-CNN", ICCV 2019
Args:
feats: FloatTensor of shape (N, C, H, W) representing image features
from which to sample or a list of features each with potentially
different C, H or W dimensions.
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with
'verts_padded' as an attribute giving the (x, y, z) vertex positions
for which to sample. (x, y) verts should be normalized such that
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right
location in the input feature map.
return_packed: (bool) Indicates whether to return packed features
interp_mode: (str) Specifies how to interpolate features.
('bilinear' or 'nearest')
padding_mode: (str) Specifies how to handle vertices outside of the
[-1, 1] range. ('zeros', 'reflection', or 'border')
align_corners (bool): Geometrically, we consider the pixels of the
input as squares rather than points.
If set to ``True``, the extrema (``-1`` and ``1``) are considered as
referring to the center points of the input's corner pixels. If set
to ``False``, they are instead considered as referring to the corner
points of the input's corner pixels, making the sampling more
resolution agnostic. Default: ``True``
Returns:
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for
each vertex. If feats is a list, we return concatentated
features in axis=2 of shape (N, V, sum(C_n)) where
C_n = feats[n].shape[1]. If return_packed = True, the
features are transformed to a packed representation
of shape (sum(V), C)
"""
if torch.is_tensor(verts):
if verts.dim() != 3:
raise ValueError("verts tensor should be 3 dimensional")
grid = verts
elif hasattr(verts, "verts_padded"):
grid = verts.verts_padded()
else:
raise ValueError(
"verts must be a tensor or have a `verts_padded` attribute"
)
grid = grid[:, None, :, :2] # (N, 1, V, 2)
if torch.is_tensor(feats):
feats = [feats]
for feat in feats:
if feat.dim() != 4:
raise ValueError("feats must have shape (N, C, H, W)")
if grid.shape[0] != feat.shape[0]:
raise ValueError("inconsistent batch dimension")
feats_sampled = []
for feat in feats:
feat_sampled = F.grid_sample(
feat,
grid,
mode=interp_mode,
padding_mode=padding_mode,
align_corners=align_corners,
) # (N, C, 1, V)
feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2) # (N, V, C)
feats_sampled.append(feat_sampled)
feats_sampled = torch.cat(feats_sampled, dim=2) # (N, V, sum(C))
if return_packed:
# flatten the first two dimensions: (N*V, C)
feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1])
if hasattr(verts, "verts_padded_to_packed_idx"):
idx = (
verts.verts_padded_to_packed_idx()
.view(-1, 1)
.expand(-1, feats_sampled.shape[-1])
)
feats_sampled = feats_sampled.gather(0, idx) # (sum(V), C)
return feats_sampled