mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Fix a few linting warnings
Summary: Fix a few linting warnings Reviewed By: nikhilaravi Differential Revision: D20720810 fbshipit-source-id: c5b6a25fdd7971cc8743b54bbe162464a874071d
This commit is contained in:
parent
4f8a2f1979
commit
8219a52ccc
@ -5,7 +5,7 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -345,42 +345,26 @@ def _parse_face(
|
|||||||
faces_materials_idx.append(material_idx)
|
faces_materials_idx.append(material_idx)
|
||||||
|
|
||||||
|
|
||||||
def _load_obj(
|
def _parse_obj(f, data_dir: str):
|
||||||
f_obj,
|
|
||||||
data_dir,
|
|
||||||
load_textures: bool = True,
|
|
||||||
create_texture_atlas: bool = False,
|
|
||||||
texture_atlas_size: int = 4,
|
|
||||||
texture_wrap: Optional[str] = "repeat",
|
|
||||||
device="cpu",
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Load a mesh from a file-like object. See load_obj function more details.
|
Load a mesh from a file-like object. See load_obj function for more details
|
||||||
Any material files associated with the obj are expected to be in the
|
about the return values.
|
||||||
directory given by data_dir.
|
|
||||||
"""
|
"""
|
||||||
|
verts, normals, verts_uvs = [], [], []
|
||||||
if texture_wrap is not None and texture_wrap not in ["repeat", "clamp"]:
|
faces_verts_idx, faces_normals_idx, faces_textures_idx = [], [], []
|
||||||
msg = "texture_wrap must be one of ['repeat', 'clamp'] or None, got %s"
|
|
||||||
raise ValueError(msg % texture_wrap)
|
|
||||||
|
|
||||||
lines = [line.strip() for line in f_obj]
|
|
||||||
verts = []
|
|
||||||
normals = []
|
|
||||||
verts_uvs = []
|
|
||||||
faces_verts_idx = []
|
|
||||||
faces_normals_idx = []
|
|
||||||
faces_textures_idx = []
|
|
||||||
material_names = []
|
|
||||||
faces_materials_idx = []
|
faces_materials_idx = []
|
||||||
f_mtl = None
|
material_names = []
|
||||||
materials_idx = -1
|
mtl_path = None
|
||||||
|
|
||||||
|
lines = [line.strip() for line in f]
|
||||||
|
|
||||||
# startswith expects each line to be a string. If the file is read in as
|
# startswith expects each line to be a string. If the file is read in as
|
||||||
# bytes then first decode to strings.
|
# bytes then first decode to strings.
|
||||||
if lines and isinstance(lines[0], bytes):
|
if lines and isinstance(lines[0], bytes):
|
||||||
lines = [el.decode("utf-8") for el in lines]
|
lines = [el.decode("utf-8") for el in lines]
|
||||||
|
|
||||||
|
materials_idx = -1
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
tokens = line.strip().split()
|
tokens = line.strip().split()
|
||||||
if line.startswith("mtllib"):
|
if line.startswith("mtllib"):
|
||||||
@ -389,7 +373,8 @@ def _load_obj(
|
|||||||
# NOTE: only allow one .mtl file per .obj.
|
# NOTE: only allow one .mtl file per .obj.
|
||||||
# Definitions for multiple materials can be included
|
# Definitions for multiple materials can be included
|
||||||
# in this one .mtl file.
|
# in this one .mtl file.
|
||||||
f_mtl = os.path.join(data_dir, line.split()[1])
|
mtl_path = line[len(tokens[0]) :].strip() # Take the remainder of the line
|
||||||
|
mtl_path = os.path.join(data_dir, mtl_path)
|
||||||
elif len(tokens) and tokens[0] == "usemtl":
|
elif len(tokens) and tokens[0] == "usemtl":
|
||||||
material_name = tokens[1]
|
material_name = tokens[1]
|
||||||
# materials are often repeated for different parts
|
# materials are often repeated for different parts
|
||||||
@ -430,6 +415,83 @@ def _load_obj(
|
|||||||
faces_materials_idx,
|
faces_materials_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
verts,
|
||||||
|
normals,
|
||||||
|
verts_uvs,
|
||||||
|
faces_verts_idx,
|
||||||
|
faces_normals_idx,
|
||||||
|
faces_textures_idx,
|
||||||
|
faces_materials_idx,
|
||||||
|
material_names,
|
||||||
|
mtl_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_materials(
|
||||||
|
material_names: List[str], f, data_dir: str, *, load_textures: bool, device
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load materials and optionally textures from the specified path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
material_names: a list of the material names found in the .obj file.
|
||||||
|
f: a file-like object of the material information.
|
||||||
|
data_dir: the directory where the material texture files are located.
|
||||||
|
load_textures: whether textures should be loaded.
|
||||||
|
device: string or torch.device on which to return the new tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
material_colors: dict of properties for each material.
|
||||||
|
texture_images: dict of material names and texture images.
|
||||||
|
"""
|
||||||
|
if not load_textures:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not material_names or f is None:
|
||||||
|
if material_names:
|
||||||
|
warnings.warn("No mtl file provided")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not os.path.isfile(f):
|
||||||
|
warnings.warn(f"Mtl file does not exist: {f}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Texture mode uv wrap
|
||||||
|
return load_mtl(f, material_names, data_dir, device=device)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_obj(
|
||||||
|
f_obj,
|
||||||
|
data_dir,
|
||||||
|
load_textures: bool = True,
|
||||||
|
create_texture_atlas: bool = False,
|
||||||
|
texture_atlas_size: int = 4,
|
||||||
|
texture_wrap: Optional[str] = "repeat",
|
||||||
|
device="cpu",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Load a mesh from a file-like object. See load_obj function more details.
|
||||||
|
Any material files associated with the obj are expected to be in the
|
||||||
|
directory given by data_dir.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if texture_wrap is not None and texture_wrap not in ["repeat", "clamp"]:
|
||||||
|
msg = "texture_wrap must be one of ['repeat', 'clamp'] or None, got %s"
|
||||||
|
raise ValueError(msg % texture_wrap)
|
||||||
|
|
||||||
|
(
|
||||||
|
verts,
|
||||||
|
normals,
|
||||||
|
verts_uvs,
|
||||||
|
faces_verts_idx,
|
||||||
|
faces_normals_idx,
|
||||||
|
faces_textures_idx,
|
||||||
|
faces_materials_idx,
|
||||||
|
material_names,
|
||||||
|
mtl_path,
|
||||||
|
) = _parse_obj(f_obj, data_dir)
|
||||||
|
|
||||||
verts = _make_tensor(verts, cols=3, dtype=torch.float32, device=device) # (V, 3)
|
verts = _make_tensor(verts, cols=3, dtype=torch.float32, device=device) # (V, 3)
|
||||||
normals = _make_tensor(
|
normals = _make_tensor(
|
||||||
normals, cols=3, dtype=torch.float32, device=device
|
normals, cols=3, dtype=torch.float32, device=device
|
||||||
@ -443,58 +505,47 @@ def _load_obj(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Repeat for normals and textures if present.
|
# Repeat for normals and textures if present.
|
||||||
if len(faces_normals_idx) > 0:
|
if len(faces_normals_idx):
|
||||||
faces_normals_idx = _format_faces_indices(
|
faces_normals_idx = _format_faces_indices(
|
||||||
faces_normals_idx, normals.shape[0], device=device, pad_value=-1
|
faces_normals_idx, normals.shape[0], device=device, pad_value=-1
|
||||||
)
|
)
|
||||||
if len(faces_textures_idx) > 0:
|
if len(faces_textures_idx):
|
||||||
faces_textures_idx = _format_faces_indices(
|
faces_textures_idx = _format_faces_indices(
|
||||||
faces_textures_idx, verts_uvs.shape[0], device=device, pad_value=-1
|
faces_textures_idx, verts_uvs.shape[0], device=device, pad_value=-1
|
||||||
)
|
)
|
||||||
if len(faces_materials_idx) > 0:
|
if len(faces_materials_idx):
|
||||||
faces_materials_idx = torch.tensor(
|
faces_materials_idx = torch.tensor(
|
||||||
faces_materials_idx, dtype=torch.int64, device=device
|
faces_materials_idx, dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load materials
|
texture_atlas = None
|
||||||
material_colors, texture_images, texture_atlas = None, None, None
|
material_colors, texture_images = _load_materials(
|
||||||
if load_textures:
|
material_names, mtl_path, data_dir, load_textures=load_textures, device=device
|
||||||
if (len(material_names) > 0) and (f_mtl is not None):
|
)
|
||||||
# pyre-fixme[6]: Expected `Union[_PathLike[typing.Any], bytes, str]` for
|
|
||||||
# 1st param but got `Optional[str]`.
|
|
||||||
if os.path.isfile(f_mtl):
|
|
||||||
# Texture mode uv wrap
|
|
||||||
material_colors, texture_images = load_mtl(
|
|
||||||
f_mtl, material_names, data_dir, device=device
|
|
||||||
)
|
|
||||||
if create_texture_atlas:
|
|
||||||
# Using the images and properties from the
|
|
||||||
# material file make a per face texture map.
|
|
||||||
|
|
||||||
# Create an array of strings of material names for each face.
|
if create_texture_atlas:
|
||||||
# If faces_materials_idx == -1 then that face doesn't have a material.
|
# Using the images and properties from the
|
||||||
idx = faces_materials_idx.cpu().numpy()
|
# material file make a per face texture map.
|
||||||
face_material_names = np.array(material_names)[idx] # (F,)
|
|
||||||
face_material_names[idx == -1] = ""
|
|
||||||
|
|
||||||
texture_atlas = None
|
# Create an array of strings of material names for each face.
|
||||||
if len(verts_uvs) > 0:
|
# If faces_materials_idx == -1 then that face doesn't have a material.
|
||||||
# Get the uv coords for each vert in each face
|
idx = faces_materials_idx.cpu().numpy()
|
||||||
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
|
face_material_names = np.array(material_names)[idx] # (F,)
|
||||||
|
face_material_names[idx == -1] = ""
|
||||||
|
|
||||||
# Construct the atlas.
|
if len(verts_uvs) > 0:
|
||||||
texture_atlas = make_mesh_texture_atlas(
|
# Get the uv coords for each vert in each face
|
||||||
material_colors,
|
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
|
||||||
texture_images,
|
|
||||||
face_material_names,
|
# Construct the atlas.
|
||||||
faces_verts_uvs,
|
texture_atlas = make_mesh_texture_atlas(
|
||||||
texture_atlas_size,
|
material_colors,
|
||||||
texture_wrap,
|
texture_images,
|
||||||
)
|
face_material_names,
|
||||||
else:
|
faces_verts_uvs,
|
||||||
warnings.warn(f"Mtl file does not exist: {f_mtl}")
|
texture_atlas_size,
|
||||||
elif len(material_names) > 0:
|
texture_wrap,
|
||||||
warnings.warn("No mtl file provided")
|
)
|
||||||
|
|
||||||
faces = _Faces(
|
faces = _Faces(
|
||||||
verts_idx=faces_verts_idx,
|
verts_idx=faces_verts_idx,
|
||||||
@ -502,10 +553,9 @@ def _load_obj(
|
|||||||
textures_idx=faces_textures_idx,
|
textures_idx=faces_textures_idx,
|
||||||
materials_idx=faces_materials_idx,
|
materials_idx=faces_materials_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
aux = _Aux(
|
aux = _Aux(
|
||||||
normals=normals if len(normals) > 0 else None,
|
normals=normals if len(normals) else None,
|
||||||
verts_uvs=verts_uvs if len(verts_uvs) > 0 else None,
|
verts_uvs=verts_uvs if len(verts_uvs) else None,
|
||||||
material_colors=material_colors,
|
material_colors=material_colors,
|
||||||
texture_images=texture_images,
|
texture_images=texture_images,
|
||||||
texture_atlas=texture_atlas,
|
texture_atlas=texture_atlas,
|
||||||
|
@ -282,7 +282,7 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
|
|||||||
return None
|
return None
|
||||||
if not len(data): # np.loadtxt() seeks even on empty data
|
if not len(data): # np.loadtxt() seeks even on empty data
|
||||||
f.seek(old_offset)
|
f.seek(old_offset)
|
||||||
if (data.shape[1] - 1 != data[:, 0]).any():
|
if (data[:, 0] != data.shape[1] - 1).any():
|
||||||
msg = "A line of %s data did not have the specified length."
|
msg = "A line of %s data did not have the specified length."
|
||||||
raise ValueError(msg % definition.name)
|
raise ValueError(msg % definition.name)
|
||||||
if data.shape[0] != definition.count:
|
if data.shape[0] != definition.count:
|
||||||
|
@ -657,10 +657,6 @@ class TexturesUV(TexturesBase):
|
|||||||
|
|
||||||
if verts_uvs.device != self.device:
|
if verts_uvs.device != self.device:
|
||||||
raise ValueError("verts_uvs and faces_uvs must be on the same device")
|
raise ValueError("verts_uvs and faces_uvs must be on the same device")
|
||||||
|
|
||||||
# These values may be overridden when textures is
|
|
||||||
# passed into the Meshes constructor.
|
|
||||||
max_V = verts_uvs.shape[1]
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Expected verts_uvs to be a tensor or list")
|
raise ValueError("Expected verts_uvs to be a tensor or list")
|
||||||
|
|
||||||
|
@ -498,10 +498,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
|
|||||||
# Note, the reference texture atlas generated using SoftRas load_obj function
|
# Note, the reference texture atlas generated using SoftRas load_obj function
|
||||||
# is too large to check in to the repo. Download the file to run the test locally.
|
# is too large to check in to the repo. Download the file to run the test locally.
|
||||||
if not os.path.exists(expected_atlas_fname):
|
if not os.path.exists(expected_atlas_fname):
|
||||||
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/tests/cow_texture_atlas_softras.pt"
|
url = (
|
||||||
|
"https://dl.fbaipublicfiles.com/pytorch3d/data/"
|
||||||
|
"tests/cow_texture_atlas_softras.pt"
|
||||||
|
)
|
||||||
msg = (
|
msg = (
|
||||||
"cow_texture_atlas_softras.pt not found, download from %s, save it at the path %s, and rerun"
|
"cow_texture_atlas_softras.pt not found, download from %s, "
|
||||||
% (url, expected_atlas_fname)
|
"save it at the path %s, and rerun" % (url, expected_atlas_fname)
|
||||||
)
|
)
|
||||||
warnings.warn(msg)
|
warnings.warn(msg)
|
||||||
return True
|
return True
|
||||||
|
@ -80,7 +80,10 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
|
|||||||
# Note, this file is too large to check in to the repo.
|
# Note, this file is too large to check in to the repo.
|
||||||
# Download the file to run the test locally.
|
# Download the file to run the test locally.
|
||||||
if not path.exists(pointcloud_filename):
|
if not path.exists(pointcloud_filename):
|
||||||
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz"
|
url = (
|
||||||
|
"https://dl.fbaipublicfiles.com/pytorch3d/data/"
|
||||||
|
"PittsburghBridge/pointcloud.npz"
|
||||||
|
)
|
||||||
msg = (
|
msg = (
|
||||||
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
|
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
|
||||||
% (url, pointcloud_filename)
|
% (url, pointcloud_filename)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user