From 93d3d8feda296c1c9a78a1d20b37f83cd4299e5b Mon Sep 17 00:00:00 2001 From: Nikhila Ravi Date: Wed, 23 Sep 2020 12:10:53 -0700 Subject: [PATCH] Tidy OBJ / MTL parsing Summary: Tidy OBJ / MTL parsing: remove redundant calls to tokenize, factor out parsing and texture loading Reviewed By: gkioxari Differential Revision: D20720768 fbshipit-source-id: fb1713106d4ff99a4a9147afcc3da74ae013d8dc --- pytorch3d/io/mtl_io.py | 144 ++++++++++++++++++++++++----------------- pytorch3d/io/obj_io.py | 30 ++++----- 2 files changed, 100 insertions(+), 74 deletions(-) diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py index 885a636f..32bcc788 100644 --- a/pytorch3d/io/mtl_io.py +++ b/pytorch3d/io/mtl_io.py @@ -3,7 +3,7 @@ """This module implements utility functions for loading .mtl files and textures.""" import os import warnings -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -379,7 +379,84 @@ def _bilinear_interpolation_grid_sample( return out.permute(0, 2, 3, 1) -def load_mtl(f, material_names: List, data_dir: str, device="cpu"): +MaterialProperties = Dict[str, Dict[str, torch.Tensor]] +TextureFiles = Dict[str, str] +TextureImages = Dict[str, torch.Tensor] + + +def _parse_mtl(f, device="cpu") -> Tuple[MaterialProperties, TextureFiles]: + material_properties = {} + texture_files = {} + material_name = "" + + with _open_file(f, "r") as f: + for line in f: + tokens = line.strip().split() + if not tokens: + continue + if tokens[0] == "newmtl": + material_name = tokens[1] + material_properties[material_name] = {} + elif tokens[0] == "map_Kd": + # Diffuse texture map + texture_files[material_name] = tokens[1] + elif tokens[0] == "Kd": + # RGB diffuse reflectivity + kd = np.array(tokens[1:4]).astype(np.float32) + kd = torch.from_numpy(kd).to(device) + material_properties[material_name]["diffuse_color"] = kd + elif tokens[0] == "Ka": + # RGB ambient reflectivity + ka = np.array(tokens[1:4]).astype(np.float32) + ka = torch.from_numpy(ka).to(device) + material_properties[material_name]["ambient_color"] = ka + elif tokens[0] == "Ks": + # RGB specular reflectivity + ks = np.array(tokens[1:4]).astype(np.float32) + ks = torch.from_numpy(ks).to(device) + material_properties[material_name]["specular_color"] = ks + elif tokens[0] == "Ns": + # Specular exponent + ns = np.array(tokens[1:4]).astype(np.float32) + ns = torch.from_numpy(ns).to(device) + material_properties[material_name]["shininess"] = ns + + return material_properties, texture_files + + +def _load_texture_images( + material_names: List[str], + data_dir: str, + material_properties: MaterialProperties, + texture_files: TextureFiles, +) -> Tuple[MaterialProperties, TextureImages]: + final_material_properties = {} + texture_images = {} + + # Only keep the materials referenced in the obj. + for material_name in material_names: + if material_name in texture_files: + # Load the texture image. + path = os.path.join(data_dir, texture_files[material_name]) + if os.path.isfile(path): + image = _read_image(path, format="RGB") / 255.0 + image = torch.from_numpy(image) + texture_images[material_name] = image + else: + msg = f"Texture file does not exist: {path}" + warnings.warn(msg) + + if material_name in material_properties: + final_material_properties[material_name] = material_properties[ + material_name + ] + + return final_material_properties, texture_images + + +def load_mtl( + f, material_names: List[str], data_dir: str, device="cpu" +) -> Tuple[MaterialProperties, TextureImages]: """ Load texture images and material reflectivity values for ambient, diffuse and specular light (Ka, Kd, Ks, Ns). @@ -390,8 +467,8 @@ def load_mtl(f, material_names: List, data_dir: str, device="cpu"): data_dir: the directory where the material texture files are located. Returns: - material_colors: dict of properties for each material. If a material - does not have any properties it will have an emtpy dict. + material_properties: dict of properties for each material. If a material + does not have any properties it will have an empty dict. { material_name_1: { "ambient_color": tensor of shape (1, 3), @@ -408,58 +485,7 @@ def load_mtl(f, material_names: List, data_dir: str, device="cpu"): ... } """ - texture_files = {} - material_colors = {} - material_properties = {} - texture_images = {} - material_name = "" - - with _open_file(f) as f: - lines = [line.strip() for line in f] - for line in lines: - if len(line.split()) != 0: - if line.split()[0] == "newmtl": - material_name = line.split()[1] - material_colors[material_name] = {} - if line.split()[0] == "map_Kd": - # Texture map. - texture_files[material_name] = line.split()[1] - if line.split()[0] == "Kd": - # RGB diffuse reflectivity - kd = np.array(list(line.split()[1:4])).astype(np.float32) - kd = torch.from_numpy(kd).to(device) - material_colors[material_name]["diffuse_color"] = kd - if line.split()[0] == "Ka": - # RGB ambient reflectivity - ka = np.array(list(line.split()[1:4])).astype(np.float32) - ka = torch.from_numpy(ka).to(device) - material_colors[material_name]["ambient_color"] = ka - if line.split()[0] == "Ks": - # RGB specular reflectivity - ks = np.array(list(line.split()[1:4])).astype(np.float32) - ks = torch.from_numpy(ks).to(device) - material_colors[material_name]["specular_color"] = ks - if line.split()[0] == "Ns": - # Specular exponent - ns = np.array(list(line.split()[1:4])).astype(np.float32) - ns = torch.from_numpy(ns).to(device) - material_colors[material_name]["shininess"] = ns - - # Only keep the materials referenced in the obj. - for name in material_names: - if name in texture_files: - # Load the texture image. - filename = texture_files[name] - filename_texture = os.path.join(data_dir, filename) - if os.path.isfile(filename_texture): - image = _read_image(filename_texture, format="RGB") / 255.0 - image = torch.from_numpy(image) - texture_images[name] = image - else: - msg = f"Texture file does not exist: {filename_texture}" - warnings.warn(msg) - - if name in material_colors: - material_properties[name] = material_colors[name] - - return material_properties, texture_images + material_properties, texture_files = _parse_mtl(f, device) + return _load_texture_images( + material_names, data_dir, material_properties, texture_files + ) diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index e87255c3..a21f63c4 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -276,13 +276,14 @@ def load_objs_as_meshes( def _parse_face( line, + tokens, material_idx, faces_verts_idx, faces_normals_idx, faces_textures_idx, faces_materials_idx, ): - face = line.split()[1:] + face = tokens[1:] face_list = [f.split("/") for f in face] face_verts = [] face_normals = [] @@ -381,15 +382,16 @@ def _load_obj( lines = [el.decode("utf-8") for el in lines] for line in lines: + tokens = line.strip().split() if line.startswith("mtllib"): - if len(line.split()) < 2: + if len(tokens) < 2: raise ValueError("material file name is not specified") # NOTE: only allow one .mtl file per .obj. # Definitions for multiple materials can be included # in this one .mtl file. f_mtl = os.path.join(data_dir, line.split()[1]) - elif len(line.split()) != 0 and line.split()[0] == "usemtl": - material_name = line.split()[1] + elif len(tokens) and tokens[0] == "usemtl": + material_name = tokens[1] # materials are often repeated for different parts # of a mesh. if material_name not in material_names: @@ -397,32 +399,30 @@ def _load_obj( materials_idx = len(material_names) - 1 else: materials_idx = material_names.index(material_name) - elif line.startswith("v "): - # Line is a vertex. - vert = [float(x) for x in line.split()[1:4]] + elif line.startswith("v "): # Line is a vertex. + vert = [float(x) for x in tokens[1:4]] if len(vert) != 3: msg = "Vertex %s does not have 3 values. Line: %s" raise ValueError(msg % (str(vert), str(line))) verts.append(vert) - elif line.startswith("vt "): - # Line is a texture. - tx = [float(x) for x in line.split()[1:3]] + elif line.startswith("vt "): # Line is a texture. + tx = [float(x) for x in tokens[1:3]] if len(tx) != 2: raise ValueError( "Texture %s does not have 2 values. Line: %s" % (str(tx), str(line)) ) verts_uvs.append(tx) - elif line.startswith("vn "): - # Line is a normal. - norm = [float(x) for x in line.split()[1:4]] + elif line.startswith("vn "): # Line is a normal. + norm = [float(x) for x in tokens[1:4]] if len(norm) != 3: msg = "Normal %s does not have 3 values. Line: %s" raise ValueError(msg % (str(norm), str(line))) normals.append(norm) - elif line.startswith("f "): - # Line is a face update face properties info. + elif line.startswith("f "): # Line is a face. + # Update face properties info. _parse_face( line, + tokens, materials_idx, faces_verts_idx, faces_normals_idx,