mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Single function to load meshes from OBJs. join_meshes.
Summary: Create the textures and the Meshes object from OBJ files in a single call. There is functionality in OBJ files (like normals) which is ignored by this function. Reviewed By: gkioxari Differential Revision: D19691699 fbshipit-source-id: e26442ed80ff231b65b17d6c54c9d41e22b4e4a3
This commit is contained in:
committed by
Facebook Github Bot
parent
23bb27956a
commit
8fe65d5f56
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
|
||||
from .obj_io import load_obj, save_obj
|
||||
from .obj_io import load_obj, load_objs_as_meshes, save_obj
|
||||
from .ply_io import load_ply, save_ply
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
||||
@@ -13,6 +13,8 @@ import torch
|
||||
from fvcore.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes
|
||||
|
||||
|
||||
def _read_image(file_name: str, format=None):
|
||||
"""
|
||||
@@ -90,7 +92,7 @@ def _open_file(f):
|
||||
|
||||
def load_obj(f_obj, load_textures=True):
|
||||
"""
|
||||
Load a mesh and textures from a .obj and .mtl file.
|
||||
Load a mesh from a .obj file and optionally textures from a .mtl file.
|
||||
Currently this handles verts, faces, vertex texture uv coordinates, normals,
|
||||
texture images and material reflectivity values.
|
||||
|
||||
@@ -208,6 +210,44 @@ def load_obj(f_obj, load_textures=True):
|
||||
f_obj.close()
|
||||
|
||||
|
||||
def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
|
||||
"""
|
||||
Load meshes from a list of .obj files using the load_obj function, and
|
||||
return them as a Meshes object. This only works for meshes which have a
|
||||
single texture image for the whole mesh. See the load_obj function for more
|
||||
details. material_colors and normals are not stored.
|
||||
|
||||
Args:
|
||||
f: A list of file-like objects (with methods read, readline, tell,
|
||||
and seek), pathlib paths or strings containing file names.
|
||||
device: Desired device of returned Meshes. Default:
|
||||
uses the current device for the default tensor type.
|
||||
load_textures: Boolean indicating whether material files are loaded
|
||||
|
||||
Returns:
|
||||
New Meshes object.
|
||||
"""
|
||||
mesh_list = []
|
||||
for f_obj in files:
|
||||
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
|
||||
verts = verts.to(device)
|
||||
tex = None
|
||||
tex_maps = aux.texture_images
|
||||
if tex_maps is not None and len(tex_maps) > 0:
|
||||
verts_uvs = aux.verts_uvs[None, ...].to(device) # (1, V, 2)
|
||||
faces_uvs = faces.textures_idx[None, ...].to(device) # (1, F, 3)
|
||||
image = list(tex_maps.values())[0].to(device)[None]
|
||||
tex = Textures(verts_uvs=verts_uvs, faces_uvs=faces_uvs, maps=image)
|
||||
|
||||
mesh = Meshes(
|
||||
verts=[verts], faces=[faces.verts_idx.to(device)], textures=tex
|
||||
)
|
||||
mesh_list.append(mesh)
|
||||
if len(mesh_list) == 1:
|
||||
return mesh_list[0]
|
||||
return join_meshes(mesh_list)
|
||||
|
||||
|
||||
def _parse_face(
|
||||
line,
|
||||
material_idx,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from .meshes import Meshes
|
||||
from .meshes import Meshes, join_meshes
|
||||
from .textures import Textures
|
||||
from .utils import (
|
||||
list_to_packed,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from typing import List
|
||||
import torch
|
||||
|
||||
from pytorch3d import _C
|
||||
@@ -1365,3 +1366,77 @@ class Meshes(object):
|
||||
if self.textures is not None:
|
||||
tex = self.textures.extend(N)
|
||||
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
|
||||
|
||||
|
||||
def join_meshes(meshes: List[Meshes], include_textures: bool = True):
|
||||
"""
|
||||
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
|
||||
must all be on the same device. If include_textures is true, they must all
|
||||
be compatible, either all or none having textures, and all the Textures
|
||||
objects having the same members. If include_textures is False, textures are
|
||||
ignored.
|
||||
|
||||
Args:
|
||||
meshes: list of meshes.
|
||||
include_textures: (bool) whether to try to join the textures.
|
||||
|
||||
Returns:
|
||||
new Meshes object containing all the meshes from all the inputs.
|
||||
"""
|
||||
if isinstance(meshes, Meshes):
|
||||
# Meshes objects can be iterated and produce single Meshes. We avoid
|
||||
# letting join_meshes(mesh1, mesh2) silently do the wrong thing.
|
||||
raise ValueError("Wrong first argument to join_meshes.")
|
||||
verts = [v for mesh in meshes for v in mesh.verts_list()]
|
||||
faces = [f for mesh in meshes for f in mesh.faces_list()]
|
||||
if len(meshes) == 0 or not include_textures:
|
||||
return Meshes(verts=verts, faces=faces)
|
||||
|
||||
if meshes[0].textures is None:
|
||||
if any(mesh.textures is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent textures in join_meshes.")
|
||||
return Meshes(verts=verts, faces=faces)
|
||||
|
||||
if any(mesh.textures is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent textures in join_meshes.")
|
||||
|
||||
# Now we know there are multiple meshes and they have textures to merge.
|
||||
first = meshes[0].textures
|
||||
kwargs = {}
|
||||
if first.maps_padded() is not None:
|
||||
if any(mesh.textures.maps_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent maps_padded in join_meshes.")
|
||||
maps = [m for mesh in meshes for m in mesh.textures.maps_padded()]
|
||||
kwargs["maps"] = maps
|
||||
elif any(mesh.textures.maps_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent maps_padded in join_meshes.")
|
||||
|
||||
if first.verts_uvs_padded() is not None:
|
||||
if any(mesh.textures.verts_uvs_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
|
||||
uvs = [uv for mesh in meshes for uv in mesh.textures.verts_uvs_list()]
|
||||
V = max(uv.shape[0] for uv in uvs)
|
||||
kwargs["verts_uvs"] = struct_utils.list_to_padded(uvs, (V, 2), -1)
|
||||
elif any(mesh.textures.verts_uvs_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_uvs_padded in join_meshes.")
|
||||
|
||||
if first.faces_uvs_padded() is not None:
|
||||
if any(mesh.textures.faces_uvs_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
|
||||
uvs = [uv for mesh in meshes for uv in mesh.textures.faces_uvs_list()]
|
||||
F = max(uv.shape[0] for uv in uvs)
|
||||
kwargs["faces_uvs"] = struct_utils.list_to_padded(uvs, (F, 3), -1)
|
||||
elif any(mesh.textures.faces_uvs_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent faces_uvs_padded in join_meshes.")
|
||||
|
||||
if first.verts_rgb_padded() is not None:
|
||||
if any(mesh.textures.verts_rgb_padded() is None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
|
||||
rgb = [i for mesh in meshes for i in mesh.textures.verts_rgb_list()]
|
||||
V = max(i.shape[0] for i in rgb)
|
||||
kwargs["verts_rgb"] = struct_utils.list_to_padded(rgb, (V, 3))
|
||||
elif any(mesh.textures.verts_rgb_padded() is not None for mesh in meshes):
|
||||
raise ValueError("Inconsistent verts_rgb_padded in join_meshes.")
|
||||
|
||||
tex = Textures(**kwargs)
|
||||
return Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
Reference in New Issue
Block a user