Fix a few linting warnings

Summary: Fix a few linting warnings

Reviewed By: nikhilaravi

Differential Revision: D20720810

fbshipit-source-id: c5b6a25fdd7971cc8743b54bbe162464a874071d
This commit is contained in:
Patrick Labatut 2020-10-01 02:38:06 -07:00 committed by Facebook GitHub Bot
parent 4f8a2f1979
commit 8219a52ccc
5 changed files with 132 additions and 80 deletions

View File

@ -5,7 +5,7 @@
import os import os
import warnings import warnings
from collections import namedtuple from collections import namedtuple
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import torch import torch
@ -345,42 +345,26 @@ def _parse_face(
faces_materials_idx.append(material_idx) faces_materials_idx.append(material_idx)
def _load_obj( def _parse_obj(f, data_dir: str):
f_obj,
data_dir,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
device="cpu",
):
""" """
Load a mesh from a file-like object. See load_obj function more details. Load a mesh from a file-like object. See load_obj function for more details
Any material files associated with the obj are expected to be in the about the return values.
directory given by data_dir.
""" """
verts, normals, verts_uvs = [], [], []
if texture_wrap is not None and texture_wrap not in ["repeat", "clamp"]: faces_verts_idx, faces_normals_idx, faces_textures_idx = [], [], []
msg = "texture_wrap must be one of ['repeat', 'clamp'] or None, got %s"
raise ValueError(msg % texture_wrap)
lines = [line.strip() for line in f_obj]
verts = []
normals = []
verts_uvs = []
faces_verts_idx = []
faces_normals_idx = []
faces_textures_idx = []
material_names = []
faces_materials_idx = [] faces_materials_idx = []
f_mtl = None material_names = []
materials_idx = -1 mtl_path = None
lines = [line.strip() for line in f]
# startswith expects each line to be a string. If the file is read in as # startswith expects each line to be a string. If the file is read in as
# bytes then first decode to strings. # bytes then first decode to strings.
if lines and isinstance(lines[0], bytes): if lines and isinstance(lines[0], bytes):
lines = [el.decode("utf-8") for el in lines] lines = [el.decode("utf-8") for el in lines]
materials_idx = -1
for line in lines: for line in lines:
tokens = line.strip().split() tokens = line.strip().split()
if line.startswith("mtllib"): if line.startswith("mtllib"):
@ -389,7 +373,8 @@ def _load_obj(
# NOTE: only allow one .mtl file per .obj. # NOTE: only allow one .mtl file per .obj.
# Definitions for multiple materials can be included # Definitions for multiple materials can be included
# in this one .mtl file. # in this one .mtl file.
f_mtl = os.path.join(data_dir, line.split()[1]) mtl_path = line[len(tokens[0]) :].strip() # Take the remainder of the line
mtl_path = os.path.join(data_dir, mtl_path)
elif len(tokens) and tokens[0] == "usemtl": elif len(tokens) and tokens[0] == "usemtl":
material_name = tokens[1] material_name = tokens[1]
# materials are often repeated for different parts # materials are often repeated for different parts
@ -430,6 +415,83 @@ def _load_obj(
faces_materials_idx, faces_materials_idx,
) )
return (
verts,
normals,
verts_uvs,
faces_verts_idx,
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
material_names,
mtl_path,
)
def _load_materials(
material_names: List[str], f, data_dir: str, *, load_textures: bool, device
):
"""
Load materials and optionally textures from the specified path.
Args:
material_names: a list of the material names found in the .obj file.
f: a file-like object of the material information.
data_dir: the directory where the material texture files are located.
load_textures: whether textures should be loaded.
device: string or torch.device on which to return the new tensors.
Returns:
material_colors: dict of properties for each material.
texture_images: dict of material names and texture images.
"""
if not load_textures:
return None, None
if not material_names or f is None:
if material_names:
warnings.warn("No mtl file provided")
return None, None
if not os.path.isfile(f):
warnings.warn(f"Mtl file does not exist: {f}")
return None, None
# Texture mode uv wrap
return load_mtl(f, material_names, data_dir, device=device)
def _load_obj(
f_obj,
data_dir,
load_textures: bool = True,
create_texture_atlas: bool = False,
texture_atlas_size: int = 4,
texture_wrap: Optional[str] = "repeat",
device="cpu",
):
"""
Load a mesh from a file-like object. See load_obj function more details.
Any material files associated with the obj are expected to be in the
directory given by data_dir.
"""
if texture_wrap is not None and texture_wrap not in ["repeat", "clamp"]:
msg = "texture_wrap must be one of ['repeat', 'clamp'] or None, got %s"
raise ValueError(msg % texture_wrap)
(
verts,
normals,
verts_uvs,
faces_verts_idx,
faces_normals_idx,
faces_textures_idx,
faces_materials_idx,
material_names,
mtl_path,
) = _parse_obj(f_obj, data_dir)
verts = _make_tensor(verts, cols=3, dtype=torch.float32, device=device) # (V, 3) verts = _make_tensor(verts, cols=3, dtype=torch.float32, device=device) # (V, 3)
normals = _make_tensor( normals = _make_tensor(
normals, cols=3, dtype=torch.float32, device=device normals, cols=3, dtype=torch.float32, device=device
@ -443,30 +505,24 @@ def _load_obj(
) )
# Repeat for normals and textures if present. # Repeat for normals and textures if present.
if len(faces_normals_idx) > 0: if len(faces_normals_idx):
faces_normals_idx = _format_faces_indices( faces_normals_idx = _format_faces_indices(
faces_normals_idx, normals.shape[0], device=device, pad_value=-1 faces_normals_idx, normals.shape[0], device=device, pad_value=-1
) )
if len(faces_textures_idx) > 0: if len(faces_textures_idx):
faces_textures_idx = _format_faces_indices( faces_textures_idx = _format_faces_indices(
faces_textures_idx, verts_uvs.shape[0], device=device, pad_value=-1 faces_textures_idx, verts_uvs.shape[0], device=device, pad_value=-1
) )
if len(faces_materials_idx) > 0: if len(faces_materials_idx):
faces_materials_idx = torch.tensor( faces_materials_idx = torch.tensor(
faces_materials_idx, dtype=torch.int64, device=device faces_materials_idx, dtype=torch.int64, device=device
) )
# Load materials texture_atlas = None
material_colors, texture_images, texture_atlas = None, None, None material_colors, texture_images = _load_materials(
if load_textures: material_names, mtl_path, data_dir, load_textures=load_textures, device=device
if (len(material_names) > 0) and (f_mtl is not None):
# pyre-fixme[6]: Expected `Union[_PathLike[typing.Any], bytes, str]` for
# 1st param but got `Optional[str]`.
if os.path.isfile(f_mtl):
# Texture mode uv wrap
material_colors, texture_images = load_mtl(
f_mtl, material_names, data_dir, device=device
) )
if create_texture_atlas: if create_texture_atlas:
# Using the images and properties from the # Using the images and properties from the
# material file make a per face texture map. # material file make a per face texture map.
@ -477,7 +533,6 @@ def _load_obj(
face_material_names = np.array(material_names)[idx] # (F,) face_material_names = np.array(material_names)[idx] # (F,)
face_material_names[idx == -1] = "" face_material_names[idx == -1] = ""
texture_atlas = None
if len(verts_uvs) > 0: if len(verts_uvs) > 0:
# Get the uv coords for each vert in each face # Get the uv coords for each vert in each face
faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2) faces_verts_uvs = verts_uvs[faces_textures_idx] # (F, 3, 2)
@ -491,10 +546,6 @@ def _load_obj(
texture_atlas_size, texture_atlas_size,
texture_wrap, texture_wrap,
) )
else:
warnings.warn(f"Mtl file does not exist: {f_mtl}")
elif len(material_names) > 0:
warnings.warn("No mtl file provided")
faces = _Faces( faces = _Faces(
verts_idx=faces_verts_idx, verts_idx=faces_verts_idx,
@ -502,10 +553,9 @@ def _load_obj(
textures_idx=faces_textures_idx, textures_idx=faces_textures_idx,
materials_idx=faces_materials_idx, materials_idx=faces_materials_idx,
) )
aux = _Aux( aux = _Aux(
normals=normals if len(normals) > 0 else None, normals=normals if len(normals) else None,
verts_uvs=verts_uvs if len(verts_uvs) > 0 else None, verts_uvs=verts_uvs if len(verts_uvs) else None,
material_colors=material_colors, material_colors=material_colors,
texture_images=texture_images, texture_images=texture_images,
texture_atlas=texture_atlas, texture_atlas=texture_atlas,

View File

@ -282,7 +282,7 @@ def _try_read_ply_constant_list_ascii(f, definition: _PlyElementType):
return None return None
if not len(data): # np.loadtxt() seeks even on empty data if not len(data): # np.loadtxt() seeks even on empty data
f.seek(old_offset) f.seek(old_offset)
if (data.shape[1] - 1 != data[:, 0]).any(): if (data[:, 0] != data.shape[1] - 1).any():
msg = "A line of %s data did not have the specified length." msg = "A line of %s data did not have the specified length."
raise ValueError(msg % definition.name) raise ValueError(msg % definition.name)
if data.shape[0] != definition.count: if data.shape[0] != definition.count:

View File

@ -657,10 +657,6 @@ class TexturesUV(TexturesBase):
if verts_uvs.device != self.device: if verts_uvs.device != self.device:
raise ValueError("verts_uvs and faces_uvs must be on the same device") raise ValueError("verts_uvs and faces_uvs must be on the same device")
# These values may be overridden when textures is
# passed into the Meshes constructor.
max_V = verts_uvs.shape[1]
else: else:
raise ValueError("Expected verts_uvs to be a tensor or list") raise ValueError("Expected verts_uvs to be a tensor or list")

View File

@ -498,10 +498,13 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
# Note, the reference texture atlas generated using SoftRas load_obj function # Note, the reference texture atlas generated using SoftRas load_obj function
# is too large to check in to the repo. Download the file to run the test locally. # is too large to check in to the repo. Download the file to run the test locally.
if not os.path.exists(expected_atlas_fname): if not os.path.exists(expected_atlas_fname):
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/tests/cow_texture_atlas_softras.pt" url = (
"https://dl.fbaipublicfiles.com/pytorch3d/data/"
"tests/cow_texture_atlas_softras.pt"
)
msg = ( msg = (
"cow_texture_atlas_softras.pt not found, download from %s, save it at the path %s, and rerun" "cow_texture_atlas_softras.pt not found, download from %s, "
% (url, expected_atlas_fname) "save it at the path %s, and rerun" % (url, expected_atlas_fname)
) )
warnings.warn(msg) warnings.warn(msg)
return True return True

View File

@ -80,7 +80,10 @@ class TestRenderPoints(TestCaseMixin, unittest.TestCase):
# Note, this file is too large to check in to the repo. # Note, this file is too large to check in to the repo.
# Download the file to run the test locally. # Download the file to run the test locally.
if not path.exists(pointcloud_filename): if not path.exists(pointcloud_filename):
url = "https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz" url = (
"https://dl.fbaipublicfiles.com/pytorch3d/data/"
"PittsburghBridge/pointcloud.npz"
)
msg = ( msg = (
"pointcloud.npz not found, download from %s, save it at the path %s, and rerun" "pointcloud.npz not found, download from %s, save it at the path %s, and rerun"
% (url, pointcloud_filename) % (url, pointcloud_filename)