mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Support color in cubify
Summary: The diff support colors in cubify for align = "center" Reviewed By: bottler Differential Revision: D53777011 fbshipit-source-id: ccb2bd1e3d89be3d1ac943eff08f40e50b0540d9
This commit is contained in:
committed by
Facebook GitHub Bot
parent
8772fe0de8
commit
ae9d8787ce
@@ -5,9 +5,13 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch3d.common.compat import meshgrid_ij
|
||||
|
||||
from pytorch3d.structures import Meshes
|
||||
|
||||
|
||||
@@ -50,7 +54,14 @@ def ravel_index(idx, dims) -> torch.Tensor:
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
|
||||
def cubify(
|
||||
voxels: torch.Tensor,
|
||||
thresh: float,
|
||||
*,
|
||||
feats: Optional[torch.Tensor] = None,
|
||||
device=None,
|
||||
align: str = "topleft"
|
||||
) -> Meshes:
|
||||
r"""
|
||||
Converts a voxel to a mesh by replacing each occupied voxel with a cube
|
||||
consisting of 12 faces and 8 vertices. Shared vertices are merged, and
|
||||
@@ -59,6 +70,9 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
|
||||
voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
|
||||
thresh: A scalar threshold. If a voxel occupancy is larger than
|
||||
thresh, the voxel is considered occupied.
|
||||
feats: A FloatTensor of shape (N, K, D, H, W) containing the color information
|
||||
of each voxel. K is the number of channels. This is supported only when
|
||||
align == "center"
|
||||
device: The device of the output meshes
|
||||
align: Defines the alignment of the mesh vertices and the grid locations.
|
||||
Has to be one of {"topleft", "corner", "center"}. See below for explanation.
|
||||
@@ -177,6 +191,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
|
||||
# boolean to linear index
|
||||
# NF x 2
|
||||
linind = torch.nonzero(faces_idx, as_tuple=False)
|
||||
|
||||
# NF x 4
|
||||
nyxz = unravel_index(linind[:, 0], (N, H, W, D))
|
||||
|
||||
@@ -238,6 +253,21 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
|
||||
grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
|
||||
for n in range(N)
|
||||
]
|
||||
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
|
||||
|
||||
return Meshes(verts=verts_list, faces=faces_list)
|
||||
textures_list = None
|
||||
if feats is not None and align == "center":
|
||||
# We return a TexturesAtlas containing one color for each face
|
||||
# N x K x D x H x W -> N x H x W x D x K
|
||||
feats = feats.permute(0, 3, 4, 2, 1)
|
||||
|
||||
# (NHWD) x K
|
||||
feats = feats.reshape(-1, feats.size(4))
|
||||
feats = torch.index_select(feats, 0, linind[:, 0])
|
||||
feats = feats.reshape(-1, 1, 1, feats.size(1))
|
||||
feats_list = list(torch.split(feats, split_size.tolist(), 0))
|
||||
from pytorch3d.renderer.mesh.textures import TexturesAtlas
|
||||
|
||||
textures_list = TexturesAtlas(feats_list)
|
||||
|
||||
faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
|
||||
return Meshes(verts=verts_list, faces=faces_list, textures=textures_list)
|
||||
|
||||
Reference in New Issue
Block a user