mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix texture atlas for objs which only have material properties
Summary: Fix for GitHub issue #381. The example mesh provided in the issue only had material properties but no texture image. The current implementation of texture atlassing generated an atlas using both the material properties and the texture image but only worked if there was a texture image and associated vertex uv coordinates. I have now modified the texture atlas creation so that it doesn't require an image and can work with materials which only have material properties. Reviewed By: gkioxari Differential Revision: D24153068 fbshipit-source-id: 63e9d325db09a84b336b83369d5342ce588a9932
This commit is contained in:
parent
5d65a0cf8c
commit
f5383a7e5a
@ -15,7 +15,8 @@ def make_mesh_texture_atlas(
|
||||
material_properties: Dict,
|
||||
texture_images: Dict,
|
||||
face_material_names,
|
||||
faces_verts_uvs: torch.Tensor,
|
||||
faces_uvs: torch.Tensor,
|
||||
verts_uvs: torch.Tensor,
|
||||
texture_size: int,
|
||||
texture_wrap: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
@ -31,8 +32,9 @@ def make_mesh_texture_atlas(
|
||||
face_material_names: numpy array of the material name corresponding to each
|
||||
face. Faces which don't have an associated material will be an empty string.
|
||||
For these faces, a uniform white texture is assigned.
|
||||
faces_verts_uvs: LongTensor of shape (F, 3, 2) giving the uv coordinates for each
|
||||
vertex in the face.
|
||||
faces_uvs: LongTensor of shape (F, 3,) giving the index into the verts_uvs for
|
||||
each face in the mesh.
|
||||
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinates for each vertex.
|
||||
texture_size: the resolution of the per face texture map returned by this function.
|
||||
Each face will have a texture map of shape (texture_size, texture_size, 3).
|
||||
texture_wrap: string, one of ["repeat", "clamp", None]
|
||||
@ -47,50 +49,55 @@ def make_mesh_texture_atlas(
|
||||
"""
|
||||
# Create an R x R texture map per face in the mesh
|
||||
R = texture_size
|
||||
F = faces_verts_uvs.shape[0]
|
||||
F = faces_uvs.shape[0]
|
||||
|
||||
# Initialize the per face texture map to a white color.
|
||||
# TODO: allow customization of this base color?
|
||||
# pyre-fixme[16]: `Tensor` has no attribute `new_ones`.
|
||||
atlas = faces_verts_uvs.new_ones(size=(F, R, R, 3))
|
||||
atlas = torch.ones(size=(F, R, R, 3), dtype=torch.float32, device=faces_uvs.device)
|
||||
|
||||
# Check for empty materials.
|
||||
if not material_properties and not texture_images:
|
||||
return atlas
|
||||
|
||||
# Iterate through the material properties - not
|
||||
# all materials have texture images so this is
|
||||
# done first separately to the texture interpolation.
|
||||
for material_name, props in material_properties.items():
|
||||
# Bool to indicate which faces use this texture map.
|
||||
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
|
||||
faces_uvs.device
|
||||
)
|
||||
if faces_material_ind.sum() > 0:
|
||||
# For these faces, update the base color to the
|
||||
# diffuse material color.
|
||||
if "diffuse_color" not in props:
|
||||
continue
|
||||
atlas[faces_material_ind, ...] = props["diffuse_color"][None, :]
|
||||
|
||||
# If there are vertex texture coordinates, create an (F, 3, 2)
|
||||
# tensor of the vertex textures per face.
|
||||
faces_verts_uvs = verts_uvs[faces_uvs] if len(verts_uvs) > 0 else None
|
||||
|
||||
# Some meshes only have material properties and no texture image.
|
||||
# In this case, return the atlas here.
|
||||
if faces_verts_uvs is None:
|
||||
return atlas
|
||||
|
||||
if texture_wrap == "repeat":
|
||||
# If texture uv coordinates are outside the range [0, 1] follow
|
||||
# the convention GL_REPEAT in OpenGL i.e the integer part of the coordinate
|
||||
# will be ignored and a repeating pattern is formed.
|
||||
# Shapenet data uses this format see:
|
||||
# https://shapenet.org/qaforum/index.php?qa=15&qa_1=why-is-the-texture-coordinate-in-the-obj-file-not-in-the-range # noqa: B950
|
||||
# pyre-fixme[16]: `ByteTensor` has no attribute `any`.
|
||||
if (faces_verts_uvs > 1).any() or (faces_verts_uvs < 0).any():
|
||||
msg = "Texture UV coordinates outside the range [0, 1]. \
|
||||
The integer part will be ignored to form a repeating pattern."
|
||||
warnings.warn(msg)
|
||||
# pyre-fixme[9]: faces_verts_uvs has type `Tensor`; used as `int`.
|
||||
# pyre-fixme[58]: `%` is not supported for operand types `Tensor` and `int`.
|
||||
faces_verts_uvs = faces_verts_uvs % 1
|
||||
elif texture_wrap == "clamp":
|
||||
# Clamp uv coordinates to the [0, 1] range.
|
||||
faces_verts_uvs = faces_verts_uvs.clamp(0.0, 1.0)
|
||||
|
||||
# Iterate through the material properties - not
|
||||
# all materials have texture images so this has to be
|
||||
# done separately to the texture interpolation.
|
||||
for material_name, props in material_properties.items():
|
||||
# Bool to indicate which faces use this texture map.
|
||||
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
|
||||
faces_verts_uvs.device
|
||||
)
|
||||
if faces_material_ind.sum() > 0:
|
||||
# For these faces, update the base color to the
|
||||
# diffuse material color.
|
||||
if "diffuse_color" not in props:
|
||||
continue
|
||||
atlas[faces_material_ind, ...] = props["diffuse_color"][None, :]
|
||||
|
||||
# Iterate through the materials used in this mesh. Update the
|
||||
# texture atlas for the faces which use this material.
|
||||
# Faces without texture are white.
|
||||
|
@ -533,19 +533,16 @@ def _load_obj(
|
||||
face_material_names = np.array(material_names)[idx] # (F,)
|
||||
face_material_names[idx == -1] = ""
|
||||
|
||||
if len(verts_uvs) > 0:
|
||||
# Get the uv coords for each vert in each face
|
||||
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
|
||||
|
||||
# Construct the atlas.
|
||||
texture_atlas = make_mesh_texture_atlas(
|
||||
material_colors,
|
||||
texture_images,
|
||||
face_material_names,
|
||||
faces_verts_uvs,
|
||||
texture_atlas_size,
|
||||
texture_wrap,
|
||||
)
|
||||
# Construct the atlas.
|
||||
texture_atlas = make_mesh_texture_atlas(
|
||||
material_colors,
|
||||
texture_images,
|
||||
face_material_names,
|
||||
faces_textures_idx,
|
||||
verts_uvs,
|
||||
texture_atlas_size,
|
||||
texture_wrap,
|
||||
)
|
||||
|
||||
faces = _Faces(
|
||||
verts_idx=faces_verts_idx,
|
||||
|
7
tests/data/obj_mtl_no_image/model.mtl
Normal file
7
tests/data/obj_mtl_no_image/model.mtl
Normal file
@ -0,0 +1,7 @@
|
||||
# Material Count: 1
|
||||
|
||||
newmtl material_1
|
||||
Ns 96.078431
|
||||
Ka 0.000000 0.000000 0.000000
|
||||
Kd 0.500000 0.000000 0.000000
|
||||
Ks 0.500000 0.500000 0.500000
|
10
tests/data/obj_mtl_no_image/model.obj
Normal file
10
tests/data/obj_mtl_no_image/model.obj
Normal file
@ -0,0 +1,10 @@
|
||||
|
||||
mtllib model.mtl
|
||||
|
||||
v 0.1 0.2 0.3
|
||||
v 0.2 0.3 0.4
|
||||
v 0.3 0.4 0.5
|
||||
v 0.4 0.5 0.6
|
||||
usemtl material_1
|
||||
f 1 2 3
|
||||
f 1 2 4
|
@ -559,6 +559,35 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(aux.normals is None)
|
||||
self.assertTrue(aux.verts_uvs is None)
|
||||
|
||||
def test_load_obj_mlt_no_image(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
obj_filename = "obj_mtl_no_image/model.obj"
|
||||
filename = os.path.join(DATA_DIR, obj_filename)
|
||||
R = 8
|
||||
verts, faces, aux = load_obj(
|
||||
filename,
|
||||
load_textures=True,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=R,
|
||||
texture_wrap=None,
|
||||
)
|
||||
|
||||
expected_verts = torch.tensor(
|
||||
[[0.1, 0.2, 0.3], [0.2, 0.3, 0.4], [0.3, 0.4, 0.5], [0.4, 0.5, 0.6]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
|
||||
self.assertTrue(torch.allclose(verts, expected_verts))
|
||||
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
|
||||
|
||||
# Check that the material diffuse color has been assigned to all the
|
||||
# values in the texture atlas.
|
||||
expected_atlas = torch.tensor([0.5, 0.0, 0.0], dtype=torch.float32)
|
||||
expected_atlas = expected_atlas[None, None, None, :].expand(2, R, R, -1)
|
||||
self.assertTrue(torch.allclose(aux.texture_atlas, expected_atlas))
|
||||
self.assertEquals(len(aux.material_colors.keys()), 1)
|
||||
self.assertEquals(list(aux.material_colors.keys()), ["material_1"])
|
||||
|
||||
def test_load_obj_missing_texture(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
obj_filename = "missing_files_obj/model.obj"
|
||||
|
Loading…
x
Reference in New Issue
Block a user