mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Update load obj and compare with SoftRas
Summary: Updated the load obj function to support creating of a per face texture map using the information in an .mtl file. Uses the approach from in SoftRasterizer. Currently I have ported in the SoftRasterizer code but this is only to help with comparison and will be deleted before landing. The ShapeNet Test data will also be deleted. Here is the [Design doc](https://docs.google.com/document/d/1AUcLP4QwVSqlfLAUfbjM9ic5vYn9P54Ha8QbcVXW2eI/edit?usp=sharing). ## Added - texture atlas creation functions in PyTorch based on the SoftRas cuda implementation - tests to compare SoftRas vs PyTorch3D implementation to verify it matches (using real shapenet data with meshes consisting of multiple textures) - benchmarks tests ## Remaining todo: - add more tests for obj io to test the new functions and the two texturing options - replace the shapenet data with the output from SoftRas saved as a file. # MAIN FILES TO REVIEW - `obj_io.py` - `test_obj_io.py` [still some tests to be added but have comparisons with SoftRas for now] The reference SoftRas implementations are in `softras_load_obj.py` and `load_textures.cu`. Reviewed By: gkioxari Differential Revision: D20754859 fbshipit-source-id: 42ace9dfb73f26e29d800c763f56d5b66c60c5e2
This commit is contained in:
parent
85c396f822
commit
c9267ab7af
462
pytorch3d/io/mtl_io.py
Normal file
462
pytorch3d/io/mtl_io.py
Normal file
@ -0,0 +1,462 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
"""This module implements utility functions for loading .mtl files and textures."""
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from pytorch3d.io.utils import _open_file, _read_image
|
||||
|
||||
|
||||
def make_mesh_texture_atlas(
|
||||
material_properties: Dict,
|
||||
texture_images: Dict,
|
||||
face_material_names,
|
||||
faces_verts_uvs: torch.Tensor,
|
||||
texture_size: int,
|
||||
texture_wrap: Optional[str],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Given properties for materials defined in the .mtl file, and the face texture uv
|
||||
coordinates, construct an (F, R, R, 3) texture atlas where R is the texture_size
|
||||
and F is the number of faces in the mesh.
|
||||
|
||||
Args:
|
||||
material_properties: dict of properties for each material. If a material
|
||||
does not have any properties it will have an emtpy dict.
|
||||
texture_images: dict of material names and texture images
|
||||
face_material_names: numpy array of the material name corresponding to each
|
||||
face. Faces which don't have an associated material will be an empty string.
|
||||
For these faces, a uniform white texture is assigned.
|
||||
faces_verts_uvs: LongTensor of shape (F, 3, 2) giving the uv coordinates for each
|
||||
vertex in the face.
|
||||
texture_size: the resolution of the per face texture map returned by this function.
|
||||
Each face will have a texture map of shape (texture_size, texture_size, 3).
|
||||
texture_wrap: string, one of ["repeat", "clamp", None]
|
||||
If `texture_wrap="repeat"` for uv values outside the range [0, 1] the integer part
|
||||
is ignored and a repeating pattern is formed.
|
||||
If `texture_wrap="clamp"` the values are clamped to the range [0, 1].
|
||||
If None, do nothing.
|
||||
|
||||
Returns:
|
||||
atlas: FloatTensor of shape (F, texture_size, texture_size, 3) giving the per
|
||||
face texture map.
|
||||
"""
|
||||
# Create an R x R texture map per face in the mesh
|
||||
R = texture_size
|
||||
F = faces_verts_uvs.shape[0]
|
||||
|
||||
# Initialize the per face texture map to a white color.
|
||||
# TODO: allow customization of this base color?
|
||||
atlas = faces_verts_uvs.new_ones(size=(F, R, R, 3))
|
||||
|
||||
# Check for empty materials.
|
||||
if not material_properties and not texture_images:
|
||||
return atlas
|
||||
|
||||
if texture_wrap == "repeat":
|
||||
# If texture uv coordinates are outside the range [0, 1] follow
|
||||
# the convention GL_REPEAT in OpenGL i.e the integer part of the coordinate
|
||||
# will be ignored and a repeating pattern is formed.
|
||||
# Shapenet data uses this format see:
|
||||
# https://shapenet.org/qaforum/index.php?qa=15&qa_1=why-is-the-texture-coordinate-in-the-obj-file-not-in-the-range
|
||||
if (faces_verts_uvs > 1).any() or (faces_verts_uvs < 0).any():
|
||||
msg = "Texture UV coordinates outside the range [0, 1]. \
|
||||
The integer part will be ignored to form a repeating pattern."
|
||||
warnings.warn(msg)
|
||||
faces_verts_uvs = faces_verts_uvs % 1
|
||||
elif texture_wrap == "clamp":
|
||||
# Clamp uv coordinates to the [0, 1] range.
|
||||
faces_verts_uvs = faces_verts_uvs.clamp(0.0, 1.0)
|
||||
|
||||
# Iterate through the material properties - not
|
||||
# all materials have texture images so this has to be
|
||||
# done separately to the texture interpolation.
|
||||
for material_name, props in material_properties.items():
|
||||
# Bool to indicate which faces use this texture map.
|
||||
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
|
||||
faces_verts_uvs.device
|
||||
)
|
||||
if (faces_material_ind).sum() > 0:
|
||||
# For these faces, update the base color to the
|
||||
# diffuse material color.
|
||||
if "diffuse_color" not in props:
|
||||
continue
|
||||
atlas[faces_material_ind, ...] = props["diffuse_color"][None, :]
|
||||
|
||||
# Iterate through the materials used in this mesh. Update the
|
||||
# texture atlas for the faces which use this material.
|
||||
# Faces without texture are white.
|
||||
for material_name, image in list(texture_images.items()):
|
||||
# Only use the RGB colors
|
||||
if image.shape[2] == 4:
|
||||
image = image[:, :, :3]
|
||||
|
||||
# Reverse the image y direction
|
||||
image = torch.flip(image, [0]).type_as(faces_verts_uvs)
|
||||
|
||||
# Bool to indicate which faces use this texture map.
|
||||
faces_material_ind = torch.from_numpy(face_material_names == material_name).to(
|
||||
faces_verts_uvs.device
|
||||
)
|
||||
|
||||
# Find the subset of faces which use this texture with this texture image
|
||||
uvs_subset = faces_verts_uvs[faces_material_ind, :, :]
|
||||
|
||||
# Update the texture atlas for the faces which use this texture.
|
||||
# TODO: should the texture map values be multiplied
|
||||
# by the diffuse material color (i.e. use *= as the atlas has
|
||||
# been initialized to the diffuse color)?. This is
|
||||
# not being done in SoftRas.
|
||||
atlas[faces_material_ind, :, :] = make_material_atlas(image, uvs_subset, R)
|
||||
|
||||
return atlas
|
||||
|
||||
|
||||
def make_material_atlas(
|
||||
image: torch.Tensor, faces_verts_uvs: torch.Tensor, texture_size: int
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Given a single texture image and the uv coordinates for all the
|
||||
face vertices, create a square texture map per face using
|
||||
the formulation from [1].
|
||||
|
||||
For a triangle with vertices (v0, v1, v2) we can create a barycentric coordinate system
|
||||
with the x axis being the vector (v1 - v0) and the y axis being the vector (v2 - v0).
|
||||
The barycentric coordinates range from [0, 1] in the +x and +y direction so this creates
|
||||
a triangular texture space with vertices at (0, 1), (0, 0) and (1, 0).
|
||||
|
||||
The per face texture map is of shape (texture_size, texture_size, 3)
|
||||
which is a square. To map a triangular texture to a square grid, each
|
||||
triangle is parametrized as follows (e.g. R = texture_size = 3):
|
||||
|
||||
The triangle texture is first divided into RxR = 9 subtriangles which each
|
||||
map to one grid cell. The numbers in the grid cells and triangles show the mapping.
|
||||
|
||||
..code-block::python
|
||||
|
||||
Triangular Texture Space:
|
||||
|
||||
1
|
||||
|\
|
||||
|6 \
|
||||
|____\
|
||||
|\ 7 |\
|
||||
|3 \ |4 \
|
||||
|____\|____\
|
||||
|\ 8 |\ 5 |\
|
||||
|0 \ |1 \ |2 \
|
||||
|____\|____\|____\
|
||||
0 1
|
||||
|
||||
Square per face texture map:
|
||||
|
||||
R ____________________
|
||||
| | | |
|
||||
| 6 | 7 | 8 |
|
||||
|______|______|______|
|
||||
| | | |
|
||||
| 3 | 4 | 5 |
|
||||
|______|______|______|
|
||||
| | | |
|
||||
| 0 | 1 | 2 |
|
||||
|______|______|______|
|
||||
0 R
|
||||
|
||||
|
||||
The barycentric coordinates of each grid cell are calculated using the
|
||||
xy coordinates:
|
||||
|
||||
..code-block::python
|
||||
|
||||
The cartesian coordinates are:
|
||||
|
||||
Grid 1:
|
||||
|
||||
R ____________________
|
||||
| | | |
|
||||
| 20 | 21 | 22 |
|
||||
|______|______|______|
|
||||
| | | |
|
||||
| 10 | 11 | 12 |
|
||||
|______|______|______|
|
||||
| | | |
|
||||
| 00 | 01 | 02 |
|
||||
|______|______|______|
|
||||
0 R
|
||||
|
||||
where 02 means y = 0, x = 2
|
||||
|
||||
Now consider this subset of the triangle which corresponds to
|
||||
grid cells 0 and 8:
|
||||
|
||||
..code-block::python
|
||||
|
||||
1/R ________
|
||||
|\ 8 |
|
||||
| \ |
|
||||
| 0 \ |
|
||||
|_______\|
|
||||
0 1/R
|
||||
|
||||
The centroids of the triangles are:
|
||||
0: (1/3, 1/3) * 1/R
|
||||
8: (2/3, 2/3) * 1/R
|
||||
|
||||
For each grid cell we can now calculate the centroid `(c_y, c_x)`
|
||||
of the corresponding texture triangle:
|
||||
- if `(x + y) < R`, then offsett the centroid of
|
||||
triangle 0 by `(y, x) * (1/R)`
|
||||
- if `(x + y) > R`, then offset the centroid of
|
||||
triangle 8 by `((R-1-y), (R-1-x)) * (1/R)`.
|
||||
|
||||
This is equivalent to updating the portion of Grid 1
|
||||
above the diagnonal, replacing `(y, x)` with `((R-1-y), (R-1-x))`:
|
||||
|
||||
..code-block::python
|
||||
|
||||
R _____________________
|
||||
| | | |
|
||||
| 20 | 01 | 00 |
|
||||
|______|______|______|
|
||||
| | | |
|
||||
| 10 | 11 | 10 |
|
||||
|______|______|______|
|
||||
| | | |
|
||||
| 00 | 01 | 02 |
|
||||
|______|______|______|
|
||||
0 R
|
||||
|
||||
The barycentric coordinates (w0, w1, w2) are then given by:
|
||||
|
||||
..code-block::python
|
||||
|
||||
w0 = c_x
|
||||
w1 = c_y
|
||||
w2 = 1- w0 - w1
|
||||
|
||||
Args:
|
||||
image: FloatTensor of shape (H, W, 3)
|
||||
faces_verts_uvs: uv coordinates for each vertex in each face (F, 3, 2)
|
||||
texture_size: int
|
||||
|
||||
Returns:
|
||||
atlas: a FloatTensor of shape (F, texture_size, texture_size, 3) giving a
|
||||
per face texture map.
|
||||
|
||||
[1] Liu et al, 'Soft Rasterizer: A Differentiable Renderer for Image-based
|
||||
3D Reasoning', ICCV 2019
|
||||
"""
|
||||
R = texture_size
|
||||
device = faces_verts_uvs.device
|
||||
rng = torch.arange(R, device=device)
|
||||
|
||||
# Meshgrid returns (row, column) i.e (Y, X)
|
||||
# Change order to (X, Y) to make the grid.
|
||||
Y, X = torch.meshgrid(rng, rng)
|
||||
grid = torch.stack([X, Y], axis=-1) # (R, R, 2)
|
||||
|
||||
# Grid cells below the diagonal: x + y < R.
|
||||
below_diag = grid.sum(-1) < R
|
||||
|
||||
# map a [0, R] grid -> to a [0, 1] barycentric coordinates of
|
||||
# the texture triangle centroids.
|
||||
bary = torch.zeros((R, R, 3), device=device) # (R, R, 3)
|
||||
slc = torch.arange(2, device=device)[:, None]
|
||||
# w0, w1
|
||||
bary[below_diag, slc] = ((grid[below_diag] + 1.0 / 3.0) / R).T
|
||||
# w0, w1 for above diagonal grid cells.
|
||||
bary[~below_diag, slc] = (((R - 1.0 - grid[~below_diag]) + 2.0 / 3.0) / R).T
|
||||
# w2 = 1. - w0 - w1
|
||||
bary[..., -1] = 1 - bary[..., :2].sum(dim=-1)
|
||||
|
||||
# Calculate the uv position in the image for each pixel
|
||||
# in the per face texture map
|
||||
# (F, 1, 1, 3, 2) * (R, R, 3, 1) -> (F, R, R, 3, 2) -> (F, R, R, 2)
|
||||
uv_pos = (faces_verts_uvs[:, None, None] * bary[..., None]).sum(-2)
|
||||
|
||||
# bi-linearly interpolate the textures from the images
|
||||
# using the uv coordinates given by uv_pos.
|
||||
textures = _bilinear_interpolation_vectorized(image, uv_pos)
|
||||
|
||||
return textures
|
||||
|
||||
|
||||
def _bilinear_interpolation_vectorized(
|
||||
image: torch.Tensor, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Bi linearly interpolate the image using the uv positions in the flow-field
|
||||
grid (following the naming conventions for torch.nn.functional.grid_sample).
|
||||
|
||||
This implementation uses the same steps as in the SoftRas cuda kernel
|
||||
to make it easy to compare. This vectorized version requires less memory than
|
||||
_bilinear_interpolation_grid_sample but is slightly slower.
|
||||
If speed is an issue and the number of faces in the mesh and texture image sizes
|
||||
are small, consider using _bilinear_interpolation_grid_sample instead.
|
||||
|
||||
Args:
|
||||
image: FloatTensor of shape (H, W, D) a single image/input tensor with D
|
||||
channels.
|
||||
grid: FloatTensor of shape (N, R, R, 2) giving the pixel locations of the
|
||||
points at which to sample a value in the image. The grid values must
|
||||
be in the range [0, 1]. u is the x direction and v is the y direction.
|
||||
|
||||
Returns:
|
||||
out: FloatTensor of shape (N, H, W, D) giving the interpolated
|
||||
D dimensional value from image at each of the pixel locations in grid.
|
||||
|
||||
"""
|
||||
H, W, _ = image.shape
|
||||
# Convert [0, 1] to the range [0, W-1] and [0, H-1]
|
||||
grid = grid * torch.tensor([W - 1, H - 1]).type_as(grid)
|
||||
weight_1 = grid - grid.int()
|
||||
weight_0 = 1.0 - weight_1
|
||||
|
||||
grid_x, grid_y = grid.unbind(-1)
|
||||
y0 = grid_y.to(torch.int64)
|
||||
y1 = (grid_y + 1).to(torch.int64)
|
||||
x0 = grid_x.to(torch.int64)
|
||||
x1 = x0 + 1
|
||||
|
||||
weight_x0, weight_y0 = weight_0.unbind(-1)
|
||||
weight_x1, weight_y1 = weight_1.unbind(-1)
|
||||
|
||||
# Bi-linear interpolation
|
||||
# griditions = [[y, x], [(y+1), x]
|
||||
# [y, (x+1)], [(y+1), (x+1)]]
|
||||
# weights = [[wx0*wy0, wx0*wy1],
|
||||
# [wx1*wy0, wx1*wy1]]
|
||||
out = (
|
||||
image[y0, x0] * (weight_x0 * weight_y0)[..., None]
|
||||
+ image[y1, x0] * (weight_x0 * weight_y1)[..., None]
|
||||
+ image[y0, x1] * (weight_x1 * weight_y0)[..., None]
|
||||
+ image[y1, x1] * (weight_x1 * weight_y1)[..., None]
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _bilinear_interpolation_grid_sample(
|
||||
image: torch.Tensor, grid: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Bi linearly interpolate the image using the uv positions in the flow-field
|
||||
grid (following the conventions for torch.nn.functional.grid_sample).
|
||||
|
||||
This implementation is faster than _bilinear_interpolation_vectorized but
|
||||
requires more memory so can cause OOMs. If speed is an issue try this function
|
||||
instead.
|
||||
|
||||
Args:
|
||||
image: FloatTensor of shape (H, W, D) a single image/input tensor with D
|
||||
channels.
|
||||
grid: FloatTensor of shape (N, R, R, 2) giving the pixel locations of the
|
||||
points at which to sample a value in the image. The grid values must
|
||||
be in the range [0, 1]. u is the x direction and v is the y direction.
|
||||
|
||||
Returns:
|
||||
out: FloatTensor of shape (N, H, W, D) giving the interpolated
|
||||
D dimensional value from image at each of the pixel locations in grid.
|
||||
"""
|
||||
|
||||
N = grid.shape[0]
|
||||
# convert [0, 1] to the range [-1, 1] expected by grid_sample.
|
||||
grid = grid * 2.0 - 1.0
|
||||
image = image.permute(2, 0, 1)[None, ...].expand(N, -1, -1, -1) # (N, 3, H, W)
|
||||
# Align_corners has to be set to True to match the output of the SoftRas
|
||||
# cuda kernel for bilinear sampling.
|
||||
out = F.grid_sample(image, grid, mode="bilinear", align_corners=True)
|
||||
return out.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
def load_mtl(f_mtl, material_names: List, data_dir: str, device="cpu"):
|
||||
"""
|
||||
Load texture images and material reflectivity values for ambient, diffuse
|
||||
and specular light (Ka, Kd, Ks, Ns).
|
||||
|
||||
Args:
|
||||
f_mtl: a file like object of the material information.
|
||||
material_names: a list of the material names found in the .obj file.
|
||||
data_dir: the directory where the material texture files are located.
|
||||
|
||||
Returns:
|
||||
material_colors: dict of properties for each material. If a material
|
||||
does not have any properties it will have an emtpy dict.
|
||||
{
|
||||
material_name_1: {
|
||||
"ambient_color": tensor of shape (1, 3),
|
||||
"diffuse_color": tensor of shape (1, 3),
|
||||
"specular_color": tensor of shape (1, 3),
|
||||
"shininess": tensor of shape (1)
|
||||
},
|
||||
material_name_2: {},
|
||||
...
|
||||
}
|
||||
texture_images: dict of material names and texture images
|
||||
{
|
||||
material_name_1: (H, W, 3) image,
|
||||
...
|
||||
}
|
||||
"""
|
||||
texture_files = {}
|
||||
material_colors = {}
|
||||
material_properties = {}
|
||||
texture_images = {}
|
||||
material_name = ""
|
||||
|
||||
f_mtl, new_f = _open_file(f_mtl)
|
||||
lines = [line.strip() for line in f_mtl]
|
||||
for line in lines:
|
||||
if len(line.split()) != 0:
|
||||
if line.split()[0] == "newmtl":
|
||||
material_name = line.split()[1]
|
||||
material_colors[material_name] = {}
|
||||
if line.split()[0] == "map_Kd":
|
||||
# Texture map.
|
||||
texture_files[material_name] = line.split()[1]
|
||||
if line.split()[0] == "Kd":
|
||||
# RGB diffuse reflectivity
|
||||
kd = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
kd = torch.from_numpy(kd).to(device)
|
||||
material_colors[material_name]["diffuse_color"] = kd
|
||||
if line.split()[0] == "Ka":
|
||||
# RGB ambient reflectivity
|
||||
ka = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ka = torch.from_numpy(ka).to(device)
|
||||
material_colors[material_name]["ambient_color"] = ka
|
||||
if line.split()[0] == "Ks":
|
||||
# RGB specular reflectivity
|
||||
ks = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ks = torch.from_numpy(ks).to(device)
|
||||
material_colors[material_name]["specular_color"] = ks
|
||||
if line.split()[0] == "Ns":
|
||||
# Specular exponent
|
||||
ns = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ns = torch.from_numpy(ns).to(device)
|
||||
material_colors[material_name]["shininess"] = ns
|
||||
|
||||
if new_f:
|
||||
f_mtl.close()
|
||||
|
||||
# Only keep the materials referenced in the obj.
|
||||
for name in material_names:
|
||||
if name in texture_files:
|
||||
# Load the texture image.
|
||||
filename = texture_files[name]
|
||||
filename_texture = os.path.join(data_dir, filename)
|
||||
if os.path.isfile(filename_texture):
|
||||
image = _read_image(filename_texture, format="RGB") / 255.0
|
||||
image = torch.from_numpy(image)
|
||||
texture_images[name] = image
|
||||
else:
|
||||
msg = f"Texture file does not exist: {filename_texture}"
|
||||
warnings.warn(msg)
|
||||
|
||||
if name in material_colors:
|
||||
material_properties[name] = material_colors[name]
|
||||
|
||||
return material_properties, texture_images
|
@ -6,55 +6,34 @@ import os
|
||||
import pathlib
|
||||
import warnings
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fvcore.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
|
||||
from pytorch3d.io.utils import _open_file
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
|
||||
|
||||
|
||||
def _make_tensor(data, cols: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
def _make_tensor(data, cols: int, dtype: torch.dtype, device="cpu") -> torch.Tensor:
|
||||
"""
|
||||
Return a 2D tensor with the specified cols and dtype filled with data,
|
||||
even when data is empty.
|
||||
"""
|
||||
if not data:
|
||||
return torch.zeros((0, cols), dtype=dtype)
|
||||
return torch.zeros((0, cols), dtype=dtype, device=device)
|
||||
|
||||
return torch.tensor(data, dtype=dtype)
|
||||
|
||||
|
||||
def _read_image(file_name: str, format=None):
|
||||
"""
|
||||
Read an image from a file using Pillow.
|
||||
Args:
|
||||
file_name: image file path.
|
||||
format: one of ["RGB", "BGR"]
|
||||
Returns:
|
||||
image: an image of shape (H, W, C).
|
||||
"""
|
||||
if format not in ["RGB", "BGR"]:
|
||||
raise ValueError("format can only be one of [RGB, BGR]; got %s", format)
|
||||
with PathManager.open(file_name, "rb") as f:
|
||||
image = Image.open(f)
|
||||
if format is not None:
|
||||
# PIL only supports RGB. First convert to RGB and flip channels
|
||||
# below for BGR.
|
||||
image = image.convert("RGB")
|
||||
image = np.asarray(image).astype(np.float32)
|
||||
if format == "BGR":
|
||||
image = image[:, :, ::-1]
|
||||
return image
|
||||
return torch.tensor(data, dtype=dtype, device=device)
|
||||
|
||||
|
||||
# Faces & Aux type returned from load_obj function.
|
||||
_Faces = namedtuple("Faces", "verts_idx normals_idx textures_idx materials_idx")
|
||||
_Aux = namedtuple("Properties", "normals verts_uvs material_colors texture_images")
|
||||
_Aux = namedtuple(
|
||||
"Properties", "normals verts_uvs material_colors texture_images texture_atlas"
|
||||
)
|
||||
|
||||
|
||||
def _format_faces_indices(faces_indices, max_index):
|
||||
def _format_faces_indices(faces_indices, max_index, device, pad_value=None):
|
||||
"""
|
||||
Format indices and check for invalid values. Indices can refer to
|
||||
values in one of the face properties: vertices, textures or normals.
|
||||
@ -70,7 +49,12 @@ def _format_faces_indices(faces_indices, max_index):
|
||||
Raises:
|
||||
ValueError if indices are not in a valid range.
|
||||
"""
|
||||
faces_indices = _make_tensor(faces_indices, cols=3, dtype=torch.int64)
|
||||
faces_indices = _make_tensor(
|
||||
faces_indices, cols=3, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
if pad_value:
|
||||
mask = faces_indices.eq(pad_value).all(-1)
|
||||
|
||||
# Change to 0 based indexing.
|
||||
faces_indices[(faces_indices > 0)] -= 1
|
||||
@ -78,6 +62,9 @@ def _format_faces_indices(faces_indices, max_index):
|
||||
# Negative indexing counts from the end.
|
||||
faces_indices[(faces_indices < 0)] += max_index
|
||||
|
||||
if pad_value:
|
||||
faces_indices[mask] = pad_value
|
||||
|
||||
# Check indices are valid.
|
||||
if torch.any(faces_indices >= max_index) or torch.any(faces_indices < 0):
|
||||
warnings.warn("Faces have invalid indices")
|
||||
@ -85,18 +72,14 @@ def _format_faces_indices(faces_indices, max_index):
|
||||
return faces_indices
|
||||
|
||||
|
||||
def _open_file(f):
|
||||
new_f = False
|
||||
if isinstance(f, str):
|
||||
new_f = True
|
||||
f = open(f, "r")
|
||||
elif isinstance(f, pathlib.Path):
|
||||
new_f = True
|
||||
f = f.open("r")
|
||||
return f, new_f
|
||||
|
||||
|
||||
def load_obj(f_obj, load_textures=True):
|
||||
def load_obj(
|
||||
f_obj,
|
||||
load_textures=True,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
texture_wrap: Optional[str] = "repeat",
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
Load a mesh from a .obj file and optionally textures from a .mtl file.
|
||||
Currently this handles verts, faces, vertex texture uv coordinates, normals,
|
||||
@ -155,6 +138,18 @@ def load_obj(f_obj, load_textures=True):
|
||||
f: A file-like object (with methods read, readline, tell, and seek),
|
||||
a pathlib path or a string containing a file name.
|
||||
load_textures: Boolean indicating whether material files are loaded
|
||||
create_texture_atlas: Bool, If True a per face texture map is created and
|
||||
a tensor `texture_atlas` is also returned in `aux`.
|
||||
texture_atlas_size: Int specifying the resolution of the texture map per face
|
||||
when `create_texture_atlas=True`. A (texture_size, texture_size, 3)
|
||||
map is created per face.
|
||||
texture_wrap: string, one of ["repeat", "clamp"]. This applies when computing
|
||||
the texture atlas.
|
||||
If `texture_mode="repeat"`, for uv values outside the range [0, 1] the integer part
|
||||
is ignored and a repeating pattern is formed.
|
||||
If `texture_mode="clamp"` the values are clamped to the range [0, 1].
|
||||
If None, then there is no transformation of the texture values.
|
||||
device: string or torch.device on which to return the new tensors.
|
||||
|
||||
Returns:
|
||||
6-element tuple containing
|
||||
@ -181,9 +176,8 @@ def load_obj(f_obj, load_textures=True):
|
||||
possible that the number of verts_uvs is greater than
|
||||
num verts i.e. T > V.
|
||||
vertex.
|
||||
- material_colors: dict of material names and associated properties.
|
||||
If a material does not have any properties it will have an
|
||||
empty dict.
|
||||
- material_colors: if `load_textures=True` and the material has associated
|
||||
properties this will be a dict of material names and properties of the form:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -197,20 +191,40 @@ def load_obj(f_obj, load_textures=True):
|
||||
material_name_2: {},
|
||||
...
|
||||
}
|
||||
- texture_images: dict of material names and texture images.
|
||||
|
||||
If a material does not have any properties it will have an
|
||||
empty dict. If `load_textures=False`, `material_colors` will None.
|
||||
|
||||
- texture_images: if `load_textures=True` and the material has a texture map,
|
||||
this will be a dict of the form:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
material_name_1: (H, W, 3) image,
|
||||
...
|
||||
}
|
||||
If `load_textures=False`, `texture_images` will None.
|
||||
- texture_atlas: if `load_textures=True` and `create_texture_atlas=True`,
|
||||
this will be a FloatTensor of the form: (F, texture_size, textures_size, 3)
|
||||
If the material does not have a texture map, then all faces
|
||||
will have a uniform white texture. Otherwise `texture_atlas` will be
|
||||
None.
|
||||
"""
|
||||
data_dir = "./"
|
||||
if isinstance(f_obj, (str, bytes, os.PathLike)):
|
||||
data_dir = os.path.dirname(f_obj)
|
||||
f_obj, new_f = _open_file(f_obj)
|
||||
try:
|
||||
return _load(f_obj, data_dir, load_textures=load_textures)
|
||||
return _load(
|
||||
f_obj,
|
||||
data_dir,
|
||||
load_textures=load_textures,
|
||||
create_texture_atlas=create_texture_atlas,
|
||||
texture_atlas_size=texture_atlas_size,
|
||||
texture_wrap=texture_wrap,
|
||||
device=device,
|
||||
)
|
||||
finally:
|
||||
if new_f:
|
||||
f_obj.close()
|
||||
@ -235,6 +249,7 @@ def load_objs_as_meshes(files: list, device=None, load_textures: bool = True):
|
||||
"""
|
||||
mesh_list = []
|
||||
for f_obj in files:
|
||||
# TODO: update this function to support the two texturing options.
|
||||
verts, faces, aux = load_obj(f_obj, load_textures=load_textures)
|
||||
verts = verts.to(device)
|
||||
tex = None
|
||||
@ -286,6 +301,10 @@ def _parse_face(
|
||||
# Triplets must be consistent for all vertices in a face e.g.
|
||||
# legal statement: f 4/1/1 3/2/1 2/1/1.
|
||||
# illegal statement: f 4/1/1 3//1 2//1.
|
||||
# If the face does not have normals or textures indices
|
||||
# fill with pad value = -1. This will ensure that
|
||||
# all the face index tensors will have F values where
|
||||
# F is the number of faces.
|
||||
if len(face_normals) > 0:
|
||||
if not (len(face_verts) == len(face_normals)):
|
||||
raise ValueError(
|
||||
@ -293,6 +312,8 @@ def _parse_face(
|
||||
Vertex properties are inconsistent. Line: %s"
|
||||
% (str(face), str(line))
|
||||
)
|
||||
else:
|
||||
face_normals = [-1] * len(face_verts) # Fill with -1
|
||||
if len(face_textures) > 0:
|
||||
if not (len(face_verts) == len(face_textures)):
|
||||
raise ValueError(
|
||||
@ -300,28 +321,41 @@ def _parse_face(
|
||||
Vertex properties are inconsistent. Line: %s"
|
||||
% (str(face), str(line))
|
||||
)
|
||||
else:
|
||||
face_textures = [-1] * len(face_verts) # Fill with -1
|
||||
|
||||
# Subdivide faces with more than 3 vertices. See comments of the
|
||||
# load_obj function for more details.
|
||||
# Subdivide faces with more than 3 vertices.
|
||||
# See comments of the load_obj function for more details.
|
||||
for i in range(len(face_verts) - 2):
|
||||
faces_verts_idx.append((face_verts[0], face_verts[i + 1], face_verts[i + 2]))
|
||||
if len(face_normals) > 0:
|
||||
faces_normals_idx.append(
|
||||
(face_normals[0], face_normals[i + 1], face_normals[i + 2])
|
||||
)
|
||||
if len(face_textures) > 0:
|
||||
faces_textures_idx.append(
|
||||
(face_textures[0], face_textures[i + 1], face_textures[i + 2])
|
||||
)
|
||||
faces_normals_idx.append(
|
||||
(face_normals[0], face_normals[i + 1], face_normals[i + 2])
|
||||
)
|
||||
faces_textures_idx.append(
|
||||
(face_textures[0], face_textures[i + 1], face_textures[i + 2])
|
||||
)
|
||||
faces_materials_idx.append(material_idx)
|
||||
|
||||
|
||||
def _load(f_obj, data_dir, load_textures=True):
|
||||
def _load(
|
||||
f_obj,
|
||||
data_dir,
|
||||
load_textures: bool = True,
|
||||
create_texture_atlas: bool = False,
|
||||
texture_atlas_size: int = 4,
|
||||
texture_wrap: Optional[str] = "repeat",
|
||||
device="cpu",
|
||||
):
|
||||
"""
|
||||
Load a mesh from a file-like object. See load_obj function more details.
|
||||
Any material files associated with the obj are expected to be in the
|
||||
directory given by data_dir.
|
||||
"""
|
||||
|
||||
if texture_wrap is not None and texture_wrap not in ["repeat", "clamp"]:
|
||||
msg = "texture_wrap must be one of ['repeat', 'clamp'] or None, got %s"
|
||||
raise ValueError(msg % texture_wrap)
|
||||
|
||||
lines = [line.strip() for line in f_obj]
|
||||
verts = []
|
||||
normals = []
|
||||
@ -343,12 +377,19 @@ def _load(f_obj, data_dir, load_textures=True):
|
||||
if line.startswith("mtllib"):
|
||||
if len(line.split()) < 2:
|
||||
raise ValueError("material file name is not specified")
|
||||
# NOTE: this assumes only one mtl file per .obj.
|
||||
# NOTE: only allow one .mtl file per .obj.
|
||||
# Definitions for multiple materials can be included
|
||||
# in this one .mtl file.
|
||||
f_mtl = os.path.join(data_dir, line.split()[1])
|
||||
elif len(line.split()) != 0 and line.split()[0] == "usemtl":
|
||||
material_name = line.split()[1]
|
||||
material_names.append(material_name)
|
||||
materials_idx = len(material_names) - 1
|
||||
# materials are often repeated for different parts
|
||||
# of a mesh.
|
||||
if material_name not in material_names:
|
||||
material_names.append(material_name)
|
||||
materials_idx = len(material_names) - 1
|
||||
else:
|
||||
materials_idx = material_names.index(material_name)
|
||||
elif line.startswith("v "):
|
||||
# Line is a vertex.
|
||||
vert = [float(x) for x in line.split()[1:4]]
|
||||
@ -372,7 +413,7 @@ def _load(f_obj, data_dir, load_textures=True):
|
||||
raise ValueError(msg % (str(norm), str(line)))
|
||||
normals.append(norm)
|
||||
elif line.startswith("f "):
|
||||
# Line is a face.
|
||||
# Line is a face update face properties info.
|
||||
_parse_face(
|
||||
line,
|
||||
materials_idx,
|
||||
@ -382,30 +423,63 @@ def _load(f_obj, data_dir, load_textures=True):
|
||||
faces_materials_idx,
|
||||
)
|
||||
|
||||
verts = _make_tensor(verts, cols=3, dtype=torch.float32) # (V, 3)
|
||||
normals = _make_tensor(normals, cols=3, dtype=torch.float32) # (N, 3)
|
||||
verts_uvs = _make_tensor(verts_uvs, cols=2, dtype=torch.float32) # (T, 2)
|
||||
verts = _make_tensor(verts, cols=3, dtype=torch.float32, device=device) # (V, 3)
|
||||
normals = _make_tensor(
|
||||
normals, cols=3, dtype=torch.float32, device=device
|
||||
) # (N, 3)
|
||||
verts_uvs = _make_tensor(
|
||||
verts_uvs, cols=2, dtype=torch.float32, device=device
|
||||
) # (T, 2)
|
||||
|
||||
faces_verts_idx = _format_faces_indices(faces_verts_idx, verts.shape[0])
|
||||
faces_verts_idx = _format_faces_indices(
|
||||
faces_verts_idx, verts.shape[0], device=device
|
||||
)
|
||||
|
||||
# Repeat for normals and textures if present.
|
||||
if len(faces_normals_idx) > 0:
|
||||
faces_normals_idx = _format_faces_indices(faces_normals_idx, normals.shape[0])
|
||||
faces_normals_idx = _format_faces_indices(
|
||||
faces_normals_idx, normals.shape[0], device=device, pad_value=-1
|
||||
)
|
||||
if len(faces_textures_idx) > 0:
|
||||
faces_textures_idx = _format_faces_indices(
|
||||
faces_textures_idx, verts_uvs.shape[0]
|
||||
faces_textures_idx, verts_uvs.shape[0], device=device, pad_value=-1
|
||||
)
|
||||
if len(faces_materials_idx) > 0:
|
||||
faces_materials_idx = torch.tensor(faces_materials_idx, dtype=torch.int64)
|
||||
faces_materials_idx = torch.tensor(
|
||||
faces_materials_idx, dtype=torch.int64, device=device
|
||||
)
|
||||
|
||||
# Load materials
|
||||
material_colors, texture_images = None, None
|
||||
material_colors, texture_images, texture_atlas = None, None, None
|
||||
if load_textures:
|
||||
if (len(material_names) > 0) and (f_mtl is not None):
|
||||
if os.path.isfile(f_mtl):
|
||||
# Texture mode uv wrap
|
||||
material_colors, texture_images = load_mtl(
|
||||
f_mtl, material_names, data_dir
|
||||
f_mtl, material_names, data_dir, device=device
|
||||
)
|
||||
if create_texture_atlas:
|
||||
# Using the images and properties from the
|
||||
# material file make a per face texture map.
|
||||
|
||||
# Create an array of strings of material names for each face.
|
||||
# If faces_materials_idx == -1 then that face doesn't have a material.
|
||||
idx = faces_materials_idx.cpu().numpy()
|
||||
face_material_names = np.array(material_names)[idx] # (F,)
|
||||
face_material_names[idx == -1] = ""
|
||||
|
||||
# Get the uv coords for each vert in each face
|
||||
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
|
||||
|
||||
# Construct the atlas.
|
||||
texture_atlas = make_mesh_texture_atlas(
|
||||
material_colors,
|
||||
texture_images,
|
||||
face_material_names,
|
||||
faces_verts_uvs,
|
||||
texture_atlas_size,
|
||||
texture_wrap,
|
||||
)
|
||||
else:
|
||||
warnings.warn(f"Mtl file does not exist: {f_mtl}")
|
||||
elif len(material_names) > 0:
|
||||
@ -423,99 +497,11 @@ def _load(f_obj, data_dir, load_textures=True):
|
||||
verts_uvs=verts_uvs if len(verts_uvs) > 0 else None,
|
||||
material_colors=material_colors,
|
||||
texture_images=texture_images,
|
||||
texture_atlas=texture_atlas,
|
||||
)
|
||||
return verts, faces, aux
|
||||
|
||||
|
||||
def load_mtl(f_mtl, material_names: List, data_dir: str):
|
||||
"""
|
||||
Load texture images and material reflectivity values for ambient, diffuse
|
||||
and specular light (Ka, Kd, Ks, Ns).
|
||||
|
||||
Args:
|
||||
f_mtl: a file like object of the material information.
|
||||
material_names: a list of the material names found in the .obj file.
|
||||
data_dir: the directory where the material texture files are located.
|
||||
|
||||
Returns:
|
||||
material_colors: dict of properties for each material. If a material
|
||||
does not have any properties it will have an emtpy dict.
|
||||
{
|
||||
material_name_1: {
|
||||
"ambient_color": tensor of shape (1, 3),
|
||||
"diffuse_color": tensor of shape (1, 3),
|
||||
"specular_color": tensor of shape (1, 3),
|
||||
"shininess": tensor of shape (1)
|
||||
},
|
||||
material_name_2: {},
|
||||
...
|
||||
}
|
||||
texture_images: dict of material names and texture images
|
||||
{
|
||||
material_name_1: (H, W, 3) image,
|
||||
...
|
||||
}
|
||||
"""
|
||||
texture_files = {}
|
||||
material_colors = {}
|
||||
material_properties = {}
|
||||
texture_images = {}
|
||||
material_name = ""
|
||||
|
||||
f_mtl, new_f = _open_file(f_mtl)
|
||||
lines = [line.strip() for line in f_mtl]
|
||||
for line in lines:
|
||||
if len(line.split()) != 0:
|
||||
if line.split()[0] == "newmtl":
|
||||
material_name = line.split()[1]
|
||||
material_colors[material_name] = {}
|
||||
if line.split()[0] == "map_Kd":
|
||||
# Texture map.
|
||||
texture_files[material_name] = line.split()[1]
|
||||
if line.split()[0] == "Kd":
|
||||
# RGB diffuse reflectivity
|
||||
kd = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
kd = torch.from_numpy(kd)
|
||||
material_colors[material_name]["diffuse_color"] = kd
|
||||
if line.split()[0] == "Ka":
|
||||
# RGB ambient reflectivity
|
||||
ka = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ka = torch.from_numpy(ka)
|
||||
material_colors[material_name]["ambient_color"] = ka
|
||||
if line.split()[0] == "Ks":
|
||||
# RGB specular reflectivity
|
||||
ks = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ks = torch.from_numpy(ks)
|
||||
material_colors[material_name]["specular_color"] = ks
|
||||
if line.split()[0] == "Ns":
|
||||
# Specular exponent
|
||||
ns = np.array(list(line.split()[1:4])).astype(np.float32)
|
||||
ns = torch.from_numpy(ns)
|
||||
material_colors[material_name]["shininess"] = ns
|
||||
|
||||
if new_f:
|
||||
f_mtl.close()
|
||||
|
||||
# Only keep the materials referenced in the obj.
|
||||
for name in material_names:
|
||||
if name in texture_files:
|
||||
# Load the texture image.
|
||||
filename = texture_files[name]
|
||||
filename_texture = os.path.join(data_dir, filename)
|
||||
if os.path.isfile(filename_texture):
|
||||
image = _read_image(filename_texture, format="RGB") / 255.0
|
||||
image = torch.from_numpy(image)
|
||||
texture_images[name] = image
|
||||
else:
|
||||
msg = f"Texture file does not exist: {filename_texture}"
|
||||
warnings.warn(msg)
|
||||
|
||||
if name in material_colors:
|
||||
material_properties[name] = material_colors[name]
|
||||
|
||||
return material_properties, texture_images
|
||||
|
||||
|
||||
def save_obj(f, verts, faces, decimal_places: Optional[int] = None):
|
||||
"""
|
||||
Save a mesh to an .obj file.
|
||||
|
41
pytorch3d/io/utils.py
Normal file
41
pytorch3d/io/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import pathlib
|
||||
|
||||
import numpy as np
|
||||
from fvcore.common.file_io import PathManager
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def _open_file(f):
|
||||
new_f = False
|
||||
if isinstance(f, str):
|
||||
new_f = True
|
||||
f = open(f, "r")
|
||||
elif isinstance(f, pathlib.Path):
|
||||
new_f = True
|
||||
f = f.open("r")
|
||||
return f, new_f
|
||||
|
||||
|
||||
def _read_image(file_name: str, format=None):
|
||||
"""
|
||||
Read an image from a file using Pillow.
|
||||
Args:
|
||||
file_name: image file path.
|
||||
format: one of ["RGB", "BGR"]
|
||||
Returns:
|
||||
image: an image of shape (H, W, C).
|
||||
"""
|
||||
if format not in ["RGB", "BGR"]:
|
||||
raise ValueError("format can only be one of [RGB, BGR]; got %s", format)
|
||||
with PathManager.open(file_name, "rb") as f:
|
||||
image = Image.open(f)
|
||||
if format is not None:
|
||||
# PIL only supports RGB. First convert to RGB and flip channels
|
||||
# below for BGR.
|
||||
image = image.convert("RGB")
|
||||
image = np.asarray(image).astype(np.float32)
|
||||
if format == "BGR":
|
||||
image = image[:, :, ::-1]
|
||||
return image
|
@ -1,5 +1,7 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
from itertools import product
|
||||
|
||||
from fvcore.common.benchmark import benchmark
|
||||
from test_obj_io import TestMeshObjIO
|
||||
from test_ply_io import TestMeshPlyIO
|
||||
@ -61,3 +63,35 @@ def bm_save_load() -> None:
|
||||
complex_kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
# Texture loading benchmarks
|
||||
kwargs_list = [{"R": 2}, {"R": 4}, {"R": 10}, {"R": 15}, {"R": 20}]
|
||||
benchmark(
|
||||
TestMeshObjIO.bm_load_texture_atlas,
|
||||
"PYTORCH3D_TEXTURE_ATLAS",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
kwargs_list = []
|
||||
S = [64, 256, 1024]
|
||||
F = [100, 1000, 10000]
|
||||
R = [5, 10, 20]
|
||||
test_cases = product(S, F, R)
|
||||
|
||||
for case in test_cases:
|
||||
s, f, r = case
|
||||
kwargs_list.append({"S": s, "F": f, "R": r})
|
||||
|
||||
benchmark(
|
||||
TestMeshObjIO.bm_bilinear_sampling_vectorized,
|
||||
"BILINEAR_VECTORIZED",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
benchmark(
|
||||
TestMeshObjIO.bm_bilinear_sampling_grid_sample,
|
||||
"BILINEAR_GRID_SAMPLE",
|
||||
kwargs_list,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
@ -2,12 +2,17 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
import warnings
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from common_testing import TestCaseMixin
|
||||
from pytorch3d.io import load_obj, load_objs_as_meshes, save_obj
|
||||
from pytorch3d.io.mtl_io import (
|
||||
_bilinear_interpolation_grid_sample,
|
||||
_bilinear_interpolation_vectorized,
|
||||
)
|
||||
from pytorch3d.structures import Meshes, Textures, join_meshes_as_batch
|
||||
from pytorch3d.utils import torus
|
||||
|
||||
@ -47,8 +52,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(torch.all(verts == expected_verts))
|
||||
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
|
||||
self.assertTrue(faces.normals_idx == [])
|
||||
self.assertTrue(faces.textures_idx == [])
|
||||
padded_vals = -torch.ones_like(faces.verts_idx)
|
||||
self.assertTrue(torch.all(faces.normals_idx == padded_vals))
|
||||
self.assertTrue(torch.all(faces.textures_idx == padded_vals))
|
||||
self.assertTrue(
|
||||
torch.all(faces.materials_idx == -torch.ones(len(expected_faces)))
|
||||
)
|
||||
@ -118,8 +124,12 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
[[0.749279, 0.501284], [0.999110, 0.501077], [0.999455, 0.750380]],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
expected_faces_normals_idx = torch.tensor([[1, 1, 1]], dtype=torch.int64)
|
||||
expected_faces_textures_idx = torch.tensor([[0, 0, 1]], dtype=torch.int64)
|
||||
expected_faces_normals_idx = -torch.ones_like(expected_faces, dtype=torch.int64)
|
||||
expected_faces_normals_idx[4, :] = torch.tensor([1, 1, 1], dtype=torch.int64)
|
||||
expected_faces_textures_idx = -torch.ones_like(
|
||||
expected_faces, dtype=torch.int64
|
||||
)
|
||||
expected_faces_textures_idx[4, :] = torch.tensor([0, 0, 1], dtype=torch.int64)
|
||||
|
||||
self.assertTrue(torch.all(verts == expected_verts))
|
||||
self.assertTrue(torch.all(faces.verts_idx == expected_faces))
|
||||
@ -160,7 +170,8 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(faces.normals_idx, expected_faces_normals_idx)
|
||||
self.assertClose(normals, expected_normals)
|
||||
self.assertClose(verts, expected_verts)
|
||||
self.assertTrue(faces.textures_idx == [])
|
||||
# Textures idx padded with -1.
|
||||
self.assertClose(faces.textures_idx, torch.ones_like(faces.verts_idx) * -1)
|
||||
self.assertTrue(textures is None)
|
||||
self.assertTrue(materials is None)
|
||||
self.assertTrue(tex_maps is None)
|
||||
@ -195,7 +206,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
self.assertClose(faces.textures_idx, expected_faces_textures_idx)
|
||||
self.assertClose(expected_textures, textures)
|
||||
self.assertClose(expected_verts, verts)
|
||||
self.assertTrue(faces.normals_idx == [])
|
||||
self.assertTrue(
|
||||
torch.all(faces.normals_idx == -torch.ones_like(faces.textures_idx))
|
||||
)
|
||||
self.assertTrue(normals is None)
|
||||
self.assertTrue(materials is None)
|
||||
self.assertTrue(tex_maps is None)
|
||||
@ -408,6 +421,9 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
"shininess": torch.tensor([10.0], dtype=dtype),
|
||||
}
|
||||
}
|
||||
# Texture atlas is not created as `create_texture_atlas=True` was
|
||||
# not set in the load_obj args
|
||||
self.assertTrue(aux.texture_atlas is None)
|
||||
# Check that there is an image with material name material_1.
|
||||
self.assertTrue(tuple(tex_maps.keys()) == ("material_1",))
|
||||
self.assertTrue(torch.is_tensor(tuple(tex_maps.values())[0]))
|
||||
@ -423,6 +439,36 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
torch.allclose(materials[n1][k1], expected_materials[n2][k2])
|
||||
)
|
||||
|
||||
def test_load_mtl_texture_atlas_compare_softras(self):
|
||||
# Load saved texture atlas created with SoftRas.
|
||||
device = torch.device("cuda:0")
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent
|
||||
obj_filename = DATA_DIR / "docs/tutorials/data/cow_mesh/cow.obj"
|
||||
expected_atlas_fname = DATA_DIR / "tests/data/cow_texture_atlas_softras.pt"
|
||||
|
||||
# Note, the reference texture atlas generated using SoftRas load_obj function
|
||||
# is too large to check in to the repo. Download the file to run the test locally.
|
||||
if not os.path.exists(expected_atlas_fname):
|
||||
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/tests/cow_texture_atlas_softras.pt"
|
||||
msg = (
|
||||
"cow_texture_atlas_softras.pt not found, download from %s, save it at the path %s, and rerun"
|
||||
% (url, expected_atlas_fname)
|
||||
)
|
||||
warnings.warn(msg)
|
||||
return True
|
||||
|
||||
expected_atlas = torch.load(expected_atlas_fname)
|
||||
_, _, aux = load_obj(
|
||||
obj_filename,
|
||||
load_textures=True,
|
||||
device=device,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=15,
|
||||
texture_wrap="repeat",
|
||||
)
|
||||
|
||||
self.assertClose(expected_atlas, aux.texture_atlas, atol=5e-5)
|
||||
|
||||
def test_load_mtl_noload(self):
|
||||
DATA_DIR = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
obj_filename = "cow_mesh/cow.obj"
|
||||
@ -629,3 +675,51 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
||||
meshes = torus(r=0.25, R=1.0, sides=N, rings=2 * N)
|
||||
[verts], [faces] = meshes.verts_list(), meshes.faces_list()
|
||||
return TestMeshObjIO._bm_load_obj(verts, faces, decimal_places=5)
|
||||
|
||||
@staticmethod
|
||||
def bm_load_texture_atlas(R: int):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
DATA_DIR = "/data/users/nikhilar/fbsource/fbcode/vision/fair/pytorch3d/docs/"
|
||||
obj_filename = os.path.join(DATA_DIR, "tutorials/data/cow_mesh/cow.obj")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load():
|
||||
load_obj(
|
||||
obj_filename,
|
||||
load_textures=True,
|
||||
device=device,
|
||||
create_texture_atlas=True,
|
||||
texture_atlas_size=R,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return load
|
||||
|
||||
@staticmethod
|
||||
def bm_bilinear_sampling_vectorized(S: int, F: int, R: int):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
image = torch.rand((S, S, 3))
|
||||
grid = torch.rand((F, R, R, 2))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load():
|
||||
_bilinear_interpolation_vectorized(image, grid)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return load
|
||||
|
||||
@staticmethod
|
||||
def bm_bilinear_sampling_grid_sample(S: int, F: int, R: int):
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.set_device(device)
|
||||
image = torch.rand((S, S, 3))
|
||||
grid = torch.rand((F, R, R, 2))
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def load():
|
||||
_bilinear_interpolation_grid_sample(image, grid)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return load
|
||||
|
Loading…
x
Reference in New Issue
Block a user