mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 03:40:34 +08:00
Initial commit
fbshipit-source-id: ad58e416e3ceeca85fae0583308968d04e78fe0d
This commit is contained in:
11
pytorch3d/ops/__init__.py
Normal file
11
pytorch3d/ops/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .cubify import cubify
|
||||
from .graph_conv import GraphConv
|
||||
from .nearest_neighbor_points import nn_points_idx
|
||||
from .sample_points_from_meshes import sample_points_from_meshes
|
||||
from .subdivide_meshes import SubdivideMeshes
|
||||
from .vert_align import vert_align
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
208
pytorch3d/ops/cubify.py
Normal file
208
pytorch3d/ops/cubify.py
Normal file
@@ -0,0 +1,208 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
def unravel_index(idx, dims) -> torch.Tensor:
|
||||
r"""
|
||||
Equivalent to np.unravel_index
|
||||
Args:
|
||||
idx: A LongTensor whose elements are indices into the
|
||||
flattened version of an array of dimensions dims.
|
||||
dims: The shape of the array to be indexed.
|
||||
Implemented only for dims=(N, H, W, D)
|
||||
"""
|
||||
if len(dims) != 4:
|
||||
raise ValueError("Expects a 4-element list.")
|
||||
N, H, W, D = dims
|
||||
n = torch.div(idx, H * W * D)
|
||||
h = torch.div(idx - n * H * W * D, W * D)
|
||||
w = torch.div(idx - n * H * W * D - h * W * D, D)
|
||||
d = idx - n * H * W * D - h * W * D - w * D
|
||||
return torch.stack((n, h, w, d), dim=1)
|
||||
|
||||
|
||||
def ravel_index(idx, dims) -> torch.Tensor:
|
||||
"""
|
||||
Computes the linear index in an array of shape dims.
|
||||
It performs the reverse functionality of unravel_index
|
||||
Args:
|
||||
idx: A LongTensor of shape (N, 3). Each row corresponds to indices into an
|
||||
array of dimensions dims.
|
||||
dims: The shape of the array to be indexed.
|
||||
Implemented only for dims=(H, W, D)
|
||||
"""
|
||||
if len(dims) != 3:
|
||||
raise ValueError("Expects a 3-element list")
|
||||
if idx.shape[1] != 3:
|
||||
raise ValueError("Expects an index tensor of shape Nx3")
|
||||
H, W, D = dims
|
||||
linind = idx[:, 0] * W * D + idx[:, 1] * D + idx[:, 2]
|
||||
return linind
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def cubify(voxels, thresh, device=None) -> Meshes:
|
||||
r"""
|
||||
Converts a voxel to a mesh by replacing each occupied voxel with a cube
|
||||
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
|
||||
internal faces are removed.
|
||||
Args:
|
||||
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
|
||||
thresh: A scalar threshold. If a voxel occupancy is larger than
|
||||
thresh, the voxel is considered occupied.
|
||||
Returns:
|
||||
meshes: A Meshes object of the corresponding meshes.
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = voxels.device
|
||||
|
||||
if len(voxels) == 0:
|
||||
return Meshes(verts=[], faces=[])
|
||||
|
||||
N, D, H, W = voxels.size()
|
||||
# vertices corresponding to a unit cube: 8x3
|
||||
cube_verts = torch.tensor(
|
||||
[
|
||||
[0, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[0, 1, 1],
|
||||
[1, 0, 0],
|
||||
[1, 0, 1],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1],
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# faces corresponding to a unit cube: 12x3
|
||||
cube_faces = torch.tensor(
|
||||
[
|
||||
[0, 1, 2],
|
||||
[1, 3, 2], # left face: 0, 1
|
||||
[2, 3, 6],
|
||||
[3, 7, 6], # bottom face: 2, 3
|
||||
[0, 2, 6],
|
||||
[0, 6, 4], # front face: 4, 5
|
||||
[0, 5, 1],
|
||||
[0, 4, 5], # up face: 6, 7
|
||||
[6, 7, 5],
|
||||
[6, 5, 4], # right face: 8, 9
|
||||
[1, 7, 3],
|
||||
[1, 5, 7], # back face: 10, 11
|
||||
],
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
wx = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 1, 2)
|
||||
wy = torch.tensor([0.5, 0.5], device=device).view(1, 1, 1, 2, 1)
|
||||
wz = torch.tensor([0.5, 0.5], device=device).view(1, 1, 2, 1, 1)
|
||||
|
||||
voxelt = voxels.ge(thresh).float()
|
||||
# N x 1 x D x H x W
|
||||
voxelt = voxelt.view(N, 1, D, H, W)
|
||||
|
||||
# N x 1 x (D-1) x (H-1) x (W-1)
|
||||
voxelt_x = F.conv3d(voxelt, wx).gt(0.5).float()
|
||||
voxelt_y = F.conv3d(voxelt, wy).gt(0.5).float()
|
||||
voxelt_z = F.conv3d(voxelt, wz).gt(0.5).float()
|
||||
|
||||
# 12 x N x 1 x D x H x W
|
||||
faces_idx = torch.ones((cube_faces.size(0), N, 1, D, H, W), device=device)
|
||||
|
||||
# add left face
|
||||
faces_idx[0, :, :, :, :, 1:] = 1 - voxelt_x
|
||||
faces_idx[1, :, :, :, :, 1:] = 1 - voxelt_x
|
||||
# add bottom face
|
||||
faces_idx[2, :, :, :, :-1, :] = 1 - voxelt_y
|
||||
faces_idx[3, :, :, :, :-1, :] = 1 - voxelt_y
|
||||
# add front face
|
||||
faces_idx[4, :, :, 1:, :, :] = 1 - voxelt_z
|
||||
faces_idx[5, :, :, 1:, :, :] = 1 - voxelt_z
|
||||
# add up face
|
||||
faces_idx[6, :, :, :, 1:, :] = 1 - voxelt_y
|
||||
faces_idx[7, :, :, :, 1:, :] = 1 - voxelt_y
|
||||
# add right face
|
||||
faces_idx[8, :, :, :, :, :-1] = 1 - voxelt_x
|
||||
faces_idx[9, :, :, :, :, :-1] = 1 - voxelt_x
|
||||
# add back face
|
||||
faces_idx[10, :, :, :-1, :, :] = 1 - voxelt_z
|
||||
faces_idx[11, :, :, :-1, :, :] = 1 - voxelt_z
|
||||
|
||||
faces_idx *= voxelt
|
||||
|
||||
# N x H x W x D x 12
|
||||
faces_idx = faces_idx.permute(1, 2, 4, 5, 3, 0).squeeze(1)
|
||||
# (NHWD) x 12
|
||||
faces_idx = faces_idx.contiguous()
|
||||
faces_idx = faces_idx.view(-1, cube_faces.size(0))
|
||||
|
||||
# boolean to linear index
|
||||
# NF x 2
|
||||
linind = torch.nonzero(faces_idx)
|
||||
# NF x 4
|
||||
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
|
||||
|
||||
# NF x 3: faces
|
||||
faces = torch.index_select(cube_faces, 0, linind[:, 1])
|
||||
|
||||
grid_faces = []
|
||||
for d in range(cube_faces.size(1)):
|
||||
# NF x 3
|
||||
xyz = torch.index_select(cube_verts, 0, faces[:, d])
|
||||
permute_idx = torch.tensor([1, 0, 2], device=device)
|
||||
yxz = torch.index_select(xyz, 1, permute_idx)
|
||||
yxz += nyxz[:, 1:]
|
||||
# NF x 1
|
||||
temp = ravel_index(yxz, (H + 1, W + 1, D + 1))
|
||||
grid_faces.append(temp)
|
||||
# NF x 3
|
||||
grid_faces = torch.stack(grid_faces, dim=1)
|
||||
|
||||
y, x, z = torch.meshgrid(
|
||||
torch.arange(H + 1), torch.arange(W + 1), torch.arange(D + 1)
|
||||
)
|
||||
y = y.to(device=device, dtype=torch.float32)
|
||||
y = y * 2.0 / (H - 1.0) - 1.0
|
||||
x = x.to(device=device, dtype=torch.float32)
|
||||
x = x * 2.0 / (W - 1.0) - 1.0
|
||||
z = z.to(device=device, dtype=torch.float32)
|
||||
z = z * 2.0 / (D - 1.0) - 1.0
|
||||
# ((H+1)(W+1)(D+1)) x 3
|
||||
grid_verts = torch.stack((x, y, z), dim=3).view(-1, 3)
|
||||
|
||||
if len(nyxz) == 0:
|
||||
verts_list = [torch.tensor([], dtype=torch.float32, device=device)] * N
|
||||
faces_list = [torch.tensor([], dtype=torch.int64, device=device)] * N
|
||||
return Meshes(verts=verts_list, faces=faces_list)
|
||||
|
||||
num_verts = grid_verts.size(0)
|
||||
grid_faces += nyxz[:, 0].view(-1, 1) * num_verts
|
||||
idleverts = torch.ones(num_verts * N, dtype=torch.uint8, device=device)
|
||||
|
||||
idleverts.scatter_(0, grid_faces.flatten(), 0)
|
||||
grid_faces -= nyxz[:, 0].view(-1, 1) * num_verts
|
||||
split_size = torch.bincount(nyxz[:, 0], minlength=N)
|
||||
faces_list = list(torch.split(grid_faces, split_size.tolist(), 0))
|
||||
|
||||
idleverts = idleverts.view(N, num_verts)
|
||||
idlenum = idleverts.cumsum(1)
|
||||
|
||||
verts_list = [
|
||||
grid_verts.index_select(0, (idleverts[n] == 0).nonzero()[:, 0])
|
||||
for n in range(N)
|
||||
]
|
||||
faces_list = [
|
||||
nface - idlenum[n][nface] for n, nface in enumerate(faces_list)
|
||||
]
|
||||
|
||||
return Meshes(verts=verts_list, faces=faces_list)
|
||||
174
pytorch3d/ops/graph_conv.py
Normal file
174
pytorch3d/ops/graph_conv.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.autograd.function import once_differentiable
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
class GraphConv(nn.Module):
|
||||
"""A single graph convolution layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
init: str = "normal",
|
||||
directed: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_dim: Number of input features per vertex.
|
||||
output_dim: Number of output features per vertex.
|
||||
init: Weight initialization method. Can be one of ['zero', 'normal'].
|
||||
directed: Bool indicating if edges in the graph are directed.
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.directed = directed
|
||||
self.w0 = nn.Linear(input_dim, output_dim)
|
||||
self.w1 = nn.Linear(input_dim, output_dim)
|
||||
|
||||
if init == "normal":
|
||||
nn.init.normal_(self.w0.weight, mean=0, std=0.01)
|
||||
nn.init.normal_(self.w1.weight, mean=0, std=0.01)
|
||||
self.w0.bias.data.zero_()
|
||||
self.w1.bias.data.zero_()
|
||||
elif init == "zero":
|
||||
self.w0.weight.data.zero_()
|
||||
self.w1.weight.data.zero_()
|
||||
else:
|
||||
raise ValueError('Invalid GraphConv initialization "%s"' % init)
|
||||
|
||||
def forward(self, verts, edges):
|
||||
"""
|
||||
Args:
|
||||
verts: FloatTensor of shape (V, input_dim) where V is the number of
|
||||
vertices and input_dim is the number of input features
|
||||
per vertex. input_dim has to match the input_dim specified
|
||||
in __init__.
|
||||
edges: LongTensor of shape (E, 2) where E is the number of edges
|
||||
where each edge has the indices of the two vertices which
|
||||
form the edge.
|
||||
|
||||
Returns:
|
||||
out: FloatTensor of shape (V, output_dim) where output_dim is the
|
||||
number of output features per vertex.
|
||||
"""
|
||||
if verts.is_cuda != edges.is_cuda:
|
||||
raise ValueError(
|
||||
"verts and edges tensors must be on the same device."
|
||||
)
|
||||
if verts.shape[0] == 0:
|
||||
# empty graph.
|
||||
return verts.sum() * 0.0
|
||||
|
||||
verts_w0 = self.w0(verts) # (V, output_dim)
|
||||
verts_w1 = self.w1(verts) # (V, output_dim)
|
||||
|
||||
if torch.cuda.is_available() and verts.is_cuda and edges.is_cuda:
|
||||
neighbor_sums = gather_scatter(verts_w1, edges, self.directed)
|
||||
else:
|
||||
neighbor_sums = gather_scatter_python(
|
||||
verts_w1, edges, self.directed
|
||||
) # (V, output_dim)
|
||||
|
||||
# Add neighbor features to each vertex's features.
|
||||
out = verts_w0 + neighbor_sums
|
||||
return out
|
||||
|
||||
def __repr__(self):
|
||||
Din, Dout, directed = self.input_dim, self.output_dim, self.directed
|
||||
return "GraphConv(%d -> %d, directed=%r)" % (Din, Dout, directed)
|
||||
|
||||
|
||||
def gather_scatter_python(input, edges, directed: bool = False):
|
||||
"""
|
||||
Python implementation of gather_scatter for aggregating features of
|
||||
neighbor nodes in a graph.
|
||||
|
||||
Given a directed graph: v0 -> v1 -> v2 the updated feature for v1 depends
|
||||
on v2 in order to be consistent with Morris et al. AAAI 2019
|
||||
(https://arxiv.org/abs/1810.02244). This only affects
|
||||
directed graphs; for undirected graphs v1 will depend on both v0 and v2,
|
||||
no matter which way the edges are physically stored.
|
||||
|
||||
Args:
|
||||
input: Tensor of shape (num_vertices, input_dim).
|
||||
edges: Tensor of edge indices of shape (num_edges, 2).
|
||||
directed: bool indicating if edges are directed.
|
||||
|
||||
Returns:
|
||||
output: Tensor of same shape as input.
|
||||
"""
|
||||
if not (input.dim() == 2):
|
||||
raise ValueError("input can only have 2 dimensions.")
|
||||
if not (edges.dim() == 2):
|
||||
raise ValueError("edges can only have 2 dimensions.")
|
||||
if not (edges.shape[1] == 2):
|
||||
raise ValueError("edges must be of shape (num_edges, 2).")
|
||||
|
||||
num_vertices, input_feature_dim = input.shape
|
||||
num_edges = edges.shape[0]
|
||||
output = torch.zeros_like(input)
|
||||
idx0 = edges[:, 0].view(num_edges, 1).expand(num_edges, input_feature_dim)
|
||||
idx1 = edges[:, 1].view(num_edges, 1).expand(num_edges, input_feature_dim)
|
||||
|
||||
output = output.scatter_add(0, idx0, input.gather(0, idx1))
|
||||
if not directed:
|
||||
output = output.scatter_add(0, idx1, input.gather(0, idx0))
|
||||
return output
|
||||
|
||||
|
||||
class GatherScatter(Function):
|
||||
"""
|
||||
Torch autograd Function wrapper for gather_scatter C++/CUDA implementations.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, edges, directed=False):
|
||||
"""
|
||||
Args:
|
||||
ctx: Context object used to calculate gradients.
|
||||
input: Tensor of shape (num_vertices, input_dim)
|
||||
edges: Tensor of edge indices of shape (num_edges, 2)
|
||||
directed: Bool indicating if edges are directed.
|
||||
|
||||
Returns:
|
||||
output: Tensor of same shape as input.
|
||||
"""
|
||||
if not (input.dim() == 2):
|
||||
raise ValueError("input can only have 2 dimensions.")
|
||||
if not (edges.dim() == 2):
|
||||
raise ValueError("edges can only have 2 dimensions.")
|
||||
if not (edges.shape[1] == 2):
|
||||
raise ValueError("edges must be of shape (num_edges, 2).")
|
||||
if not (input.dtype == torch.float32):
|
||||
raise ValueError("input has to be of type torch.float32.")
|
||||
|
||||
ctx.directed = directed
|
||||
input, edges = input.contiguous(), edges.contiguous()
|
||||
ctx.save_for_backward(edges)
|
||||
backward = False
|
||||
output = _C.gather_scatter(input, edges, directed, backward)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
grad_output = grad_output.contiguous()
|
||||
edges = ctx.saved_tensors[0]
|
||||
directed = ctx.directed
|
||||
backward = True
|
||||
grad_input = _C.gather_scatter(grad_output, edges, directed, backward)
|
||||
grad_edges = None
|
||||
grad_directed = None
|
||||
return grad_input, grad_edges, grad_directed
|
||||
|
||||
|
||||
gather_scatter = GatherScatter.apply
|
||||
47
pytorch3d/ops/nearest_neighbor_points.py
Normal file
47
pytorch3d/ops/nearest_neighbor_points.py
Normal file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
def nn_points_idx(p1, p2, p2_normals=None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the coordinates of nearest neighbors in pointcloud p2 to points in p1.
|
||||
Args:
|
||||
p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
|
||||
containing P1 points of dimension D.
|
||||
p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
|
||||
containing P2 points of dimension D.
|
||||
p2_normals: [optional] FloatTensor of shape (N, P2, D) giving
|
||||
normals for p2. Default: None.
|
||||
|
||||
Returns:
|
||||
3-element tuple containing
|
||||
|
||||
- **p1_nn_points**: FloatTensor of shape (N, P1, D) where
|
||||
p1_neighbors[n, i] is the point in p2[n] which is
|
||||
the nearest neighbor to p1[n, i].
|
||||
- **p1_nn_idx**: LongTensor of shape (N, P1) giving the indices of
|
||||
the neighbors.
|
||||
- **p1_nn_normals**: Normal vectors for each point in p1_neighbors;
|
||||
only returned if p2_normals is passed
|
||||
else return [].
|
||||
"""
|
||||
N, P1, D = p1.shape
|
||||
with torch.no_grad():
|
||||
p1_nn_idx = _C.nn_points_idx(
|
||||
p1.contiguous(), p2.contiguous()
|
||||
) # (N, P1)
|
||||
p1_nn_idx_expanded = p1_nn_idx.view(N, P1, 1).expand(N, P1, D)
|
||||
p1_nn_points = p2.gather(1, p1_nn_idx_expanded)
|
||||
if p2_normals is None:
|
||||
p1_nn_normals = []
|
||||
else:
|
||||
if p2_normals.shape != p2.shape:
|
||||
raise ValueError("p2_normals has incorrect shape.")
|
||||
p1_nn_normals = p2_normals.gather(1, p1_nn_idx_expanded)
|
||||
|
||||
return p1_nn_points, p1_nn_idx, p1_nn_normals
|
||||
127
pytorch3d/ops/sample_points_from_meshes.py
Normal file
127
pytorch3d/ops/sample_points_from_meshes.py
Normal file
@@ -0,0 +1,127 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
"""
|
||||
This module implements utility functions for sampling points from
|
||||
batches of meshes.
|
||||
"""
|
||||
import sys
|
||||
from typing import Tuple, Union
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
|
||||
|
||||
def sample_points_from_meshes(
|
||||
meshes, num_samples: int = 10000, return_normals: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Convert a batch of meshes to a pointcloud by uniformly sampling points on
|
||||
the surface of the mesh with probability proportional to the face area.
|
||||
|
||||
Args:
|
||||
meshes: A Meshes object with a batch of N meshes.
|
||||
num_samples: Integer giving the number of point samples per mesh.
|
||||
return_normals: If True, return normals for the sampled points.
|
||||
eps: (float) used to clamp the norm of the normals to avoid dividing by 0.
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
|
||||
- **samples**: FloatTensor of shape (N, num_samples, 3) giving the
|
||||
coordinates of sampled points for each mesh in the batch. For empty
|
||||
meshes the corresponding row in the samples array will be filled with 0.
|
||||
- **normals**: FloatTensor of shape (N, num_samples, 3) giving a normal vector
|
||||
to each sampled point. Only returned if return_normals is True.
|
||||
For empty meshes the corresponding row in the normals array will
|
||||
be filled with 0.
|
||||
"""
|
||||
if meshes.isempty():
|
||||
raise ValueError("Meshes are empty.")
|
||||
|
||||
verts = meshes.verts_packed()
|
||||
faces = meshes.faces_packed()
|
||||
mesh_to_face = meshes.mesh_to_faces_packed_first_idx()
|
||||
num_meshes = len(meshes)
|
||||
num_valid_meshes = torch.sum(meshes.valid) # Non empty meshes.
|
||||
|
||||
# Intialize samples tensor with fill value 0 for empty meshes.
|
||||
samples = torch.zeros((num_meshes, num_samples, 3), device=meshes.device)
|
||||
|
||||
# Only compute samples for non empty meshes
|
||||
with torch.no_grad():
|
||||
areas, _ = _C.face_areas_normals(
|
||||
verts, faces
|
||||
) # Face areas can be zero.
|
||||
max_faces = meshes.num_faces_per_mesh().max().item()
|
||||
areas_padded = _C.packed_to_padded_tensor(
|
||||
areas, mesh_to_face[meshes.valid], max_faces
|
||||
) # (N, F)
|
||||
|
||||
# TODO (gkioxari) Confirm multinomial bug is not present with real data.
|
||||
sample_face_idxs = areas_padded.multinomial(
|
||||
num_samples, replacement=True
|
||||
) # (N, num_samples)
|
||||
sample_face_idxs += mesh_to_face[meshes.valid].view(num_valid_meshes, 1)
|
||||
|
||||
# Get the vertex coordinates of the sampled faces.
|
||||
face_verts = verts[faces.long()]
|
||||
v0, v1, v2 = face_verts[:, 0], face_verts[:, 1], face_verts[:, 2]
|
||||
|
||||
# Randomly generate barycentric coords.
|
||||
w0, w1, w2 = _rand_barycentric_coords(
|
||||
num_valid_meshes, num_samples, verts.dtype, verts.device
|
||||
)
|
||||
|
||||
# Use the barycentric coords to get a point on each sampled face.
|
||||
a = v0[sample_face_idxs] # (N, num_samples, 3)
|
||||
b = v1[sample_face_idxs]
|
||||
c = v2[sample_face_idxs]
|
||||
samples[meshes.valid] = (
|
||||
w0[:, :, None] * a + w1[:, :, None] * b + w2[:, :, None] * c
|
||||
)
|
||||
|
||||
if return_normals:
|
||||
# Intialize normals tensor with fill value 0 for empty meshes.
|
||||
# Normals for the sampled points are face normals computed from
|
||||
# the vertices of the face in which the sampled point lies.
|
||||
normals = torch.zeros(
|
||||
(num_meshes, num_samples, 3), device=meshes.device
|
||||
)
|
||||
vert_normals = (v1 - v0).cross(v2 - v1, dim=1)
|
||||
vert_normals = vert_normals / vert_normals.norm(
|
||||
dim=1, p=2, keepdim=True
|
||||
).clamp(min=sys.float_info.epsilon)
|
||||
vert_normals = vert_normals[sample_face_idxs]
|
||||
normals[meshes.valid] = vert_normals
|
||||
|
||||
return samples, normals
|
||||
else:
|
||||
return samples
|
||||
|
||||
|
||||
def _rand_barycentric_coords(
|
||||
size1, size2, dtype, device
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Helper function to generate random barycentric coordinates which are uniformly
|
||||
distributed over a triangle.
|
||||
|
||||
Args:
|
||||
size1, size2: The number of coordinates generated will be size1*size2.
|
||||
Output tensors will each be of shape (size1, size2).
|
||||
dtype: Datatype to generate.
|
||||
device: A torch.device object on which the outputs will be allocated.
|
||||
|
||||
Returns:
|
||||
w0, w1, w2: Tensors of shape (size1, size2) giving random barycentric
|
||||
coordinates
|
||||
"""
|
||||
uv = torch.rand(2, size1, size2, dtype=dtype, device=device)
|
||||
u, v = uv[0], uv[1]
|
||||
u_sqrt = u.sqrt()
|
||||
w0 = 1.0 - u_sqrt
|
||||
w1 = u_sqrt * (1.0 - v)
|
||||
w2 = u_sqrt * v
|
||||
return w0, w1, w2
|
||||
479
pytorch3d/ops/subdivide_meshes.py
Normal file
479
pytorch3d/ops/subdivide_meshes.py
Normal file
@@ -0,0 +1,479 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
class SubdivideMeshes(nn.Module):
|
||||
"""
|
||||
Subdivide a triangle mesh by adding a new vertex at the center of each edge
|
||||
and dividing each face into four new faces. Vectors of vertex
|
||||
attributes can also be subdivided by averaging the values of the attributes
|
||||
at the two vertices which form each edge. This implementation
|
||||
preserves face orientation - if the vertices of a face are all ordered
|
||||
counter-clockwise, then the faces in the subdivided meshes will also have
|
||||
their vertices ordered counter-clockwise.
|
||||
|
||||
If meshes is provided as an input, the initializer performs the relatively
|
||||
expensive computation of determining the new face indices. This one-time
|
||||
computation can be reused for all meshes with the same face topology
|
||||
but different vertex positions.
|
||||
"""
|
||||
|
||||
def __init__(self, meshes=None):
|
||||
"""
|
||||
Args:
|
||||
meshes: Meshes object or None. If a meshes object is provided,
|
||||
the first mesh is used to compute the new faces of the
|
||||
subdivided topology which can be reused for meshes with
|
||||
the same input topology.
|
||||
"""
|
||||
super(SubdivideMeshes, self).__init__()
|
||||
|
||||
self.precomputed = False
|
||||
self._N = -1
|
||||
if meshes is not None:
|
||||
# This computation is on indices, so gradients do not need to be
|
||||
# tracked.
|
||||
mesh = meshes[0]
|
||||
with torch.no_grad():
|
||||
subdivided_faces = self.subdivide_faces(mesh)
|
||||
if subdivided_faces.shape[1] != 3:
|
||||
raise ValueError("faces can only have three vertices")
|
||||
self.register_buffer("_subdivided_faces", subdivided_faces)
|
||||
self.precomputed = True
|
||||
|
||||
def subdivide_faces(self, meshes):
|
||||
r"""
|
||||
Args:
|
||||
meshes: a Meshes object.
|
||||
|
||||
Returns:
|
||||
subdivided_faces_packed: (4*sum(F_n), 3) shape LongTensor of
|
||||
original and new faces.
|
||||
|
||||
Refer to pytorch3d.structures.meshes.py for more details on packed
|
||||
representations of faces.
|
||||
|
||||
Each face is split into 4 faces e.g. Input face
|
||||
::
|
||||
v0
|
||||
/\
|
||||
/ \
|
||||
/ \
|
||||
e1 / \ e0
|
||||
/ \
|
||||
/ \
|
||||
/ \
|
||||
/______________\
|
||||
v2 e2 v1
|
||||
|
||||
faces_packed = [[0, 1, 2]]
|
||||
faces_packed_to_edges_packed = [[2, 1, 0]]
|
||||
|
||||
`faces_packed_to_edges_packed` is used to represent all the new
|
||||
vertex indices corresponding to the mid-points of edges in the mesh.
|
||||
The actual vertex coordinates will be computed in the forward function.
|
||||
To get the indices of the new vertices, offset
|
||||
`faces_packed_to_edges_packed` by the total number of vertices.
|
||||
::
|
||||
faces_packed_to_edges_packed = [[2, 1, 0]] + 3 = [[5, 4, 3]]
|
||||
|
||||
e.g. subdivided face
|
||||
::
|
||||
v0
|
||||
/\
|
||||
/ \
|
||||
/ f0 \
|
||||
v4 /______\ v3
|
||||
/\ /\
|
||||
/ \ f3 / \
|
||||
/ f2 \ / f1 \
|
||||
/______\/______\
|
||||
v2 v5 v1
|
||||
|
||||
f0 = [0, 3, 4]
|
||||
f1 = [1, 5, 3]
|
||||
f2 = [2, 4, 5]
|
||||
f3 = [5, 4, 3]
|
||||
|
||||
"""
|
||||
verts_packed = meshes.verts_packed()
|
||||
with torch.no_grad():
|
||||
faces_packed = meshes.faces_packed()
|
||||
faces_packed_to_edges_packed = meshes.faces_packed_to_edges_packed()
|
||||
faces_packed_to_edges_packed += verts_packed.shape[0]
|
||||
|
||||
f0 = torch.stack(
|
||||
[
|
||||
faces_packed[:, 0],
|
||||
faces_packed_to_edges_packed[:, 2],
|
||||
faces_packed_to_edges_packed[:, 1],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
f1 = torch.stack(
|
||||
[
|
||||
faces_packed[:, 1],
|
||||
faces_packed_to_edges_packed[:, 0],
|
||||
faces_packed_to_edges_packed[:, 2],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
f2 = torch.stack(
|
||||
[
|
||||
faces_packed[:, 2],
|
||||
faces_packed_to_edges_packed[:, 1],
|
||||
faces_packed_to_edges_packed[:, 0],
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
f3 = faces_packed_to_edges_packed
|
||||
subdivided_faces_packed = torch.cat(
|
||||
[f0, f1, f2, f3], dim=0
|
||||
) # (4*sum(F_n), 3)
|
||||
|
||||
return subdivided_faces_packed
|
||||
|
||||
def forward(self, meshes, feats=None):
|
||||
"""
|
||||
Subdivide a batch of meshes by adding a new vertex on each edge, and
|
||||
dividing each face into four new faces. New meshes contains two types
|
||||
of vertices:
|
||||
1) Vertices that appear in the input meshes.
|
||||
Data for these vertices are copied from the input meshes.
|
||||
2) New vertices at the midpoint of each edge.
|
||||
Data for these vertices is the average of the data for the two
|
||||
vertices that make up the edge.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object representing a batch of meshes.
|
||||
feats: Per-vertex features to be subdivided along with the verts.
|
||||
Should be parallel to the packed vert representation of the
|
||||
input meshes; so it should have shape (V, D) where V is the
|
||||
total number of verts in the input meshes. Default: None.
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
|
||||
- **new_meshes**: Meshes object of a batch of subdivided meshes.
|
||||
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
|
||||
(packed) vertices of the subdivided meshes. Only returned
|
||||
if feats is not None.
|
||||
|
||||
"""
|
||||
self._N = len(meshes)
|
||||
if self.precomputed:
|
||||
return self.subdivide_homogeneous(meshes, feats)
|
||||
else:
|
||||
return self.subdivide_heterogenerous(meshes, feats)
|
||||
|
||||
def subdivide_homogeneous(self, meshes, feats=None):
|
||||
"""
|
||||
Subdivide verts (and optionally features) of a batch of meshes
|
||||
where each mesh has the same topology of faces. The subdivided faces
|
||||
are precomputed in the initializer.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object representing a batch of meshes.
|
||||
feats: Per-vertex features to be subdivided along with the verts.
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
|
||||
- **new_meshes**: Meshes object of a batch of subdivided meshes.
|
||||
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
|
||||
(packed) vertices of the subdivided meshes. Only returned
|
||||
if feats is not None.
|
||||
"""
|
||||
verts = meshes.verts_padded() # (N, V, D)
|
||||
edges = meshes[0].edges_packed()
|
||||
|
||||
# The set of faces is the same across the different meshes.
|
||||
new_faces = self._subdivided_faces.view(1, -1, 3).expand(
|
||||
self._N, -1, -1
|
||||
)
|
||||
|
||||
# Add one new vertex at the midpoint of each edge by taking the average
|
||||
# of the vertices that form each edge.
|
||||
new_verts = verts[:, edges].mean(dim=2)
|
||||
new_verts = torch.cat(
|
||||
[verts, new_verts], dim=1
|
||||
) # (sum(V_n)+sum(E_n), 3)
|
||||
new_feats = None
|
||||
|
||||
# Calculate features for new vertices.
|
||||
if feats is not None:
|
||||
if feats.dim() == 2:
|
||||
# feats is in packed format, transform it from packed to
|
||||
# padded, i.e. (N*V, D) to (N, V, D).
|
||||
feats = feats.view(verts.size(0), verts.size(1), feats.size(1))
|
||||
if feats.dim() != 3:
|
||||
raise ValueError(
|
||||
"features need to be of shape (N, V, D) or (N*V, D)"
|
||||
)
|
||||
|
||||
# Take average of the features at the vertices that form each edge.
|
||||
new_feats = feats[:, edges].mean(dim=2)
|
||||
new_feats = torch.cat(
|
||||
[feats, new_feats], dim=1
|
||||
) # (sum(V_n)+sum(E_n), 3)
|
||||
|
||||
new_meshes = Meshes(verts=new_verts, faces=new_faces)
|
||||
|
||||
if feats is None:
|
||||
return new_meshes
|
||||
else:
|
||||
return new_meshes, new_feats
|
||||
|
||||
def subdivide_heterogenerous(self, meshes, feats=None):
|
||||
"""
|
||||
Subdivide faces, verts (and optionally features) of a batch of meshes
|
||||
where each mesh can have different face topologies.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object representing a batch of meshes.
|
||||
feats: Per-vertex features to be subdivided along with the verts.
|
||||
|
||||
Returns:
|
||||
2-element tuple containing
|
||||
|
||||
- **new_meshes**: Meshes object of a batch of subdivided meshes.
|
||||
- **new_feats**: (optional) Tensor of subdivided feats, parallel to the
|
||||
(packed) vertices of the subdivided meshes. Only returned
|
||||
if feats is not None.
|
||||
"""
|
||||
|
||||
# The computation of new faces is on face indices, so gradients do not
|
||||
# need to be tracked.
|
||||
verts = meshes.verts_packed()
|
||||
with torch.no_grad():
|
||||
new_faces = self.subdivide_faces(meshes)
|
||||
edges = meshes.edges_packed()
|
||||
face_to_mesh_idx = meshes.faces_packed_to_mesh_idx()
|
||||
edge_to_mesh_idx = meshes.edges_packed_to_mesh_idx()
|
||||
num_edges_per_mesh = edge_to_mesh_idx.bincount(minlength=self._N)
|
||||
num_verts_per_mesh = meshes.num_verts_per_mesh()
|
||||
num_faces_per_mesh = meshes.num_faces_per_mesh()
|
||||
|
||||
# Add one new vertex at the midpoint of each edge.
|
||||
new_verts_per_mesh = num_verts_per_mesh + num_edges_per_mesh # (N,)
|
||||
new_face_to_mesh_idx = torch.cat([face_to_mesh_idx] * 4, dim=0)
|
||||
|
||||
# Calculate the indices needed to group the new and existing verts
|
||||
# for each mesh.
|
||||
verts_sort_idx = create_verts_index(
|
||||
num_verts_per_mesh, num_edges_per_mesh, meshes.device
|
||||
) # (sum(V_n)+sum(E_n),)
|
||||
|
||||
verts_ordered_idx_init = torch.zeros(
|
||||
new_verts_per_mesh.sum(),
|
||||
dtype=torch.int64,
|
||||
device=meshes.device,
|
||||
) # (sum(V_n)+sum(E_n),)
|
||||
|
||||
# Reassign vertex indices so that existing and new vertices for each
|
||||
# mesh are sequential.
|
||||
verts_ordered_idx = verts_ordered_idx_init.scatter_add(
|
||||
0,
|
||||
verts_sort_idx,
|
||||
torch.arange(new_verts_per_mesh.sum(), device=meshes.device),
|
||||
)
|
||||
|
||||
# Retrieve vertex indices for each face.
|
||||
new_faces = verts_ordered_idx[new_faces]
|
||||
|
||||
# Calculate the indices needed to group the existing and new faces
|
||||
# for each mesh.
|
||||
face_sort_idx = create_faces_index(
|
||||
num_faces_per_mesh, device=meshes.device
|
||||
)
|
||||
|
||||
# Reorder the faces to sequentially group existing and new faces
|
||||
# for each mesh.
|
||||
new_faces = new_faces[face_sort_idx]
|
||||
new_face_to_mesh_idx = new_face_to_mesh_idx[face_sort_idx]
|
||||
new_faces_per_mesh = new_face_to_mesh_idx.bincount(
|
||||
minlength=self._N
|
||||
) # (sum(F_n)*4)
|
||||
|
||||
# Add one new vertex at the midpoint of each edge by taking the average
|
||||
# of the verts that form each edge.
|
||||
new_verts = verts[edges].mean(dim=1)
|
||||
new_verts = torch.cat([verts, new_verts], dim=0)
|
||||
|
||||
# Reorder the verts to sequentially group existing and new verts for
|
||||
# each mesh.
|
||||
new_verts = new_verts[verts_sort_idx]
|
||||
|
||||
if feats is not None:
|
||||
new_feats = feats[edges].mean(dim=1)
|
||||
new_feats = torch.cat([feats, new_feats], dim=0)
|
||||
new_feats = new_feats[verts_sort_idx]
|
||||
|
||||
verts_list = list(new_verts.split(new_verts_per_mesh.tolist(), 0))
|
||||
faces_list = list(new_faces.split(new_faces_per_mesh.tolist(), 0))
|
||||
new_verts_per_mesh_cumsum = torch.cat(
|
||||
[
|
||||
new_verts_per_mesh.new_full(size=(1,), fill_value=0.0),
|
||||
new_verts_per_mesh.cumsum(0)[:-1],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
faces_list = [
|
||||
faces_list[n] - new_verts_per_mesh_cumsum[n] for n in range(self._N)
|
||||
]
|
||||
if feats is not None:
|
||||
feats_list = new_feats.split(new_verts_per_mesh.tolist(), 0)
|
||||
new_meshes = Meshes(verts=verts_list, faces=faces_list)
|
||||
|
||||
if feats is None:
|
||||
return new_meshes
|
||||
else:
|
||||
new_feats = torch.cat(feats_list, dim=0)
|
||||
return new_meshes, new_feats
|
||||
|
||||
|
||||
def create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
|
||||
"""
|
||||
Helper function to group the vertex indices for each mesh. New vertices are
|
||||
stacked at the end of the original verts tensor, so in order to have
|
||||
sequential packing, the verts tensor needs to be reordered so that the
|
||||
vertices corresponding to each mesh are grouped together.
|
||||
|
||||
Args:
|
||||
verts_per_mesh: Tensor of shape (N,) giving the number of vertices
|
||||
in each mesh in the batch where N is the batch size.
|
||||
edges_per_mesh: Tensor of shape (N,) giving the number of edges
|
||||
in each mesh in the batch
|
||||
|
||||
Returns:
|
||||
verts_idx: A tensor with vert indices for each mesh ordered sequentially
|
||||
by mesh index.
|
||||
"""
|
||||
# e.g. verts_per_mesh = (4, 5, 6)
|
||||
# e.g. edges_per_mesh = (5, 7, 9)
|
||||
|
||||
V = verts_per_mesh.sum() # e.g. 15
|
||||
E = edges_per_mesh.sum() # e.g. 21
|
||||
|
||||
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
|
||||
edges_per_mesh_cumsum = edges_per_mesh.cumsum(
|
||||
dim=0
|
||||
) # (N,) e.g. (5, 12, 21)
|
||||
|
||||
v_to_e_idx = verts_per_mesh_cumsum.clone()
|
||||
|
||||
# vertex to edge index.
|
||||
v_to_e_idx[1:] += edges_per_mesh_cumsum[
|
||||
:-1
|
||||
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
|
||||
|
||||
# vertex to edge offset.
|
||||
v_to_e_offset = (
|
||||
V - verts_per_mesh_cumsum
|
||||
) # e.g. 15 - (4, 9, 15) = (11, 6, 0)
|
||||
v_to_e_offset[1:] += edges_per_mesh_cumsum[
|
||||
:-1
|
||||
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
|
||||
e_to_v_idx = (
|
||||
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
|
||||
) # (4, 9) + (5, 12) = (9, 21)
|
||||
e_to_v_offset = (
|
||||
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
|
||||
) # (4, 9) - (5, 12) - 15 = (-16, -18)
|
||||
|
||||
# Add one new vertex per edge.
|
||||
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
|
||||
idx_diffs[v_to_e_idx] += v_to_e_offset
|
||||
idx_diffs[e_to_v_idx] += e_to_v_offset
|
||||
|
||||
# e.g.
|
||||
# [
|
||||
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
|
||||
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
|
||||
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
|
||||
# ]
|
||||
|
||||
verts_idx = idx_diffs.cumsum(dim=0) - 1
|
||||
|
||||
# e.g.
|
||||
# [
|
||||
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
|
||||
# 4, 5, 6, 7, 8, 20, 21, 22, 23, 24, 25, 26, --> mesh 1
|
||||
# 9, 10, 11, 12, 13, 14, 27, 28, 29, 30, 31, 32, 33, 34, 35 --> mesh 2
|
||||
# ]
|
||||
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
|
||||
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
|
||||
|
||||
return verts_idx
|
||||
|
||||
|
||||
def create_faces_index(faces_per_mesh, device=None):
|
||||
"""
|
||||
Helper function to group the faces indices for each mesh. New faces are
|
||||
stacked at the end of the original faces tensor, so in order to have
|
||||
sequential packing, the faces tensor needs to be reordered to that faces
|
||||
corresponding to each mesh are grouped together.
|
||||
|
||||
Args:
|
||||
faces_per_mesh: Tensor of shape (N,) giving the number of faces
|
||||
in each mesh in the batch where N is the batch size.
|
||||
|
||||
Returns:
|
||||
faces_idx: A tensor with face indices for each mesh ordered sequentially
|
||||
by mesh index.
|
||||
"""
|
||||
# e.g. faces_per_mesh = [2, 5, 3]
|
||||
|
||||
F = faces_per_mesh.sum() # e.g. 10
|
||||
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
|
||||
|
||||
switch1_idx = faces_per_mesh_cumsum.clone()
|
||||
switch1_idx[1:] += (
|
||||
3 * faces_per_mesh_cumsum[:-1]
|
||||
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
|
||||
|
||||
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
|
||||
switch2_idx[1:] += (
|
||||
2 * faces_per_mesh_cumsum[:-1]
|
||||
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
|
||||
|
||||
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
|
||||
switch3_idx[1:] += faces_per_mesh_cumsum[
|
||||
:-1
|
||||
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
|
||||
|
||||
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
|
||||
|
||||
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
|
||||
|
||||
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
|
||||
idx_diffs[switch1_idx] += switch123_offset
|
||||
idx_diffs[switch2_idx] += switch123_offset
|
||||
idx_diffs[switch3_idx] += switch123_offset
|
||||
idx_diffs[switch4_idx] -= 3 * F
|
||||
|
||||
# e.g
|
||||
# [
|
||||
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
|
||||
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
|
||||
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
|
||||
# ]
|
||||
|
||||
faces_idx = idx_diffs.cumsum(dim=0) - 1
|
||||
|
||||
# e.g.
|
||||
# [
|
||||
# 0, 1, 10, 11, 20, 21, 30, 31,
|
||||
# 2, 3, 4, 5, 6, 12, 13, 14, 15, 16, 22, 23, 24, 25, 26, 32, 33, 34, 35, 36,
|
||||
# 7, 8, 9, 17, 18, 19, 27, 28, 29, 37, 38, 39
|
||||
# ]
|
||||
# where for mesh 0, [0, 1] are the indices of the existing faces, and
|
||||
# [10, 11, 20, 21, 30, 31] are the indices of the new faces after subdivision.
|
||||
|
||||
return faces_idx
|
||||
101
pytorch3d/ops/vert_align.py
Normal file
101
pytorch3d/ops/vert_align.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def vert_align(
|
||||
feats,
|
||||
verts,
|
||||
return_packed: bool = False,
|
||||
interp_mode: str = "bilinear",
|
||||
padding_mode: str = "zeros",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample vertex features from a feature map. This operation is called
|
||||
"perceptual feaure pooling" in [1] or "vert align" in [2].
|
||||
|
||||
[1] Wang et al, "Pixel2Mesh: Generating 3D Mesh Models from Single
|
||||
RGB Images", ECCV 2018.
|
||||
[2] Gkioxari et al, "Mesh R-CNN", ICCV 2019
|
||||
|
||||
Args:
|
||||
feats: FloatTensor of shape (N, C, H, W) representing image features
|
||||
from which to sample or a list of features each with potentially
|
||||
different C, H or W dimensions.
|
||||
verts: FloatTensor of shape (N, V, 3) or an object (e.g. Meshes) with
|
||||
'verts_padded' as an attribute giving the (x, y, z) vertex positions
|
||||
for which to sample. (x, y) verts should be normalized such that
|
||||
(-1, -1) corresponds to top-left and (+1, +1) to bottom-right
|
||||
location in the input feature map.
|
||||
return_packed: (bool) Indicates whether to return packed features
|
||||
interp_mode: (str) Specifies how to interpolate features.
|
||||
('bilinear' or 'nearest')
|
||||
padding_mode: (str) Specifies how to handle vertices outside of the
|
||||
[-1, 1] range. ('zeros', 'reflection', or 'border')
|
||||
align_corners (bool): Geometrically, we consider the pixels of the
|
||||
input as squares rather than points.
|
||||
If set to ``True``, the extrema (``-1`` and ``1``) are considered as
|
||||
referring to the center points of the input's corner pixels. If set
|
||||
to ``False``, they are instead considered as referring to the corner
|
||||
points of the input's corner pixels, making the sampling more
|
||||
resolution agnostic. Default: ``True``
|
||||
|
||||
Returns:
|
||||
feats_sampled: FloatTensor of shape (N, V, C) giving sampled features for
|
||||
each vertex. If feats is a list, we return concatentated
|
||||
features in axis=2 of shape (N, V, sum(C_n)) where
|
||||
C_n = feats[n].shape[1]. If return_packed = True, the
|
||||
features are transformed to a packed representation
|
||||
of shape (sum(V), C)
|
||||
|
||||
"""
|
||||
if torch.is_tensor(verts):
|
||||
if verts.dim() != 3:
|
||||
raise ValueError("verts tensor should be 3 dimensional")
|
||||
grid = verts
|
||||
elif hasattr(verts, "verts_padded"):
|
||||
grid = verts.verts_padded()
|
||||
else:
|
||||
raise ValueError(
|
||||
"verts must be a tensor or have a `verts_padded` attribute"
|
||||
)
|
||||
|
||||
grid = grid[:, None, :, :2] # (N, 1, V, 2)
|
||||
|
||||
if torch.is_tensor(feats):
|
||||
feats = [feats]
|
||||
for feat in feats:
|
||||
if feat.dim() != 4:
|
||||
raise ValueError("feats must have shape (N, C, H, W)")
|
||||
if grid.shape[0] != feat.shape[0]:
|
||||
raise ValueError("inconsistent batch dimension")
|
||||
|
||||
feats_sampled = []
|
||||
for feat in feats:
|
||||
feat_sampled = F.grid_sample(
|
||||
feat,
|
||||
grid,
|
||||
mode=interp_mode,
|
||||
padding_mode=padding_mode,
|
||||
align_corners=align_corners,
|
||||
) # (N, C, 1, V)
|
||||
feat_sampled = feat_sampled.squeeze(dim=2).transpose(1, 2) # (N, V, C)
|
||||
feats_sampled.append(feat_sampled)
|
||||
feats_sampled = torch.cat(feats_sampled, dim=2) # (N, V, sum(C))
|
||||
|
||||
if return_packed:
|
||||
# flatten the first two dimensions: (N*V, C)
|
||||
feats_sampled = feats_sampled.view(-1, feats_sampled.shape[-1])
|
||||
if hasattr(verts, "verts_padded_to_packed_idx"):
|
||||
idx = (
|
||||
verts.verts_padded_to_packed_idx()
|
||||
.view(-1, 1)
|
||||
.expand(-1, feats_sampled.shape[-1])
|
||||
)
|
||||
feats_sampled = feats_sampled.gather(0, idx) # (sum(V), C)
|
||||
|
||||
return feats_sampled
|
||||
Reference in New Issue
Block a user