mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Rendering texturing fixes
Summary: Fix errors raised by issue on GitHub - extending mesh textures + rendering with Gourad and Phong shaders. https://github.com/facebookresearch/pytorch3d/issues/97 Reviewed By: gkioxari Differential Revision: D20319610 fbshipit-source-id: d1c692ff0b9397a77a9b829c5c731790de70c09f
This commit is contained in:
committed by
Facebook GitHub Bot
parent
f580ce1385
commit
5d3cc3569a
@@ -107,7 +107,9 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
|
||||
There will be one C dimensional value for each element in
|
||||
fragments.pix_to_face.
|
||||
"""
|
||||
vertex_textures = meshes.textures.verts_rgb_padded().view(-1, 3) # (V, C)
|
||||
vertex_textures = meshes.textures.verts_rgb_padded().reshape(
|
||||
-1, 3
|
||||
) # (V, C)
|
||||
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
|
||||
faces_packed = meshes.faces_packed()
|
||||
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
|
||||
|
||||
@@ -223,27 +223,32 @@ class TensorProperties(object):
|
||||
self with all properties reshaped. e.g. a property with shape (N, 3)
|
||||
is transformed to shape (B, 3).
|
||||
"""
|
||||
# Iterate through the attributes of the class which are tensors.
|
||||
for k in dir(self):
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
if v.shape[0] > 1:
|
||||
# There are different values for each batch element
|
||||
# so gather these using the batch_idx
|
||||
idx_dims = batch_idx.shape
|
||||
# so gather these using the batch_idx.
|
||||
# First clone the input batch_idx tensor before
|
||||
# modifying it.
|
||||
_batch_idx = batch_idx.clone()
|
||||
idx_dims = _batch_idx.shape
|
||||
tensor_dims = v.shape
|
||||
if len(idx_dims) > len(tensor_dims):
|
||||
msg = "batch_idx cannot have more dimensions than %s. "
|
||||
msg += "got shape %r and %s has shape %r"
|
||||
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
|
||||
if idx_dims != tensor_dims:
|
||||
# To use torch.gather the index tensor (batch_idx) has
|
||||
# To use torch.gather the index tensor (_batch_idx) has
|
||||
# to have the same shape as the input tensor.
|
||||
new_dims = len(tensor_dims) - len(idx_dims)
|
||||
new_shape = idx_dims + (1,) * new_dims
|
||||
expand_dims = (-1,) + tensor_dims[1:]
|
||||
batch_idx = batch_idx.view(*new_shape)
|
||||
batch_idx = batch_idx.expand(*expand_dims)
|
||||
v = v.gather(0, batch_idx)
|
||||
_batch_idx = _batch_idx.view(*new_shape)
|
||||
_batch_idx = _batch_idx.expand(*expand_dims)
|
||||
|
||||
v = v.gather(0, _batch_idx)
|
||||
setattr(self, k, v)
|
||||
return self
|
||||
|
||||
|
||||
@@ -324,14 +324,14 @@ class Meshes(object):
|
||||
)
|
||||
if self._N > 0:
|
||||
self.device = self._verts_list[0].device
|
||||
num_verts_per_mesh = torch.tensor(
|
||||
self._num_verts_per_mesh = torch.tensor(
|
||||
[len(v) for v in self._verts_list], device=self.device
|
||||
)
|
||||
self._V = num_verts_per_mesh.max()
|
||||
num_faces_per_mesh = torch.tensor(
|
||||
self._V = self._num_verts_per_mesh.max()
|
||||
self._num_faces_per_mesh = torch.tensor(
|
||||
[len(f) for f in self._faces_list], device=self.device
|
||||
)
|
||||
self._F = num_faces_per_mesh.max()
|
||||
self._F = self._num_faces_per_mesh.max()
|
||||
self.valid = torch.tensor(
|
||||
[
|
||||
len(v) > 0 and len(f) > 0
|
||||
@@ -341,8 +341,8 @@ class Meshes(object):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
if (len(num_verts_per_mesh.unique()) == 1) and (
|
||||
len(num_faces_per_mesh.unique()) == 1
|
||||
if (len(self._num_verts_per_mesh.unique()) == 1) and (
|
||||
len(self._num_faces_per_mesh.unique()) == 1
|
||||
):
|
||||
self.equisized = True
|
||||
|
||||
@@ -355,6 +355,7 @@ class Meshes(object):
|
||||
self._faces_padded = faces.to(torch.int64)
|
||||
self._N = self._verts_padded.shape[0]
|
||||
self._V = self._verts_padded.shape[1]
|
||||
|
||||
self.device = self._verts_padded.device
|
||||
self.valid = torch.zeros(
|
||||
(self._N,), dtype=torch.bool, device=self.device
|
||||
@@ -363,18 +364,25 @@ class Meshes(object):
|
||||
# Check that padded faces - which have value -1 - are at the
|
||||
# end of the tensors
|
||||
faces_not_padded = self._faces_padded.gt(-1).all(2)
|
||||
num_faces = faces_not_padded.sum(1)
|
||||
self._num_faces_per_mesh = faces_not_padded.sum(1)
|
||||
if (faces_not_padded[:, :-1] < faces_not_padded[:, 1:]).any():
|
||||
raise ValueError("Padding of faces must be at the end")
|
||||
|
||||
# NOTE that we don't check for the ordering of padded verts
|
||||
# as long as the faces index correspond to the right vertices.
|
||||
|
||||
self.valid = num_faces > 0
|
||||
self._F = num_faces.max()
|
||||
if len(num_faces.unique()) == 1:
|
||||
self.valid = self._num_faces_per_mesh > 0
|
||||
self._F = self._num_faces_per_mesh.max()
|
||||
if len(self._num_faces_per_mesh.unique()) == 1:
|
||||
self.equisized = True
|
||||
|
||||
self._num_verts_per_mesh = torch.full(
|
||||
size=(self._N,),
|
||||
fill_value=self._V,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Verts and Faces must be either a list or a tensor with \
|
||||
@@ -382,6 +390,23 @@ class Meshes(object):
|
||||
number of verts or faces respectively."
|
||||
)
|
||||
|
||||
if self.isempty():
|
||||
self._num_verts_per_mesh = torch.zeros(
|
||||
(0,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self._num_faces_per_mesh = torch.zeros(
|
||||
(0,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
# Set the num verts/faces on the textures if present.
|
||||
if self.textures is not None:
|
||||
self.textures._num_faces_per_mesh = (
|
||||
self._num_faces_per_mesh.tolist()
|
||||
)
|
||||
self.textures._num_verts_per_mesh = (
|
||||
self._num_verts_per_mesh.tolist()
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._N
|
||||
|
||||
@@ -893,11 +918,9 @@ class Meshes(object):
|
||||
self._verts_packed,
|
||||
self._verts_packed_to_mesh_idx,
|
||||
self._mesh_to_verts_packed_first_idx,
|
||||
self._num_verts_per_mesh,
|
||||
self._faces_packed,
|
||||
self._faces_packed_to_mesh_idx,
|
||||
self._mesh_to_faces_packed_first_idx,
|
||||
self._num_faces_per_mesh,
|
||||
]
|
||||
)
|
||||
):
|
||||
@@ -920,7 +943,6 @@ class Meshes(object):
|
||||
self._num_verts_per_mesh = torch.zeros(
|
||||
(0,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
self._faces_packed = -torch.ones(
|
||||
(0, 3), dtype=torch.int64, device=self.device
|
||||
)
|
||||
@@ -1354,6 +1376,7 @@ class Meshes(object):
|
||||
tex = None
|
||||
if self.textures is not None:
|
||||
tex = self.textures.extend(N)
|
||||
|
||||
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from .utils import list_to_packed, padded_to_list
|
||||
from .utils import padded_to_list, padded_to_packed
|
||||
|
||||
|
||||
"""
|
||||
@@ -92,14 +92,19 @@ class Textures(object):
|
||||
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
|
||||
vertex in the face. Padding value is assumed to be -1.
|
||||
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
|
||||
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex.
|
||||
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
|
||||
value is assumed to be -1.
|
||||
|
||||
Note: only the padded representation of the textures is stored
|
||||
and the packed/list representations are computed on the fly and
|
||||
not cached.
|
||||
"""
|
||||
if faces_uvs is not None and faces_uvs.ndim != 3:
|
||||
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
||||
raise ValueError(msg % repr(faces_uvs.shape))
|
||||
if verts_uvs is not None and verts_uvs.ndim != 3:
|
||||
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
||||
raise ValueError(msg % repr(faces_uvs.shape))
|
||||
raise ValueError(msg % repr(verts_uvs.shape))
|
||||
if verts_rgb is not None and verts_rgb.ndim != 3:
|
||||
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
|
||||
raise ValueError(msg % repr(verts_rgb.shape))
|
||||
@@ -109,20 +114,20 @@ class Textures(object):
|
||||
raise ValueError(msg % repr(maps.shape))
|
||||
elif isinstance(maps, list):
|
||||
maps = _pad_texture_maps(maps)
|
||||
if faces_uvs is None or verts_uvs is None:
|
||||
msg = "To use maps, faces_uvs and verts_uvs are required"
|
||||
raise ValueError(msg)
|
||||
|
||||
self._faces_uvs_padded = faces_uvs
|
||||
self._verts_uvs_padded = verts_uvs
|
||||
self._verts_rgb_padded = verts_rgb
|
||||
self._maps_padded = maps
|
||||
self._num_faces_per_mesh = None
|
||||
self._set_num_faces_per_mesh()
|
||||
|
||||
def _set_num_faces_per_mesh(self) -> None:
|
||||
"""
|
||||
Determines and sets the number of textured faces for each mesh.
|
||||
"""
|
||||
if self._faces_uvs_padded is not None:
|
||||
faces_uvs = self._faces_uvs_padded
|
||||
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
|
||||
# The number of faces/verts for each mesh is
|
||||
# set inside the Meshes object when textures is
|
||||
# passed into the Meshes constructor.
|
||||
self._num_faces_per_mesh = None
|
||||
self._num_verts_per_mesh = None
|
||||
|
||||
def clone(self):
|
||||
other = Textures()
|
||||
@@ -148,41 +153,67 @@ class Textures(object):
|
||||
setattr(other, key, value[index][None])
|
||||
else:
|
||||
setattr(other, key, value[index])
|
||||
other._set_num_faces_per_mesh()
|
||||
return other
|
||||
|
||||
def faces_uvs_padded(self) -> torch.Tensor:
|
||||
return self._faces_uvs_padded
|
||||
|
||||
def faces_uvs_list(self) -> List[torch.Tensor]:
|
||||
if self._faces_uvs_padded is not None:
|
||||
return padded_to_list(
|
||||
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
||||
)
|
||||
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._faces_uvs_padded is None:
|
||||
return None
|
||||
return padded_to_list(
|
||||
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
||||
)
|
||||
|
||||
def faces_uvs_packed(self) -> torch.Tensor:
|
||||
return list_to_packed(self.faces_uvs_list())[0]
|
||||
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._faces_uvs_padded is None:
|
||||
return None
|
||||
return padded_to_packed(
|
||||
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
||||
)
|
||||
|
||||
def verts_uvs_padded(self) -> torch.Tensor:
|
||||
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._verts_uvs_padded
|
||||
|
||||
def verts_uvs_list(self) -> List[torch.Tensor]:
|
||||
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._verts_uvs_padded is None:
|
||||
return None
|
||||
# Vertices shared between multiple faces
|
||||
# may have a different uv coordinate for
|
||||
# each face so the num_verts_uvs_per_mesh
|
||||
# may be different from num_verts_per_mesh.
|
||||
# Therefore don't use any split_size.
|
||||
return padded_to_list(self._verts_uvs_padded)
|
||||
|
||||
def verts_uvs_packed(self) -> torch.Tensor:
|
||||
return list_to_packed(self.verts_uvs_list())[0]
|
||||
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._verts_uvs_padded is None:
|
||||
return None
|
||||
# Vertices shared between multiple faces
|
||||
# may have a different uv coordinate for
|
||||
# each face so the num_verts_uvs_per_mesh
|
||||
# may be different from num_verts_per_mesh.
|
||||
# Therefore don't use any split_size.
|
||||
return padded_to_packed(self._verts_uvs_padded)
|
||||
|
||||
def verts_rgb_padded(self) -> torch.Tensor:
|
||||
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._verts_rgb_padded
|
||||
|
||||
def verts_rgb_list(self) -> List[torch.Tensor]:
|
||||
return padded_to_list(self._verts_rgb_padded)
|
||||
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
|
||||
if self._verts_rgb_padded is None:
|
||||
return None
|
||||
return padded_to_list(
|
||||
self._verts_rgb_padded, split_size=self._num_verts_per_mesh
|
||||
)
|
||||
|
||||
def verts_rgb_packed(self) -> torch.Tensor:
|
||||
return list_to_packed(self.verts_rgb_list())[0]
|
||||
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
|
||||
if self._verts_rgb_padded is None:
|
||||
return None
|
||||
return padded_to_packed(
|
||||
self._verts_rgb_padded, split_size=self._num_verts_per_mesh
|
||||
)
|
||||
|
||||
# Currently only the padded maps are used.
|
||||
def maps_padded(self) -> torch.Tensor:
|
||||
def maps_padded(self) -> Union[torch.Tensor, None]:
|
||||
return self._maps_padded
|
||||
|
||||
def extend(self, N: int) -> "Textures":
|
||||
|
||||
Reference in New Issue
Block a user