mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 09:52:11 +08:00 
			
		
		
		
	Support color in cubify
Summary: The diff support colors in cubify for align = "center" Reviewed By: bottler Differential Revision: D53777011 fbshipit-source-id: ccb2bd1e3d89be3d1ac943eff08f40e50b0540d9
This commit is contained in:
		
							parent
							
								
									8772fe0de8
								
							
						
					
					
						commit
						ae9d8787ce
					
				@ -5,9 +5,13 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
 | 
			
		||||
from pytorch3d.common.compat import meshgrid_ij
 | 
			
		||||
 | 
			
		||||
from pytorch3d.structures import Meshes
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,7 +54,14 @@ def ravel_index(idx, dims) -> torch.Tensor:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@torch.no_grad()
 | 
			
		||||
def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
 | 
			
		||||
def cubify(
 | 
			
		||||
    voxels: torch.Tensor,
 | 
			
		||||
    thresh: float,
 | 
			
		||||
    *,
 | 
			
		||||
    feats: Optional[torch.Tensor] = None,
 | 
			
		||||
    device=None,
 | 
			
		||||
    align: str = "topleft"
 | 
			
		||||
) -> Meshes:
 | 
			
		||||
    r"""
 | 
			
		||||
    Converts a voxel to a mesh by replacing each occupied voxel with a cube
 | 
			
		||||
    consisting of 12 faces and 8 vertices. Shared vertices are merged, and
 | 
			
		||||
@ -59,6 +70,9 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
 | 
			
		||||
      voxels: A FloatTensor of shape (N, D, H, W) containing occupancy probabilities.
 | 
			
		||||
      thresh: A scalar threshold. If a voxel occupancy is larger than
 | 
			
		||||
          thresh, the voxel is considered occupied.
 | 
			
		||||
      feats: A FloatTensor of shape (N, K, D, H, W) containing the color information
 | 
			
		||||
          of each voxel. K is the number of channels. This is supported only when
 | 
			
		||||
          align == "center"
 | 
			
		||||
      device: The device of the output meshes
 | 
			
		||||
      align: Defines the alignment of the mesh vertices and the grid locations.
 | 
			
		||||
          Has to be one of {"topleft", "corner", "center"}. See below for explanation.
 | 
			
		||||
@ -177,6 +191,7 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
 | 
			
		||||
    # boolean to linear index
 | 
			
		||||
    # NF x 2
 | 
			
		||||
    linind = torch.nonzero(faces_idx, as_tuple=False)
 | 
			
		||||
 | 
			
		||||
    # NF x 4
 | 
			
		||||
    nyxz = unravel_index(linind[:, 0], (N, H, W, D))
 | 
			
		||||
 | 
			
		||||
@ -238,6 +253,21 @@ def cubify(voxels, thresh, device=None, align: str = "topleft") -> Meshes:
 | 
			
		||||
        grid_verts.index_select(0, (idleverts[n] == 0).nonzero(as_tuple=False)[:, 0])
 | 
			
		||||
        for n in range(N)
 | 
			
		||||
    ]
 | 
			
		||||
    faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
 | 
			
		||||
 | 
			
		||||
    return Meshes(verts=verts_list, faces=faces_list)
 | 
			
		||||
    textures_list = None
 | 
			
		||||
    if feats is not None and align == "center":
 | 
			
		||||
        # We return a TexturesAtlas containing one color for each face
 | 
			
		||||
        # N x K x D x H x W  -> N x H x W x D x K
 | 
			
		||||
        feats = feats.permute(0, 3, 4, 2, 1)
 | 
			
		||||
 | 
			
		||||
        # (NHWD) x K
 | 
			
		||||
        feats = feats.reshape(-1, feats.size(4))
 | 
			
		||||
        feats = torch.index_select(feats, 0, linind[:, 0])
 | 
			
		||||
        feats = feats.reshape(-1, 1, 1, feats.size(1))
 | 
			
		||||
        feats_list = list(torch.split(feats, split_size.tolist(), 0))
 | 
			
		||||
        from pytorch3d.renderer.mesh.textures import TexturesAtlas
 | 
			
		||||
 | 
			
		||||
        textures_list = TexturesAtlas(feats_list)
 | 
			
		||||
 | 
			
		||||
    faces_list = [nface - idlenum[n][nface] for n, nface in enumerate(faces_list)]
 | 
			
		||||
    return Meshes(verts=verts_list, faces=faces_list, textures=textures_list)
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.ops import cubify
 | 
			
		||||
from pytorch3d.renderer.mesh.textures import TexturesAtlas
 | 
			
		||||
 | 
			
		||||
from .common_testing import TestCaseMixin
 | 
			
		||||
 | 
			
		||||
@ -313,3 +314,42 @@ class TestCubify(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            torch.cuda.synchronize()
 | 
			
		||||
 | 
			
		||||
        return convert
 | 
			
		||||
 | 
			
		||||
    def test_cubify_with_feats(self):
 | 
			
		||||
        N, V = 3, 2
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
 | 
			
		||||
        feats = torch.zeros((N, 3, V, V, V), dtype=torch.float32, device=device)
 | 
			
		||||
        # fill the feats with red color
 | 
			
		||||
        feats[:, 0, :, :, :] = 255
 | 
			
		||||
 | 
			
		||||
        # 1st example: (top left corner, znear) is on
 | 
			
		||||
        voxels[0, 0, 0, 0] = 1.0
 | 
			
		||||
        # the color is set to green
 | 
			
		||||
        feats[0, :, 0, 0, 0] = torch.Tensor([0, 255, 0])
 | 
			
		||||
        # 2nd example: all are on
 | 
			
		||||
        voxels[1] = 1.0
 | 
			
		||||
 | 
			
		||||
        # 3rd example
 | 
			
		||||
        voxels[2, :, :, 1] = 1.0
 | 
			
		||||
        voxels[2, 1, 1, 0] = 1.0
 | 
			
		||||
        # the color is set to yellow and blue respectively
 | 
			
		||||
        feats[2, 1, :, :, 1] = 255
 | 
			
		||||
        feats[2, :, 1, 1, 0] = torch.Tensor([0, 0, 255])
 | 
			
		||||
        meshes = cubify(voxels, 0.5, feats=feats, align="center")
 | 
			
		||||
        textures = meshes.textures
 | 
			
		||||
        self.assertTrue(textures is not None)
 | 
			
		||||
        self.assertTrue(isinstance(textures, TexturesAtlas))
 | 
			
		||||
        faces_textures = textures.faces_verts_textures_packed()
 | 
			
		||||
        red = faces_textures.new_tensor([255.0, 0.0, 0.0])
 | 
			
		||||
        green = faces_textures.new_tensor([0.0, 255.0, 0.0])
 | 
			
		||||
        blue = faces_textures.new_tensor([0.0, 0.0, 255.0])
 | 
			
		||||
        yellow = faces_textures.new_tensor([255.0, 255.0, 0.0])
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(faces_textures.shape, (100, 3, 3))
 | 
			
		||||
        faces_textures_ = faces_textures.flatten(end_dim=1)
 | 
			
		||||
        self.assertClose(faces_textures_[:36], green.expand(36, -1))
 | 
			
		||||
        self.assertClose(faces_textures_[36:180], red.expand(144, -1))
 | 
			
		||||
        self.assertClose(faces_textures_[180:228], yellow.expand(48, -1))
 | 
			
		||||
        self.assertClose(faces_textures_[228:258], blue.expand(30, -1))
 | 
			
		||||
        self.assertClose(faces_textures_[258:300], yellow.expand(42, -1))
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user