Support color in cubify

Summary: The diff support colors in cubify for align = "center"

Reviewed By: bottler

Differential Revision: D53777011

fbshipit-source-id: ccb2bd1e3d89be3d1ac943eff08f40e50b0540d9
This commit is contained in:
Cijo Jose
2024-02-16 08:19:12 -08:00
committed by Facebook GitHub Bot
parent 8772fe0de8
commit ae9d8787ce
2 changed files with 73 additions and 3 deletions

View File

@@ -8,6 +8,7 @@ import unittest
import torch
from pytorch3d.ops import cubify
from pytorch3d.renderer.mesh.textures import TexturesAtlas
from .common_testing import TestCaseMixin
@@ -313,3 +314,42 @@ class TestCubify(TestCaseMixin, unittest.TestCase):
torch.cuda.synchronize()
return convert
def test_cubify_with_feats(self):
N, V = 3, 2
device = torch.device("cuda:0")
voxels = torch.zeros((N, V, V, V), dtype=torch.float32, device=device)
feats = torch.zeros((N, 3, V, V, V), dtype=torch.float32, device=device)
# fill the feats with red color
feats[:, 0, :, :, :] = 255
# 1st example: (top left corner, znear) is on
voxels[0, 0, 0, 0] = 1.0
# the color is set to green
feats[0, :, 0, 0, 0] = torch.Tensor([0, 255, 0])
# 2nd example: all are on
voxels[1] = 1.0
# 3rd example
voxels[2, :, :, 1] = 1.0
voxels[2, 1, 1, 0] = 1.0
# the color is set to yellow and blue respectively
feats[2, 1, :, :, 1] = 255
feats[2, :, 1, 1, 0] = torch.Tensor([0, 0, 255])
meshes = cubify(voxels, 0.5, feats=feats, align="center")
textures = meshes.textures
self.assertTrue(textures is not None)
self.assertTrue(isinstance(textures, TexturesAtlas))
faces_textures = textures.faces_verts_textures_packed()
red = faces_textures.new_tensor([255.0, 0.0, 0.0])
green = faces_textures.new_tensor([0.0, 255.0, 0.0])
blue = faces_textures.new_tensor([0.0, 0.0, 255.0])
yellow = faces_textures.new_tensor([255.0, 255.0, 0.0])
self.assertEqual(faces_textures.shape, (100, 3, 3))
faces_textures_ = faces_textures.flatten(end_dim=1)
self.assertClose(faces_textures_[:36], green.expand(36, -1))
self.assertClose(faces_textures_[36:180], red.expand(144, -1))
self.assertClose(faces_textures_[180:228], yellow.expand(48, -1))
self.assertClose(faces_textures_[228:258], blue.expand(30, -1))
self.assertClose(faces_textures_[258:300], yellow.expand(42, -1))