diff --git a/pytorch3d/io/mtl_io.py b/pytorch3d/io/mtl_io.py new file mode 100644 index 00000000..663ae276 --- /dev/null +++ b/pytorch3d/io/mtl_io.py @@ -0,0 +1,462 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +"""This module implements utility functions for loading .mtl files and textures.""" +import os +import warnings +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from pytorch3d.io.utils import _open_file, _read_image + + +def make_mesh_texture_atlas( + material_properties: Dict, + texture_images: Dict, + face_material_names, + faces_verts_uvs: torch.Tensor, + texture_size: int, + texture_wrap: Optional[str], +) -> torch.Tensor: + """ + Given properties for materials defined in the .mtl file, and the face texture uv + coordinates, construct an (F, R, R, 3) texture atlas where R is the texture_size + and F is the number of faces in the mesh. + + Args: + material_properties: dict of properties for each material. If a material + does not have any properties it will have an emtpy dict. + texture_images: dict of material names and texture images + face_material_names: numpy array of the material name corresponding to each + face. Faces which don't have an associated material will be an empty string. + For these faces, a uniform white texture is assigned. + faces_verts_uvs: LongTensor of shape (F, 3, 2) giving the uv coordinates for each + vertex in the face. + texture_size: the resolution of the per face texture map returned by this function. + Each face will have a texture map of shape (texture_size, texture_size, 3). + texture_wrap: string, one of ["repeat", "clamp", None] + If `texture_wrap="repeat"` for uv values outside the range [0, 1] the integer part + is ignored and a repeating pattern is formed. + If `texture_wrap="clamp"` the values are clamped to the range [0, 1]. + If None, do nothing. + + Returns: + atlas: FloatTensor of shape (F, texture_size, texture_size, 3) giving the per + face texture map. + """ + # Create an R x R texture map per face in the mesh + R = texture_size + F = faces_verts_uvs.shape[0] + + # Initialize the per face texture map to a white color. + # TODO: allow customization of this base color? + atlas = faces_verts_uvs.new_ones(size=(F, R, R, 3)) + + # Check for empty materials. + if not material_properties and not texture_images: + return atlas + + if texture_wrap == "repeat": + # If texture uv coordinates are outside the range [0, 1] follow + # the convention GL_REPEAT in OpenGL i.e the integer part of the coordinate + # will be ignored and a repeating pattern is formed. + # Shapenet data uses this format see: + # https://shapenet.org/qaforum/index.php?qa=15&qa_1=why-is-the-texture-coordinate-in-the-obj-file-not-in-the-range + if (faces_verts_uvs > 1).any() or (faces_verts_uvs < 0).any(): + msg = "Texture UV coordinates outside the range [0, 1]. \ + The integer part will be ignored to form a repeating pattern." + warnings.warn(msg) + faces_verts_uvs = faces_verts_uvs % 1 + elif texture_wrap == "clamp": + # Clamp uv coordinates to the [0, 1] range. + faces_verts_uvs = faces_verts_uvs.clamp(0.0, 1.0) + + # Iterate through the material properties - not + # all materials have texture images so this has to be + # done separately to the texture interpolation. + for material_name, props in material_properties.items(): + # Bool to indicate which faces use this texture map. + faces_material_ind = torch.from_numpy(face_material_names == material_name).to( + faces_verts_uvs.device + ) + if (faces_material_ind).sum() > 0: + # For these faces, update the base color to the + # diffuse material color. + if "diffuse_color" not in props: + continue + atlas[faces_material_ind, ...] = props["diffuse_color"][None, :] + + # Iterate through the materials used in this mesh. Update the + # texture atlas for the faces which use this material. + # Faces without texture are white. + for material_name, image in list(texture_images.items()): + # Only use the RGB colors + if image.shape[2] == 4: + image = image[:, :, :3] + + # Reverse the image y direction + image = torch.flip(image, [0]).type_as(faces_verts_uvs) + + # Bool to indicate which faces use this texture map. + faces_material_ind = torch.from_numpy(face_material_names == material_name).to( + faces_verts_uvs.device + ) + + # Find the subset of faces which use this texture with this texture image + uvs_subset = faces_verts_uvs[faces_material_ind, :, :] + + # Update the texture atlas for the faces which use this texture. + # TODO: should the texture map values be multiplied + # by the diffuse material color (i.e. use *= as the atlas has + # been initialized to the diffuse color)?. This is + # not being done in SoftRas. + atlas[faces_material_ind, :, :] = make_material_atlas(image, uvs_subset, R) + + return atlas + + +def make_material_atlas( + image: torch.Tensor, faces_verts_uvs: torch.Tensor, texture_size: int +) -> torch.Tensor: + r""" + Given a single texture image and the uv coordinates for all the + face vertices, create a square texture map per face using + the formulation from [1]. + + For a triangle with vertices (v0, v1, v2) we can create a barycentric coordinate system + with the x axis being the vector (v1 - v0) and the y axis being the vector (v2 - v0). + The barycentric coordinates range from [0, 1] in the +x and +y direction so this creates + a triangular texture space with vertices at (0, 1), (0, 0) and (1, 0). + + The per face texture map is of shape (texture_size, texture_size, 3) + which is a square. To map a triangular texture to a square grid, each + triangle is parametrized as follows (e.g. R = texture_size = 3): + + The triangle texture is first divided into RxR = 9 subtriangles which each + map to one grid cell. The numbers in the grid cells and triangles show the mapping. + + ..code-block::python + + Triangular Texture Space: + + 1 + |\ + |6 \ + |____\ + |\ 7 |\ + |3 \ |4 \ + |____\|____\ + |\ 8 |\ 5 |\ + |0 \ |1 \ |2 \ + |____\|____\|____\ + 0 1 + + Square per face texture map: + + R ____________________ + | | | | + | 6 | 7 | 8 | + |______|______|______| + | | | | + | 3 | 4 | 5 | + |______|______|______| + | | | | + | 0 | 1 | 2 | + |______|______|______| + 0 R + + + The barycentric coordinates of each grid cell are calculated using the + xy coordinates: + + ..code-block::python + + The cartesian coordinates are: + + Grid 1: + + R ____________________ + | | | | + | 20 | 21 | 22 | + |______|______|______| + | | | | + | 10 | 11 | 12 | + |______|______|______| + | | | | + | 00 | 01 | 02 | + |______|______|______| + 0 R + + where 02 means y = 0, x = 2 + + Now consider this subset of the triangle which corresponds to + grid cells 0 and 8: + + ..code-block::python + + 1/R ________ + |\ 8 | + | \ | + | 0 \ | + |_______\| + 0 1/R + + The centroids of the triangles are: + 0: (1/3, 1/3) * 1/R + 8: (2/3, 2/3) * 1/R + + For each grid cell we can now calculate the centroid `(c_y, c_x)` + of the corresponding texture triangle: + - if `(x + y) < R`, then offsett the centroid of + triangle 0 by `(y, x) * (1/R)` + - if `(x + y) > R`, then offset the centroid of + triangle 8 by `((R-1-y), (R-1-x)) * (1/R)`. + + This is equivalent to updating the portion of Grid 1 + above the diagnonal, replacing `(y, x)` with `((R-1-y), (R-1-x))`: + + ..code-block::python + + R _____________________ + | | | | + | 20 | 01 | 00 | + |______|______|______| + | | | | + | 10 | 11 | 10 | + |______|______|______| + | | | | + | 00 | 01 | 02 | + |______|______|______| + 0 R + + The barycentric coordinates (w0, w1, w2) are then given by: + + ..code-block::python + + w0 = c_x + w1 = c_y + w2 = 1- w0 - w1 + + Args: + image: FloatTensor of shape (H, W, 3) + faces_verts_uvs: uv coordinates for each vertex in each face (F, 3, 2) + texture_size: int + + Returns: + atlas: a FloatTensor of shape (F, texture_size, texture_size, 3) giving a + per face texture map. + + [1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based + 3D Reasoning', ICCV 2019 + """ + R = texture_size + device = faces_verts_uvs.device + rng = torch.arange(R, device=device) + + # Meshgrid returns (row, column) i.e (Y, X) + # Change order to (X, Y) to make the grid. + Y, X = torch.meshgrid(rng, rng) + grid = torch.stack([X, Y], axis=-1) # (R, R, 2) + + # Grid cells below the diagonal: x + y < R. + below_diag = grid.sum(-1) < R + + # map a [0, R] grid -> to a [0, 1] barycentric coordinates of + # the texture triangle centroids. + bary = torch.zeros((R, R, 3), device=device) # (R, R, 3) + slc = torch.arange(2, device=device)[:, None] + # w0, w1 + bary[below_diag, slc] = ((grid[below_diag] + 1.0 / 3.0) / R).T + # w0, w1 for above diagonal grid cells. + bary[~below_diag, slc] = (((R - 1.0 - grid[~below_diag]) + 2.0 / 3.0) / R).T + # w2 = 1. - w0 - w1 + bary[..., -1] = 1 - bary[..., :2].sum(dim=-1) + + # Calculate the uv position in the image for each pixel + # in the per face texture map + # (F, 1, 1, 3, 2) * (R, R, 3, 1) -> (F, R, R, 3, 2) -> (F, R, R, 2) + uv_pos = (faces_verts_uvs[:, None, None] * bary[..., None]).sum(-2) + + # bi-linearly interpolate the textures from the images + # using the uv coordinates given by uv_pos. + textures = _bilinear_interpolation_vectorized(image, uv_pos) + + return textures + + +def _bilinear_interpolation_vectorized( + image: torch.Tensor, grid: torch.Tensor +) -> torch.Tensor: + """ + Bi linearly interpolate the image using the uv positions in the flow-field + grid (following the naming conventions for torch.nn.functional.grid_sample). + + This implementation uses the same steps as in the SoftRas cuda kernel + to make it easy to compare. This vectorized version requires less memory than + _bilinear_interpolation_grid_sample but is slightly slower. + If speed is an issue and the number of faces in the mesh and texture image sizes + are small, consider using _bilinear_interpolation_grid_sample instead. + + Args: + image: FloatTensor of shape (H, W, D) a single image/input tensor with D + channels. + grid: FloatTensor of shape (N, R, R, 2) giving the pixel locations of the + points at which to sample a value in the image. The grid values must + be in the range [0, 1]. u is the x direction and v is the y direction. + + Returns: + out: FloatTensor of shape (N, H, W, D) giving the interpolated + D dimensional value from image at each of the pixel locations in grid. + + """ + H, W, _ = image.shape + # Convert [0, 1] to the range [0, W-1] and [0, H-1] + grid = grid * torch.tensor([W - 1, H - 1]).type_as(grid) + weight_1 = grid - grid.int() + weight_0 = 1.0 - weight_1 + + grid_x, grid_y = grid.unbind(-1) + y0 = grid_y.to(torch.int64) + y1 = (grid_y + 1).to(torch.int64) + x0 = grid_x.to(torch.int64) + x1 = x0 + 1 + + weight_x0, weight_y0 = weight_0.unbind(-1) + weight_x1, weight_y1 = weight_1.unbind(-1) + + # Bi-linear interpolation + # griditions = [[y, x], [(y+1), x] + # [y, (x+1)], [(y+1), (x+1)]] + # weights = [[wx0*wy0, wx0*wy1], + # [wx1*wy0, wx1*wy1]] + out = ( + image[y0, x0] * (weight_x0 * weight_y0)[..., None] + + image[y1, x0] * (weight_x0 * weight_y1)[..., None] + + image[y0, x1] * (weight_x1 * weight_y0)[..., None] + + image[y1, x1] * (weight_x1 * weight_y1)[..., None] + ) + + return out + + +def _bilinear_interpolation_grid_sample( + image: torch.Tensor, grid: torch.Tensor +) -> torch.Tensor: + """ + Bi linearly interpolate the image using the uv positions in the flow-field + grid (following the conventions for torch.nn.functional.grid_sample). + + This implementation is faster than _bilinear_interpolation_vectorized but + requires more memory so can cause OOMs. If speed is an issue try this function + instead. + + Args: + image: FloatTensor of shape (H, W, D) a single image/input tensor with D + channels. + grid: FloatTensor of shape (N, R, R, 2) giving the pixel locations of the + points at which to sample a value in the image. The grid values must + be in the range [0, 1]. u is the x direction and v is the y direction. + + Returns: + out: FloatTensor of shape (N, H, W, D) giving the interpolated + D dimensional value from image at each of the pixel locations in grid. + """ + + N = grid.shape[0] + # convert [0, 1] to the range [-1, 1] expected by grid_sample. + grid = grid * 2.0 - 1.0 + image = image.permute(2, 0, 1)[None, ...].expand(N, -1, -1, -1) # (N, 3, H, W) + # Align_corners has to be set to True to match the output of the SoftRas + # cuda kernel for bilinear sampling. + out = F.grid_sample(image, grid, mode="bilinear", align_corners=True) + return out.permute(0, 2, 3, 1) + + +def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"): + """ + Load texture images and material reflectivity values for ambient, diffuse + and specular light (Ka, Kd, Ks, Ns). + + Args: + f_mtl: a file like object of the material information. + material_names: a list of the material names found in the .obj file. + data_dir: the directory where the material texture files are located. + + Returns: + material_colors: dict of properties for each material. If a material + does not have any properties it will have an emtpy dict. + { + material_name_1: { + "ambient_color": tensor of shape (1, 3), + "diffuse_color": tensor of shape (1, 3), + "specular_color": tensor of shape (1, 3), + "shininess": tensor of shape (1) + }, + material_name_2: {}, + ... + } + texture_images: dict of material names and texture images + { + material_name_1: (H, W, 3) image, + ... + } + """ + texture_files = {} + material_colors = {} + material_properties = {} + texture_images = {} + material_name = "" + + f_mtl, new_f = _open_file(f_mtl) + lines = [line.strip() for line in f_mtl] + for line in lines: + if len(line.split()) != 0: + if line.split()[0] == "newmtl": + material_name = line.split()[1] + material_colors[material_name] = {} + if line.split()[0] == "map_Kd": + # Texture map. + texture_files[material_name] = line.split()[1] + if line.split()[0] == "Kd": + # RGB diffuse reflectivity + kd = np.array(list(line.split()[1:4])).astype(np.float32) + kd = torch.from_numpy(kd).to(device) + material_colors[material_name]["diffuse_color"] = kd + if line.split()[0] == "Ka": + # RGB ambient reflectivity + ka = np.array(list(line.split()[1:4])).astype(np.float32) + ka = torch.from_numpy(ka).to(device) + material_colors[material_name]["ambient_color"] = ka + if line.split()[0] == "Ks": + # RGB specular reflectivity + ks = np.array(list(line.split()[1:4])).astype(np.float32) + ks = torch.from_numpy(ks).to(device) + material_colors[material_name]["specular_color"] = ks + if line.split()[0] == "Ns": + # Specular exponent + ns = np.array(list(line.split()[1:4])).astype(np.float32) + ns = torch.from_numpy(ns).to(device) + material_colors[material_name]["shininess"] = ns + + if new_f: + f_mtl.close() + + # Only keep the materials referenced in the obj. + for name in material_names: + if name in texture_files: + # Load the texture image. + filename = texture_files[name] + filename_texture = os.path.join(data_dir, filename) + if os.path.isfile(filename_texture): + image = _read_image(filename_texture, format="RGB") / 255.0 + image = torch.from_numpy(image) + texture_images[name] = image + else: + msg = f"Texture file does not exist: {filename_texture}" + warnings.warn(msg) + + if name in material_colors: + material_properties[name] = material_colors[name] + + return material_properties, texture_images diff --git a/pytorch3d/io/obj_io.py b/pytorch3d/io/obj_io.py index e5fda294..b7ab04c5 100644 --- a/pytorch3d/io/obj_io.py +++ b/pytorch3d/io/obj_io.py @@ -6,55 +6,34 @@ import os import pathlib import warnings from collections import namedtuple -from typing import List, Optional +from typing import Optional import numpy as np import torch -from fvcore.common.file_io import PathManager -from PIL import Image +from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas +from pytorch3d.io.utils import _open_file from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch -def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor: +def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor: """ Return a 2D tensor with the specified cols and dtype filled with data, even when data is empty. """ if not data: - return torch.zeros((0, cols), dtype=dtype) + return torch.zeros((0, cols), dtype=dtype, device=device) - return torch.tensor(data, dtype=dtype) - - -def _read_image(file_name: str, format=None): - """ - Read an image from a file using Pillow. - Args: - file_name: image file path. - format: one of ["RGB", "BGR"] - Returns: - image: an image of shape (H, W, C). - """ - if format not in ["RGB", "BGR"]: - raise ValueError("format can only be one of [RGB, BGR]; got %s", format) - with PathManager.open(file_name, "rb") as f: - image = Image.open(f) - if format is not None: - # PIL only supports RGB. First convert to RGB and flip channels - # below for BGR. - image = image.convert("RGB") - image = np.asarray(image).astype(np.float32) - if format == "BGR": - image = image[:, :, ::-1] - return image + return torch.tensor(data, dtype=dtype, device=device) # Faces & Aux type returned from load_obj function. _Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx") -_Aux = namedtuple("Properties", "normals verts_uvs material_colors texture_images") +_Aux = namedtuple( + "Properties", "normals verts_uvs material_colors texture_images texture_atlas" +) -def _format_faces_indices(faces_indices, max_index): +def _format_faces_indices(faces_indices, max_index, device, pad_value=None): """ Format indices and check for invalid values. Indices can refer to values in one of the face properties: vertices, textures or normals. @@ -70,7 +49,12 @@ def _format_faces_indices(faces_indices, max_index): Raises: ValueError if indices are not in a valid range. """ - faces_indices = _make_tensor(faces_indices, cols=3, dtype=torch.int64) + faces_indices = _make_tensor( + faces_indices, cols=3, dtype=torch.int64, device=device + ) + + if pad_value: + mask = faces_indices.eq(pad_value).all(-1) # Change to 0 based indexing. faces_indices[(faces_indices > 0)] -= 1 @@ -78,6 +62,9 @@ def _format_faces_indices(faces_indices, max_index): # Negative indexing counts from the end. faces_indices[(faces_indices < 0)] += max_index + if pad_value: + faces_indices[mask] = pad_value + # Check indices are valid. if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0): warnings.warn("Faces have invalid indices") @@ -85,18 +72,14 @@ def _format_faces_indices(faces_indices, max_index): return faces_indices -def _open_file(f): - new_f = False - if isinstance(f, str): - new_f = True - f = open(f, "r") - elif isinstance(f, pathlib.Path): - new_f = True - f = f.open("r") - return f, new_f - - -def load_obj(f_obj, load_textures=True): +def load_obj( + f_obj, + load_textures=True, + create_texture_atlas: bool = False, + texture_atlas_size: int = 4, + texture_wrap: Optional[str] = "repeat", + device="cpu", +): """ Load a mesh from a .obj file and optionally textures from a .mtl file. Currently this handles verts, faces, vertex texture uv coordinates, normals, @@ -155,6 +138,18 @@ def load_obj(f_obj, load_textures=True): f: A file-like object (with methods read, readline, tell, and seek), a pathlib path or a string containing a file name. load_textures: Boolean indicating whether material files are loaded + create_texture_atlas: Bool, If True a per face texture map is created and + a tensor `texture_atlas` is also returned in `aux`. + texture_atlas_size: Int specifying the resolution of the texture map per face + when `create_texture_atlas=True`. A (texture_size, texture_size, 3) + map is created per face. + texture_wrap: string, one of ["repeat", "clamp"]. This applies when computing + the texture atlas. + If `texture_mode="repeat"`, for uv values outside the range [0, 1] the integer part + is ignored and a repeating pattern is formed. + If `texture_mode="clamp"` the values are clamped to the range [0, 1]. + If None, then there is no transformation of the texture values. + device: string or torch.device on which to return the new tensors. Returns: 6-element tuple containing @@ -181,9 +176,8 @@ def load_obj(f_obj, load_textures=True): possible that the number of verts_uvs is greater than num verts i.e. T > V. vertex. - - material_colors: dict of material names and associated properties. - If a material does not have any properties it will have an - empty dict. + - material_colors: if `load_textures=True` and the material has associated + properties this will be a dict of material names and properties of the form: .. code-block:: python @@ -197,20 +191,40 @@ def load_obj(f_obj, load_textures=True): material_name_2: {}, ... } - - texture_images: dict of material names and texture images. + + If a material does not have any properties it will have an + empty dict. If `load_textures=False`, `material_colors` will None. + + - texture_images: if `load_textures=True` and the material has a texture map, + this will be a dict of the form: + .. code-block:: python { material_name_1: (H, W, 3) image, ... } + If `load_textures=False`, `texture_images` will None. + - texture_atlas: if `load_textures=True` and `create_texture_atlas=True`, + this will be a FloatTensor of the form: (F, texture_size, textures_size, 3) + If the material does not have a texture map, then all faces + will have a uniform white texture. Otherwise `texture_atlas` will be + None. """ data_dir = "./" if isinstance(f_obj, (str, bytes, os.PathLike)): data_dir = os.path.dirname(f_obj) f_obj, new_f = _open_file(f_obj) try: - return _load(f_obj, data_dir, load_textures=load_textures) + return _load( + f_obj, + data_dir, + load_textures=load_textures, + create_texture_atlas=create_texture_atlas, + texture_atlas_size=texture_atlas_size, + texture_wrap=texture_wrap, + device=device, + ) finally: if new_f: f_obj.close() @@ -235,6 +249,7 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True): """ mesh_list = [] for f_obj in files: + # TODO: update this function to support the two texturing options. verts, faces, aux = load_obj(f_obj, load_textures=load_textures) verts = verts.to(device) tex = None @@ -286,6 +301,10 @@ def _parse_face( # Triplets must be consistent for all vertices in a face e.g. # legal statement: f 4/1/1 3/2/1 2/1/1. # illegal statement: f 4/1/1 3//1 2//1. + # If the face does not have normals or textures indices + # fill with pad value = -1. This will ensure that + # all the face index tensors will have F values where + # F is the number of faces. if len(face_normals) > 0: if not (len(face_verts) == len(face_normals)): raise ValueError( @@ -293,6 +312,8 @@ def _parse_face( Vertex properties are inconsistent. Line: %s" % (str(face), str(line)) ) + else: + face_normals = [-1] * len(face_verts) # Fill with -1 if len(face_textures) > 0: if not (len(face_verts) == len(face_textures)): raise ValueError( @@ -300,28 +321,41 @@ def _parse_face( Vertex properties are inconsistent. Line: %s" % (str(face), str(line)) ) + else: + face_textures = [-1] * len(face_verts) # Fill with -1 - # Subdivide faces with more than 3 vertices. See comments of the - # load_obj function for more details. + # Subdivide faces with more than 3 vertices. + # See comments of the load_obj function for more details. for i in range(len(face_verts) - 2): faces_verts_idx.append((face_verts[0], face_verts[i + 1], face_verts[i + 2])) - if len(face_normals) > 0: - faces_normals_idx.append( - (face_normals[0], face_normals[i + 1], face_normals[i + 2]) - ) - if len(face_textures) > 0: - faces_textures_idx.append( - (face_textures[0], face_textures[i + 1], face_textures[i + 2]) - ) + faces_normals_idx.append( + (face_normals[0], face_normals[i + 1], face_normals[i + 2]) + ) + faces_textures_idx.append( + (face_textures[0], face_textures[i + 1], face_textures[i + 2]) + ) faces_materials_idx.append(material_idx) -def _load(f_obj, data_dir, load_textures=True): +def _load( + f_obj, + data_dir, + load_textures: bool = True, + create_texture_atlas: bool = False, + texture_atlas_size: int = 4, + texture_wrap: Optional[str] = "repeat", + device="cpu", +): """ Load a mesh from a file-like object. See load_obj function more details. Any material files associated with the obj are expected to be in the directory given by data_dir. """ + + if texture_wrap is not None and texture_wrap not in ["repeat", "clamp"]: + msg = "texture_wrap must be one of ['repeat', 'clamp'] or None, got %s" + raise ValueError(msg % texture_wrap) + lines = [line.strip() for line in f_obj] verts = [] normals = [] @@ -343,12 +377,19 @@ def _load(f_obj, data_dir, load_textures=True): if line.startswith("mtllib"): if len(line.split()) < 2: raise ValueError("material file name is not specified") - # NOTE: this assumes only one mtl file per .obj. + # NOTE: only allow one .mtl file per .obj. + # Definitions for multiple materials can be included + # in this one .mtl file. f_mtl = os.path.join(data_dir, line.split()[1]) elif len(line.split()) != 0 and line.split()[0] == "usemtl": material_name = line.split()[1] - material_names.append(material_name) - materials_idx = len(material_names) - 1 + # materials are often repeated for different parts + # of a mesh. + if material_name not in material_names: + material_names.append(material_name) + materials_idx = len(material_names) - 1 + else: + materials_idx = material_names.index(material_name) elif line.startswith("v "): # Line is a vertex. vert = [float(x) for x in line.split()[1:4]] @@ -372,7 +413,7 @@ def _load(f_obj, data_dir, load_textures=True): raise ValueError(msg % (str(norm), str(line))) normals.append(norm) elif line.startswith("f "): - # Line is a face. + # Line is a face update face properties info. _parse_face( line, materials_idx, @@ -382,30 +423,63 @@ def _load(f_obj, data_dir, load_textures=True): faces_materials_idx, ) - verts = _make_tensor(verts, cols=3, dtype=torch.float32) # (V, 3) - normals = _make_tensor(normals, cols=3, dtype=torch.float32) # (N, 3) - verts_uvs = _make_tensor(verts_uvs, cols=2, dtype=torch.float32) # (T, 2) + verts = _make_tensor(verts, cols=3, dtype=torch.float32, device=device) # (V, 3) + normals = _make_tensor( + normals, cols=3, dtype=torch.float32, device=device + ) # (N, 3) + verts_uvs = _make_tensor( + verts_uvs, cols=2, dtype=torch.float32, device=device + ) # (T, 2) - faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0]) + faces_verts_idx = _format_faces_indices( + faces_verts_idx, verts.shape[0], device=device + ) # Repeat for normals and textures if present. if len(faces_normals_idx) > 0: - faces_normals_idx = _format_faces_indices(faces_normals_idx, normals.shape[0]) + faces_normals_idx = _format_faces_indices( + faces_normals_idx, normals.shape[0], device=device, pad_value=-1 + ) if len(faces_textures_idx) > 0: faces_textures_idx = _format_faces_indices( - faces_textures_idx, verts_uvs.shape[0] + faces_textures_idx, verts_uvs.shape[0], device=device, pad_value=-1 ) if len(faces_materials_idx) > 0: - faces_materials_idx = torch.tensor(faces_materials_idx, dtype=torch.int64) + faces_materials_idx = torch.tensor( + faces_materials_idx, dtype=torch.int64, device=device + ) # Load materials - material_colors, texture_images = None, None + material_colors, texture_images, texture_atlas = None, None, None if load_textures: if (len(material_names) > 0) and (f_mtl is not None): if os.path.isfile(f_mtl): + # Texture mode uv wrap material_colors, texture_images = load_mtl( - f_mtl, material_names, data_dir + f_mtl, material_names, data_dir, device=device ) + if create_texture_atlas: + # Using the images and properties from the + # material file make a per face texture map. + + # Create an array of strings of material names for each face. + # If faces_materials_idx == -1 then that face doesn't have a material. + idx = faces_materials_idx.cpu().numpy() + face_material_names = np.array(material_names)[idx] # (F,) + face_material_names[idx == -1] = "" + + # Get the uv coords for each vert in each face + faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2) + + # Construct the atlas. + texture_atlas = make_mesh_texture_atlas( + material_colors, + texture_images, + face_material_names, + faces_verts_uvs, + texture_atlas_size, + texture_wrap, + ) else: warnings.warn(f"Mtl file does not exist: {f_mtl}") elif len(material_names) > 0: @@ -423,99 +497,11 @@ def _load(f_obj, data_dir, load_textures=True): verts_uvs=verts_uvs if len(verts_uvs) > 0 else None, material_colors=material_colors, texture_images=texture_images, + texture_atlas=texture_atlas, ) return verts, faces, aux -def load_mtl(f_mtl, material_names: List, data_dir: str): - """ - Load texture images and material reflectivity values for ambient, diffuse - and specular light (Ka, Kd, Ks, Ns). - - Args: - f_mtl: a file like object of the material information. - material_names: a list of the material names found in the .obj file. - data_dir: the directory where the material texture files are located. - - Returns: - material_colors: dict of properties for each material. If a material - does not have any properties it will have an emtpy dict. - { - material_name_1: { - "ambient_color": tensor of shape (1, 3), - "diffuse_color": tensor of shape (1, 3), - "specular_color": tensor of shape (1, 3), - "shininess": tensor of shape (1) - }, - material_name_2: {}, - ... - } - texture_images: dict of material names and texture images - { - material_name_1: (H, W, 3) image, - ... - } - """ - texture_files = {} - material_colors = {} - material_properties = {} - texture_images = {} - material_name = "" - - f_mtl, new_f = _open_file(f_mtl) - lines = [line.strip() for line in f_mtl] - for line in lines: - if len(line.split()) != 0: - if line.split()[0] == "newmtl": - material_name = line.split()[1] - material_colors[material_name] = {} - if line.split()[0] == "map_Kd": - # Texture map. - texture_files[material_name] = line.split()[1] - if line.split()[0] == "Kd": - # RGB diffuse reflectivity - kd = np.array(list(line.split()[1:4])).astype(np.float32) - kd = torch.from_numpy(kd) - material_colors[material_name]["diffuse_color"] = kd - if line.split()[0] == "Ka": - # RGB ambient reflectivity - ka = np.array(list(line.split()[1:4])).astype(np.float32) - ka = torch.from_numpy(ka) - material_colors[material_name]["ambient_color"] = ka - if line.split()[0] == "Ks": - # RGB specular reflectivity - ks = np.array(list(line.split()[1:4])).astype(np.float32) - ks = torch.from_numpy(ks) - material_colors[material_name]["specular_color"] = ks - if line.split()[0] == "Ns": - # Specular exponent - ns = np.array(list(line.split()[1:4])).astype(np.float32) - ns = torch.from_numpy(ns) - material_colors[material_name]["shininess"] = ns - - if new_f: - f_mtl.close() - - # Only keep the materials referenced in the obj. - for name in material_names: - if name in texture_files: - # Load the texture image. - filename = texture_files[name] - filename_texture = os.path.join(data_dir, filename) - if os.path.isfile(filename_texture): - image = _read_image(filename_texture, format="RGB") / 255.0 - image = torch.from_numpy(image) - texture_images[name] = image - else: - msg = f"Texture file does not exist: {filename_texture}" - warnings.warn(msg) - - if name in material_colors: - material_properties[name] = material_colors[name] - - return material_properties, texture_images - - def save_obj(f, verts, faces, decimal_places: Optional[int] = None): """ Save a mesh to an .obj file. diff --git a/pytorch3d/io/utils.py b/pytorch3d/io/utils.py new file mode 100644 index 00000000..969f3005 --- /dev/null +++ b/pytorch3d/io/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +import pathlib + +import numpy as np +from fvcore.common.file_io import PathManager +from PIL import Image + + +def _open_file(f): + new_f = False + if isinstance(f, str): + new_f = True + f = open(f, "r") + elif isinstance(f, pathlib.Path): + new_f = True + f = f.open("r") + return f, new_f + + +def _read_image(file_name: str, format=None): + """ + Read an image from a file using Pillow. + Args: + file_name: image file path. + format: one of ["RGB", "BGR"] + Returns: + image: an image of shape (H, W, C). + """ + if format not in ["RGB", "BGR"]: + raise ValueError("format can only be one of [RGB, BGR]; got %s", format) + with PathManager.open(file_name, "rb") as f: + image = Image.open(f) + if format is not None: + # PIL only supports RGB. First convert to RGB and flip channels + # below for BGR. + image = image.convert("RGB") + image = np.asarray(image).astype(np.float32) + if format == "BGR": + image = image[:, :, ::-1] + return image diff --git a/tests/bm_mesh_io.py b/tests/bm_mesh_io.py index a4f9b5ab..15719813 100644 --- a/tests/bm_mesh_io.py +++ b/tests/bm_mesh_io.py @@ -1,5 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +from itertools import product + from fvcore.common.benchmark import benchmark from test_obj_io import TestMeshObjIO from test_ply_io import TestMeshPlyIO @@ -61,3 +63,35 @@ def bm_save_load() -> None: complex_kwargs_list, warmup_iters=1, ) + + # Texture loading benchmarks + kwargs_list = [{"R": 2}, {"R": 4}, {"R": 10}, {"R": 15}, {"R": 20}] + benchmark( + TestMeshObjIO.bm_load_texture_atlas, + "PYTORCH3D_TEXTURE_ATLAS", + kwargs_list, + warmup_iters=1, + ) + + kwargs_list = [] + S = [64, 256, 1024] + F = [100, 1000, 10000] + R = [5, 10, 20] + test_cases = product(S, F, R) + + for case in test_cases: + s, f, r = case + kwargs_list.append({"S": s, "F": f, "R": r}) + + benchmark( + TestMeshObjIO.bm_bilinear_sampling_vectorized, + "BILINEAR_VECTORIZED", + kwargs_list, + warmup_iters=1, + ) + benchmark( + TestMeshObjIO.bm_bilinear_sampling_grid_sample, + "BILINEAR_GRID_SAMPLE", + kwargs_list, + warmup_iters=1, + ) diff --git a/tests/test_obj_io.py b/tests/test_obj_io.py index c66286f6..1084c7eb 100644 --- a/tests/test_obj_io.py +++ b/tests/test_obj_io.py @@ -2,12 +2,17 @@ import os import unittest +import warnings from io import StringIO from pathlib import Path import torch from common_testing import TestCaseMixin from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj +from pytorch3d.io.mtl_io import ( + _bilinear_interpolation_grid_sample, + _bilinear_interpolation_vectorized, +) from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch from pytorch3d.utils import torus @@ -47,8 +52,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): ) self.assertTrue(torch.all(verts == expected_verts)) self.assertTrue(torch.all(faces.verts_idx == expected_faces)) - self.assertTrue(faces.normals_idx == []) - self.assertTrue(faces.textures_idx == []) + padded_vals = -torch.ones_like(faces.verts_idx) + self.assertTrue(torch.all(faces.normals_idx == padded_vals)) + self.assertTrue(torch.all(faces.textures_idx == padded_vals)) self.assertTrue( torch.all(faces.materials_idx == -torch.ones(len(expected_faces))) ) @@ -118,8 +124,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): [[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]], dtype=torch.float32, ) - expected_faces_normals_idx = torch.tensor([[1, 1, 1]], dtype=torch.int64) - expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64) + expected_faces_normals_idx = -torch.ones_like(expected_faces, dtype=torch.int64) + expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64) + expected_faces_textures_idx = -torch.ones_like( + expected_faces, dtype=torch.int64 + ) + expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64) self.assertTrue(torch.all(verts == expected_verts)) self.assertTrue(torch.all(faces.verts_idx == expected_faces)) @@ -160,7 +170,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertClose(faces.normals_idx, expected_faces_normals_idx) self.assertClose(normals, expected_normals) self.assertClose(verts, expected_verts) - self.assertTrue(faces.textures_idx == []) + # Textures idx padded with -1. + self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1) self.assertTrue(textures is None) self.assertTrue(materials is None) self.assertTrue(tex_maps is None) @@ -195,7 +206,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): self.assertClose(faces.textures_idx, expected_faces_textures_idx) self.assertClose(expected_textures, textures) self.assertClose(expected_verts, verts) - self.assertTrue(faces.normals_idx == []) + self.assertTrue( + torch.all(faces.normals_idx == -torch.ones_like(faces.textures_idx)) + ) self.assertTrue(normals is None) self.assertTrue(materials is None) self.assertTrue(tex_maps is None) @@ -408,6 +421,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): "shininess": torch.tensor([10.0], dtype=dtype), } } + # Texture atlas is not created as `create_texture_atlas=True` was + # not set in the load_obj args + self.assertTrue(aux.texture_atlas is None) # Check that there is an image with material name material_1. self.assertTrue(tuple(tex_maps.keys()) == ("material_1",)) self.assertTrue(torch.is_tensor(tuple(tex_maps.values())[0])) @@ -423,6 +439,36 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): torch.allclose(materials[n1][k1], expected_materials[n2][k2]) ) + def test_load_mtl_texture_atlas_compare_softras(self): + # Load saved texture atlas created with SoftRas. + device = torch.device("cuda:0") + DATA_DIR = Path(__file__).resolve().parent.parent + obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj" + expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt" + + # Note, the reference texture atlas generated using SoftRas load_obj function + # is too large to check in to the repo. Download the file to run the test locally. + if not os.path.exists(expected_atlas_fname): + url = "https://dl.fbaipublicfiles.com/pytorch3d/data/tests/cow_texture_atlas_softras.pt" + msg = ( + "cow_texture_atlas_softras.pt not found, download from %s, save it at the path %s, and rerun" + % (url, expected_atlas_fname) + ) + warnings.warn(msg) + return True + + expected_atlas = torch.load(expected_atlas_fname) + _, _, aux = load_obj( + obj_filename, + load_textures=True, + device=device, + create_texture_atlas=True, + texture_atlas_size=15, + texture_wrap="repeat", + ) + + self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5) + def test_load_mtl_noload(self): DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data" obj_filename = "cow_mesh/cow.obj" @@ -629,3 +675,51 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase): meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N) [verts], [faces] = meshes.verts_list(), meshes.faces_list() return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5) + + @staticmethod + def bm_load_texture_atlas(R: int): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + DATA_DIR = "/data/users/nikhilar/fbsource/fbcode/vision/fair/pytorch3d/docs/" + obj_filename = os.path.join(DATA_DIR, "tutorials/data/cow_mesh/cow.obj") + torch.cuda.synchronize() + + def load(): + load_obj( + obj_filename, + load_textures=True, + device=device, + create_texture_atlas=True, + texture_atlas_size=R, + ) + torch.cuda.synchronize() + + return load + + @staticmethod + def bm_bilinear_sampling_vectorized(S: int, F: int, R: int): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + image = torch.rand((S, S, 3)) + grid = torch.rand((F, R, R, 2)) + torch.cuda.synchronize() + + def load(): + _bilinear_interpolation_vectorized(image, grid) + torch.cuda.synchronize() + + return load + + @staticmethod + def bm_bilinear_sampling_grid_sample(S: int, F: int, R: int): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + image = torch.rand((S, S, 3)) + grid = torch.rand((F, R, R, 2)) + torch.cuda.synchronize() + + def load(): + _bilinear_interpolation_grid_sample(image, grid) + torch.cuda.synchronize() + + return load