diff --git a/pytorch3d/ops/cubify.py b/pytorch3d/ops/cubify.py index 77bc924a..364e6226 100644 --- a/pytorch3d/ops/cubify.py +++ b/pytorch3d/ops/cubify.py @@ -5,9 +5,13 @@ # LICENSE file in the root directory of this source tree. +from typing import Optional + import torch import torch.nn.functional as F + from pytorch3d.common.compat import meshgrid_ij + from pytorch3d.structures import Meshes @@ -50,7 +54,14 @@ def ravel_index(idx, dims) -> torch.Tensor: @torch.no_grad() -def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes: +def cubify( + voxels: torch.Tensor, + thresh: float, + *, + feats: Optional[torch.Tensor] = None, + device=None, + align: str = "topleft" +) -> Meshes: r""" Converts a voxel to a mesh by replacing each occupied voxel with a cube consisting of 12 faces and 8 vertices. Shared vertices are merged, and @@ -59,6 +70,9 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes: voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities. thresh: A scalar threshold. If a voxel occupancy is larger than thresh, the voxel is considered occupied. + feats: A FloatTensor of shape (N, K, D, H, W) containing the color information + of each voxel. K is the number of channels. This is supported only when + align == "center" device: The device of the output meshes align: Defines the alignment of the mesh vertices and the grid locations. Has to be one of {"topleft", "corner", "center"}. See below for explanation. @@ -177,6 +191,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes: # boolean to linear index # NF x 2 linind = torch.nonzero(faces_idx, as_tuple=False) + # NF x 4 nyxz = unravel_index(linind[:, 0], (N, H, W, D)) @@ -238,6 +253,21 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes: grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0]) for n in range(N) ] - faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)] - return Meshes(verts=verts_list, faces=faces_list) + textures_list = None + if feats is not None and align == "center": + # We return a TexturesAtlas containing one color for each face + # N x K x D x H x W -> N x H x W x D x K + feats = feats.permute(0, 3, 4, 2, 1) + + # (NHWD) x K + feats = feats.reshape(-1, feats.size(4)) + feats = torch.index_select(feats, 0, linind[:, 0]) + feats = feats.reshape(-1, 1, 1, feats.size(1)) + feats_list = list(torch.split(feats, split_size.tolist(), 0)) + from pytorch3d.renderer.mesh.textures import TexturesAtlas + + textures_list = TexturesAtlas(feats_list) + + faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)] + return Meshes(verts=verts_list, faces=faces_list, textures=textures_list) diff --git a/tests/test_cubify.py b/tests/test_cubify.py index 1c82919b..c38b0ead 100644 --- a/tests/test_cubify.py +++ b/tests/test_cubify.py @@ -8,6 +8,7 @@ import unittest import torch from pytorch3d.ops import cubify +from pytorch3d.renderer.mesh.textures import TexturesAtlas from .common_testing import TestCaseMixin @@ -313,3 +314,42 @@ class TestCubify(TestCaseMixin, unittest.TestCase): torch.cuda.synchronize() return convert + + def test_cubify_with_feats(self): + N, V = 3, 2 + device = torch.device("cuda:0") + voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device) + feats = torch.zeros((N, 3, V, V, V), dtype=torch.float32, device=device) + # fill the feats with red color + feats[:, 0, :, :, :] = 255 + + # 1st example: (top left corner, znear) is on + voxels[0, 0, 0, 0] = 1.0 + # the color is set to green + feats[0, :, 0, 0, 0] = torch.Tensor([0, 255, 0]) + # 2nd example: all are on + voxels[1] = 1.0 + + # 3rd example + voxels[2, :, :, 1] = 1.0 + voxels[2, 1, 1, 0] = 1.0 + # the color is set to yellow and blue respectively + feats[2, 1, :, :, 1] = 255 + feats[2, :, 1, 1, 0] = torch.Tensor([0, 0, 255]) + meshes = cubify(voxels, 0.5, feats=feats, align="center") + textures = meshes.textures + self.assertTrue(textures is not None) + self.assertTrue(isinstance(textures, TexturesAtlas)) + faces_textures = textures.faces_verts_textures_packed() + red = faces_textures.new_tensor([255.0, 0.0, 0.0]) + green = faces_textures.new_tensor([0.0, 255.0, 0.0]) + blue = faces_textures.new_tensor([0.0, 0.0, 255.0]) + yellow = faces_textures.new_tensor([255.0, 255.0, 0.0]) + + self.assertEqual(faces_textures.shape, (100, 3, 3)) + faces_textures_ = faces_textures.flatten(end_dim=1) + self.assertClose(faces_textures_[:36], green.expand(36, -1)) + self.assertClose(faces_textures_[36:180], red.expand(144, -1)) + self.assertClose(faces_textures_[180:228], yellow.expand(48, -1)) + self.assertClose(faces_textures_[228:258], blue.expand(30, -1)) + self.assertClose(faces_textures_[258:300], yellow.expand(42, -1))