mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Adding join_mesh in pytorch3d.structures.meshes
Summary: Adding a function in pytorch3d.structures.meshes to join multiple meshes into a Meshes object representing a single mesh. The function currently ignores all textures. Reviewed By: nikhilaravi Differential Revision: D21876908 fbshipit-source-id: 448602857e9d3d3f774d18bb4e93076f78329823
This commit is contained in:
parent
4b78e95eeb
commit
e053d7c456
@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -1539,3 +1539,28 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
|
||||
|
||||
tex = Textures(**kwargs)
|
||||
return Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
|
||||
def join_mesh(meshes: Union[Meshes, List[Meshes]]) -> Meshes:
|
||||
"""
|
||||
Joins a batch of meshes in the form of a Meshes object or a list of Meshes
|
||||
objects as a single mesh. If the input is a list, the Meshes objects in the list
|
||||
must all be on the same device. This version ignores all textures in the input mehses.
|
||||
|
||||
Args:
|
||||
meshes: Meshes object that contains a batch of meshes or a list of Meshes objects
|
||||
|
||||
Returns:
|
||||
new Meshes object containing a single mesh
|
||||
"""
|
||||
if isinstance(meshes, List):
|
||||
meshes = join_meshes_as_batch(meshes, include_textures=False)
|
||||
|
||||
if len(meshes) == 1:
|
||||
return meshes
|
||||
verts = meshes.verts_packed() # (sum(V_n), 3)
|
||||
# Offset automatically done by faces_packed
|
||||
faces = meshes.faces_packed() # (sum(F_n), 3)
|
||||
|
||||
mesh = Meshes(verts=verts.unsqueeze(0), faces=faces.unsqueeze(0))
|
||||
return mesh
|
||||
|
BIN
tests/data/test_joined_spheres_flat.png
Normal file
BIN
tests/data/test_joined_spheres_flat.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 26 KiB |
BIN
tests/data/test_joined_spheres_gouraud.png
Normal file
BIN
tests/data/test_joined_spheres_gouraud.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
BIN
tests/data/test_joined_spheres_phong.png
Normal file
BIN
tests/data/test_joined_spheres_phong.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 21 KiB |
@ -14,6 +14,7 @@ from pytorch3d.io.mtl_io import (
|
||||
_bilinear_interpolation_vectorized,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
|
||||
from pytorch3d.structures.meshes import join_mesh
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
|
||||
@ -648,6 +649,42 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(cow3_tea.verts_list()[3], mesh_teapot.verts_list()[0])
|
||||
self.assertClose(cow3_tea.faces_list()[3], mesh_teapot.faces_list()[0])
|
||||
|
||||
def test_join_meshes(self):
|
||||
"""
|
||||
Test that join_mesh joins single meshes and the corresponding values are
|
||||
consistent with the single meshes.
|
||||
"""
|
||||
|
||||
# Load cow mesh.
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
cow_obj = DATA_DIR / "cow_mesh/cow.obj"
|
||||
|
||||
cow_mesh = load_objs_as_meshes([cow_obj])
|
||||
cow_verts, cow_faces = cow_mesh.get_mesh_verts_faces(0)
|
||||
# Join a batch of three single meshes and check that the values are consistent
|
||||
# with the individual meshes.
|
||||
cow_mesh3 = join_mesh([cow_mesh, cow_mesh, cow_mesh])
|
||||
|
||||
def check_item(x, y, offset):
|
||||
self.assertClose(torch.cat([x, x + offset, x + 2 * offset], dim=1), y)
|
||||
|
||||
check_item(cow_mesh.verts_padded(), cow_mesh3.verts_padded(), 0)
|
||||
check_item(cow_mesh.faces_padded(), cow_mesh3.faces_padded(), cow_mesh._V)
|
||||
|
||||
# Test the joining of meshes of different sizes.
|
||||
teapot_obj = DATA_DIR / "teapot.obj"
|
||||
teapot_mesh = load_objs_as_meshes([teapot_obj])
|
||||
teapot_verts, teapot_faces = teapot_mesh.get_mesh_verts_faces(0)
|
||||
|
||||
mix_mesh = join_mesh([cow_mesh, teapot_mesh])
|
||||
mix_verts, mix_faces = mix_mesh.get_mesh_verts_faces(0)
|
||||
self.assertEqual(len(mix_mesh), 1)
|
||||
|
||||
self.assertClose(mix_verts[: cow_mesh._V], cow_verts)
|
||||
self.assertClose(mix_faces[: cow_mesh._F], cow_faces)
|
||||
self.assertClose(mix_verts[cow_mesh._V :], teapot_verts)
|
||||
self.assertClose(mix_faces[cow_mesh._F :], teapot_faces + cow_mesh._V)
|
||||
|
||||
@staticmethod
|
||||
def _bm_save_obj(verts: torch.Tensor, faces: torch.Tensor, decimal_places: int):
|
||||
return lambda: save_obj(StringIO(), verts, faces, decimal_places)
|
||||
|
@ -26,7 +26,7 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.renderer.mesh.texturing import Textures
|
||||
from pytorch3d.structures.meshes import Meshes
|
||||
from pytorch3d.structures.meshes import Meshes, join_mesh
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
|
||||
@ -176,7 +176,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
# Init renderer
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
shaders = {
|
||||
"phong": HardGouraudShader,
|
||||
"phong": HardPhongShader,
|
||||
"gouraud": HardGouraudShader,
|
||||
"flat": HardFlatShader,
|
||||
}
|
||||
@ -369,3 +369,70 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_joined_spheres(self):
|
||||
"""
|
||||
Test a list of Meshes can be joined as a single mesh and
|
||||
the single mesh is rendered correctly with Phong, Gouraud
|
||||
and Flat Shaders.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
# Init mesh with vertex textures.
|
||||
# Initialize a list containing two ico spheres of different sizes.
|
||||
sphere_list = [ico_sphere(3, device), ico_sphere(4, device)]
|
||||
# [(42 verts, 80 faces), (162 verts, 320 faces)]
|
||||
# The scale the vertices need to be set at to resize the spheres
|
||||
scales = [0.25, 1]
|
||||
# The distance the spheres ought to be offset horizontally to prevent overlap.
|
||||
offsets = [1.2, -0.3]
|
||||
# Initialize a list containing the adjusted sphere meshes.
|
||||
sphere_mesh_list = []
|
||||
for i in range(len(sphere_list)):
|
||||
verts = sphere_list[i].verts_padded() * scales[i]
|
||||
verts[0, :, 0] += offsets[i]
|
||||
sphere_mesh_list.append(
|
||||
Meshes(verts=verts, faces=sphere_list[i].faces_padded())
|
||||
)
|
||||
joined_sphere_mesh = join_mesh(sphere_mesh_list)
|
||||
joined_sphere_mesh.textures = Textures(
|
||||
verts_rgb=torch.ones_like(joined_sphere_mesh.verts_padded())
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
R, T = look_at_view_transform(2.7, 0.0, 0.0)
|
||||
cameras = OpenGLPerspectiveCameras(device=device, R=R, T=T)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
materials = Materials(device=device)
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
|
||||
|
||||
# Init renderer
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
shaders = {
|
||||
"phong": HardPhongShader,
|
||||
"gouraud": HardGouraudShader,
|
||||
"flat": HardFlatShader,
|
||||
}
|
||||
for (name, shader_init) in shaders.items():
|
||||
shader = shader_init(
|
||||
lights=lights,
|
||||
cameras=cameras,
|
||||
materials=materials,
|
||||
blend_params=blend_params,
|
||||
)
|
||||
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||
image = renderer(joined_sphere_mesh)
|
||||
rgb = image[..., :3].squeeze().cpu()
|
||||
if DEBUG:
|
||||
file_name = "DEBUG_joined_spheres_%s.png" % name
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / file_name
|
||||
)
|
||||
image_ref = load_rgb_image("test_joined_spheres_%s.png" % name, DATA_DIR)
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
Loading…
x
Reference in New Issue
Block a user