mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Rendering texturing fixes
Summary: Fix errors raised by issue on GitHub - extending mesh textures + rendering with Gourad and Phong shaders. https://github.com/facebookresearch/pytorch3d/issues/97 Reviewed By: gkioxari Differential Revision: D20319610 fbshipit-source-id: d1c692ff0b9397a77a9b829c5c731790de70c09f
This commit is contained in:
parent
f580ce1385
commit
5d3cc3569a
@ -107,7 +107,9 @@ def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor:
|
|||||||
There will be one C dimensional value for each element in
|
There will be one C dimensional value for each element in
|
||||||
fragments.pix_to_face.
|
fragments.pix_to_face.
|
||||||
"""
|
"""
|
||||||
vertex_textures = meshes.textures.verts_rgb_padded().view(-1, 3) # (V, C)
|
vertex_textures = meshes.textures.verts_rgb_padded().reshape(
|
||||||
|
-1, 3
|
||||||
|
) # (V, C)
|
||||||
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
|
vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :]
|
||||||
faces_packed = meshes.faces_packed()
|
faces_packed = meshes.faces_packed()
|
||||||
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
|
faces_textures = vertex_textures[faces_packed] # (F, 3, C)
|
||||||
|
@ -223,27 +223,32 @@ class TensorProperties(object):
|
|||||||
self with all properties reshaped. e.g. a property with shape (N, 3)
|
self with all properties reshaped. e.g. a property with shape (N, 3)
|
||||||
is transformed to shape (B, 3).
|
is transformed to shape (B, 3).
|
||||||
"""
|
"""
|
||||||
|
# Iterate through the attributes of the class which are tensors.
|
||||||
for k in dir(self):
|
for k in dir(self):
|
||||||
v = getattr(self, k)
|
v = getattr(self, k)
|
||||||
if torch.is_tensor(v):
|
if torch.is_tensor(v):
|
||||||
if v.shape[0] > 1:
|
if v.shape[0] > 1:
|
||||||
# There are different values for each batch element
|
# There are different values for each batch element
|
||||||
# so gather these using the batch_idx
|
# so gather these using the batch_idx.
|
||||||
idx_dims = batch_idx.shape
|
# First clone the input batch_idx tensor before
|
||||||
|
# modifying it.
|
||||||
|
_batch_idx = batch_idx.clone()
|
||||||
|
idx_dims = _batch_idx.shape
|
||||||
tensor_dims = v.shape
|
tensor_dims = v.shape
|
||||||
if len(idx_dims) > len(tensor_dims):
|
if len(idx_dims) > len(tensor_dims):
|
||||||
msg = "batch_idx cannot have more dimensions than %s. "
|
msg = "batch_idx cannot have more dimensions than %s. "
|
||||||
msg += "got shape %r and %s has shape %r"
|
msg += "got shape %r and %s has shape %r"
|
||||||
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
|
raise ValueError(msg % (k, idx_dims, k, tensor_dims))
|
||||||
if idx_dims != tensor_dims:
|
if idx_dims != tensor_dims:
|
||||||
# To use torch.gather the index tensor (batch_idx) has
|
# To use torch.gather the index tensor (_batch_idx) has
|
||||||
# to have the same shape as the input tensor.
|
# to have the same shape as the input tensor.
|
||||||
new_dims = len(tensor_dims) - len(idx_dims)
|
new_dims = len(tensor_dims) - len(idx_dims)
|
||||||
new_shape = idx_dims + (1,) * new_dims
|
new_shape = idx_dims + (1,) * new_dims
|
||||||
expand_dims = (-1,) + tensor_dims[1:]
|
expand_dims = (-1,) + tensor_dims[1:]
|
||||||
batch_idx = batch_idx.view(*new_shape)
|
_batch_idx = _batch_idx.view(*new_shape)
|
||||||
batch_idx = batch_idx.expand(*expand_dims)
|
_batch_idx = _batch_idx.expand(*expand_dims)
|
||||||
v = v.gather(0, batch_idx)
|
|
||||||
|
v = v.gather(0, _batch_idx)
|
||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -324,14 +324,14 @@ class Meshes(object):
|
|||||||
)
|
)
|
||||||
if self._N > 0:
|
if self._N > 0:
|
||||||
self.device = self._verts_list[0].device
|
self.device = self._verts_list[0].device
|
||||||
num_verts_per_mesh = torch.tensor(
|
self._num_verts_per_mesh = torch.tensor(
|
||||||
[len(v) for v in self._verts_list], device=self.device
|
[len(v) for v in self._verts_list], device=self.device
|
||||||
)
|
)
|
||||||
self._V = num_verts_per_mesh.max()
|
self._V = self._num_verts_per_mesh.max()
|
||||||
num_faces_per_mesh = torch.tensor(
|
self._num_faces_per_mesh = torch.tensor(
|
||||||
[len(f) for f in self._faces_list], device=self.device
|
[len(f) for f in self._faces_list], device=self.device
|
||||||
)
|
)
|
||||||
self._F = num_faces_per_mesh.max()
|
self._F = self._num_faces_per_mesh.max()
|
||||||
self.valid = torch.tensor(
|
self.valid = torch.tensor(
|
||||||
[
|
[
|
||||||
len(v) > 0 and len(f) > 0
|
len(v) > 0 and len(f) > 0
|
||||||
@ -341,8 +341,8 @@ class Meshes(object):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (len(num_verts_per_mesh.unique()) == 1) and (
|
if (len(self._num_verts_per_mesh.unique()) == 1) and (
|
||||||
len(num_faces_per_mesh.unique()) == 1
|
len(self._num_faces_per_mesh.unique()) == 1
|
||||||
):
|
):
|
||||||
self.equisized = True
|
self.equisized = True
|
||||||
|
|
||||||
@ -355,6 +355,7 @@ class Meshes(object):
|
|||||||
self._faces_padded = faces.to(torch.int64)
|
self._faces_padded = faces.to(torch.int64)
|
||||||
self._N = self._verts_padded.shape[0]
|
self._N = self._verts_padded.shape[0]
|
||||||
self._V = self._verts_padded.shape[1]
|
self._V = self._verts_padded.shape[1]
|
||||||
|
|
||||||
self.device = self._verts_padded.device
|
self.device = self._verts_padded.device
|
||||||
self.valid = torch.zeros(
|
self.valid = torch.zeros(
|
||||||
(self._N,), dtype=torch.bool, device=self.device
|
(self._N,), dtype=torch.bool, device=self.device
|
||||||
@ -363,18 +364,25 @@ class Meshes(object):
|
|||||||
# Check that padded faces - which have value -1 - are at the
|
# Check that padded faces - which have value -1 - are at the
|
||||||
# end of the tensors
|
# end of the tensors
|
||||||
faces_not_padded = self._faces_padded.gt(-1).all(2)
|
faces_not_padded = self._faces_padded.gt(-1).all(2)
|
||||||
num_faces = faces_not_padded.sum(1)
|
self._num_faces_per_mesh = faces_not_padded.sum(1)
|
||||||
if (faces_not_padded[:, :-1] < faces_not_padded[:, 1:]).any():
|
if (faces_not_padded[:, :-1] < faces_not_padded[:, 1:]).any():
|
||||||
raise ValueError("Padding of faces must be at the end")
|
raise ValueError("Padding of faces must be at the end")
|
||||||
|
|
||||||
# NOTE that we don't check for the ordering of padded verts
|
# NOTE that we don't check for the ordering of padded verts
|
||||||
# as long as the faces index correspond to the right vertices.
|
# as long as the faces index correspond to the right vertices.
|
||||||
|
|
||||||
self.valid = num_faces > 0
|
self.valid = self._num_faces_per_mesh > 0
|
||||||
self._F = num_faces.max()
|
self._F = self._num_faces_per_mesh.max()
|
||||||
if len(num_faces.unique()) == 1:
|
if len(self._num_faces_per_mesh.unique()) == 1:
|
||||||
self.equisized = True
|
self.equisized = True
|
||||||
|
|
||||||
|
self._num_verts_per_mesh = torch.full(
|
||||||
|
size=(self._N,),
|
||||||
|
fill_value=self._V,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Verts and Faces must be either a list or a tensor with \
|
"Verts and Faces must be either a list or a tensor with \
|
||||||
@ -382,6 +390,23 @@ class Meshes(object):
|
|||||||
number of verts or faces respectively."
|
number of verts or faces respectively."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.isempty():
|
||||||
|
self._num_verts_per_mesh = torch.zeros(
|
||||||
|
(0,), dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self._num_faces_per_mesh = torch.zeros(
|
||||||
|
(0,), dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the num verts/faces on the textures if present.
|
||||||
|
if self.textures is not None:
|
||||||
|
self.textures._num_faces_per_mesh = (
|
||||||
|
self._num_faces_per_mesh.tolist()
|
||||||
|
)
|
||||||
|
self.textures._num_verts_per_mesh = (
|
||||||
|
self._num_verts_per_mesh.tolist()
|
||||||
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self._N
|
return self._N
|
||||||
|
|
||||||
@ -893,11 +918,9 @@ class Meshes(object):
|
|||||||
self._verts_packed,
|
self._verts_packed,
|
||||||
self._verts_packed_to_mesh_idx,
|
self._verts_packed_to_mesh_idx,
|
||||||
self._mesh_to_verts_packed_first_idx,
|
self._mesh_to_verts_packed_first_idx,
|
||||||
self._num_verts_per_mesh,
|
|
||||||
self._faces_packed,
|
self._faces_packed,
|
||||||
self._faces_packed_to_mesh_idx,
|
self._faces_packed_to_mesh_idx,
|
||||||
self._mesh_to_faces_packed_first_idx,
|
self._mesh_to_faces_packed_first_idx,
|
||||||
self._num_faces_per_mesh,
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@ -920,7 +943,6 @@ class Meshes(object):
|
|||||||
self._num_verts_per_mesh = torch.zeros(
|
self._num_verts_per_mesh = torch.zeros(
|
||||||
(0,), dtype=torch.int64, device=self.device
|
(0,), dtype=torch.int64, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
self._faces_packed = -torch.ones(
|
self._faces_packed = -torch.ones(
|
||||||
(0, 3), dtype=torch.int64, device=self.device
|
(0, 3), dtype=torch.int64, device=self.device
|
||||||
)
|
)
|
||||||
@ -1354,6 +1376,7 @@ class Meshes(object):
|
|||||||
tex = None
|
tex = None
|
||||||
if self.textures is not None:
|
if self.textures is not None:
|
||||||
tex = self.textures.extend(N)
|
tex = self.textures.extend(N)
|
||||||
|
|
||||||
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
|
return Meshes(verts=new_verts_list, faces=new_faces_list, textures=tex)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ from typing import List, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
|
|
||||||
from .utils import list_to_packed, padded_to_list
|
from .utils import padded_to_list, padded_to_packed
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -92,14 +92,19 @@ class Textures(object):
|
|||||||
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
|
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each
|
||||||
vertex in the face. Padding value is assumed to be -1.
|
vertex in the face. Padding value is assumed to be -1.
|
||||||
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
|
verts_uvs: (N, V, 2) tensor giving the uv coordinate per vertex.
|
||||||
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex.
|
verts_rgb: (N, V, 3) tensor giving the rgb color per vertex. Padding
|
||||||
|
value is assumed to be -1.
|
||||||
|
|
||||||
|
Note: only the padded representation of the textures is stored
|
||||||
|
and the packed/list representations are computed on the fly and
|
||||||
|
not cached.
|
||||||
"""
|
"""
|
||||||
if faces_uvs is not None and faces_uvs.ndim != 3:
|
if faces_uvs is not None and faces_uvs.ndim != 3:
|
||||||
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
msg = "Expected faces_uvs to be of shape (N, F, 3); got %r"
|
||||||
raise ValueError(msg % repr(faces_uvs.shape))
|
raise ValueError(msg % repr(faces_uvs.shape))
|
||||||
if verts_uvs is not None and verts_uvs.ndim != 3:
|
if verts_uvs is not None and verts_uvs.ndim != 3:
|
||||||
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
msg = "Expected verts_uvs to be of shape (N, V, 2); got %r"
|
||||||
raise ValueError(msg % repr(faces_uvs.shape))
|
raise ValueError(msg % repr(verts_uvs.shape))
|
||||||
if verts_rgb is not None and verts_rgb.ndim != 3:
|
if verts_rgb is not None and verts_rgb.ndim != 3:
|
||||||
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
|
msg = "Expected verts_rgb to be of shape (N, V, 3); got %r"
|
||||||
raise ValueError(msg % repr(verts_rgb.shape))
|
raise ValueError(msg % repr(verts_rgb.shape))
|
||||||
@ -109,20 +114,20 @@ class Textures(object):
|
|||||||
raise ValueError(msg % repr(maps.shape))
|
raise ValueError(msg % repr(maps.shape))
|
||||||
elif isinstance(maps, list):
|
elif isinstance(maps, list):
|
||||||
maps = _pad_texture_maps(maps)
|
maps = _pad_texture_maps(maps)
|
||||||
|
if faces_uvs is None or verts_uvs is None:
|
||||||
|
msg = "To use maps, faces_uvs and verts_uvs are required"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self._faces_uvs_padded = faces_uvs
|
self._faces_uvs_padded = faces_uvs
|
||||||
self._verts_uvs_padded = verts_uvs
|
self._verts_uvs_padded = verts_uvs
|
||||||
self._verts_rgb_padded = verts_rgb
|
self._verts_rgb_padded = verts_rgb
|
||||||
self._maps_padded = maps
|
self._maps_padded = maps
|
||||||
self._num_faces_per_mesh = None
|
|
||||||
self._set_num_faces_per_mesh()
|
|
||||||
|
|
||||||
def _set_num_faces_per_mesh(self) -> None:
|
# The number of faces/verts for each mesh is
|
||||||
"""
|
# set inside the Meshes object when textures is
|
||||||
Determines and sets the number of textured faces for each mesh.
|
# passed into the Meshes constructor.
|
||||||
"""
|
self._num_faces_per_mesh = None
|
||||||
if self._faces_uvs_padded is not None:
|
self._num_verts_per_mesh = None
|
||||||
faces_uvs = self._faces_uvs_padded
|
|
||||||
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
|
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
other = Textures()
|
other = Textures()
|
||||||
@ -148,41 +153,67 @@ class Textures(object):
|
|||||||
setattr(other, key, value[index][None])
|
setattr(other, key, value[index][None])
|
||||||
else:
|
else:
|
||||||
setattr(other, key, value[index])
|
setattr(other, key, value[index])
|
||||||
other._set_num_faces_per_mesh()
|
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def faces_uvs_padded(self) -> torch.Tensor:
|
def faces_uvs_padded(self) -> torch.Tensor:
|
||||||
return self._faces_uvs_padded
|
return self._faces_uvs_padded
|
||||||
|
|
||||||
def faces_uvs_list(self) -> List[torch.Tensor]:
|
def faces_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||||
if self._faces_uvs_padded is not None:
|
if self._faces_uvs_padded is None:
|
||||||
return padded_to_list(
|
return None
|
||||||
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
return padded_to_list(
|
||||||
)
|
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
||||||
|
)
|
||||||
|
|
||||||
def faces_uvs_packed(self) -> torch.Tensor:
|
def faces_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||||
return list_to_packed(self.faces_uvs_list())[0]
|
if self._faces_uvs_padded is None:
|
||||||
|
return None
|
||||||
|
return padded_to_packed(
|
||||||
|
self._faces_uvs_padded, split_size=self._num_faces_per_mesh
|
||||||
|
)
|
||||||
|
|
||||||
def verts_uvs_padded(self) -> torch.Tensor:
|
def verts_uvs_padded(self) -> Union[torch.Tensor, None]:
|
||||||
return self._verts_uvs_padded
|
return self._verts_uvs_padded
|
||||||
|
|
||||||
def verts_uvs_list(self) -> List[torch.Tensor]:
|
def verts_uvs_list(self) -> Union[List[torch.Tensor], None]:
|
||||||
|
if self._verts_uvs_padded is None:
|
||||||
|
return None
|
||||||
|
# Vertices shared between multiple faces
|
||||||
|
# may have a different uv coordinate for
|
||||||
|
# each face so the num_verts_uvs_per_mesh
|
||||||
|
# may be different from num_verts_per_mesh.
|
||||||
|
# Therefore don't use any split_size.
|
||||||
return padded_to_list(self._verts_uvs_padded)
|
return padded_to_list(self._verts_uvs_padded)
|
||||||
|
|
||||||
def verts_uvs_packed(self) -> torch.Tensor:
|
def verts_uvs_packed(self) -> Union[torch.Tensor, None]:
|
||||||
return list_to_packed(self.verts_uvs_list())[0]
|
if self._verts_uvs_padded is None:
|
||||||
|
return None
|
||||||
|
# Vertices shared between multiple faces
|
||||||
|
# may have a different uv coordinate for
|
||||||
|
# each face so the num_verts_uvs_per_mesh
|
||||||
|
# may be different from num_verts_per_mesh.
|
||||||
|
# Therefore don't use any split_size.
|
||||||
|
return padded_to_packed(self._verts_uvs_padded)
|
||||||
|
|
||||||
def verts_rgb_padded(self) -> torch.Tensor:
|
def verts_rgb_padded(self) -> Union[torch.Tensor, None]:
|
||||||
return self._verts_rgb_padded
|
return self._verts_rgb_padded
|
||||||
|
|
||||||
def verts_rgb_list(self) -> List[torch.Tensor]:
|
def verts_rgb_list(self) -> Union[List[torch.Tensor], None]:
|
||||||
return padded_to_list(self._verts_rgb_padded)
|
if self._verts_rgb_padded is None:
|
||||||
|
return None
|
||||||
|
return padded_to_list(
|
||||||
|
self._verts_rgb_padded, split_size=self._num_verts_per_mesh
|
||||||
|
)
|
||||||
|
|
||||||
def verts_rgb_packed(self) -> torch.Tensor:
|
def verts_rgb_packed(self) -> Union[torch.Tensor, None]:
|
||||||
return list_to_packed(self.verts_rgb_list())[0]
|
if self._verts_rgb_padded is None:
|
||||||
|
return None
|
||||||
|
return padded_to_packed(
|
||||||
|
self._verts_rgb_padded, split_size=self._num_verts_per_mesh
|
||||||
|
)
|
||||||
|
|
||||||
# Currently only the padded maps are used.
|
# Currently only the padded maps are used.
|
||||||
def maps_padded(self) -> torch.Tensor:
|
def maps_padded(self) -> Union[torch.Tensor, None]:
|
||||||
return self._maps_padded
|
return self._maps_padded
|
||||||
|
|
||||||
def extend(self, N: int) -> "Textures":
|
def extend(self, N: int) -> "Textures":
|
||||||
|
Before Width: | Height: | Size: 21 KiB After Width: | Height: | Size: 21 KiB |
Before Width: | Height: | Size: 10 KiB After Width: | Height: | Size: 10 KiB |
@ -135,6 +135,15 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_simple(self):
|
def test_simple(self):
|
||||||
mesh = TestMeshes.init_simple_mesh("cuda:0")
|
mesh = TestMeshes.init_simple_mesh("cuda:0")
|
||||||
|
|
||||||
|
# Check that faces/verts per mesh are set in init:
|
||||||
|
self.assertClose(
|
||||||
|
mesh._num_faces_per_mesh.cpu(), torch.tensor([1, 2, 7])
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
mesh._num_verts_per_mesh.cpu(), torch.tensor([3, 4, 5])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check computed tensors
|
||||||
self.assertClose(
|
self.assertClose(
|
||||||
mesh.verts_packed_to_mesh_idx().cpu(),
|
mesh.verts_packed_to_mesh_idx().cpu(),
|
||||||
torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
|
torch.tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2]),
|
||||||
@ -142,9 +151,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(
|
self.assertClose(
|
||||||
mesh.mesh_to_verts_packed_first_idx().cpu(), torch.tensor([0, 3, 7])
|
mesh.mesh_to_verts_packed_first_idx().cpu(), torch.tensor([0, 3, 7])
|
||||||
)
|
)
|
||||||
self.assertClose(
|
|
||||||
mesh.num_verts_per_mesh().cpu(), torch.tensor([3, 4, 5])
|
|
||||||
)
|
|
||||||
self.assertClose(
|
self.assertClose(
|
||||||
mesh.verts_padded_to_packed_idx().cpu(),
|
mesh.verts_padded_to_packed_idx().cpu(),
|
||||||
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
|
torch.tensor([0, 1, 2, 5, 6, 7, 8, 10, 11, 12, 13, 14]),
|
||||||
@ -156,9 +162,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(
|
self.assertClose(
|
||||||
mesh.mesh_to_faces_packed_first_idx().cpu(), torch.tensor([0, 1, 3])
|
mesh.mesh_to_faces_packed_first_idx().cpu(), torch.tensor([0, 1, 3])
|
||||||
)
|
)
|
||||||
self.assertClose(
|
|
||||||
mesh.num_faces_per_mesh().cpu(), torch.tensor([1, 2, 7])
|
|
||||||
)
|
|
||||||
self.assertClose(
|
self.assertClose(
|
||||||
mesh.num_edges_per_mesh().cpu(),
|
mesh.num_edges_per_mesh().cpu(),
|
||||||
torch.tensor([3, 5, 10], dtype=torch.int32),
|
torch.tensor([3, 5, 10], dtype=torch.int32),
|
||||||
@ -249,6 +252,8 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertEqual(mesh.faces_padded().shape[0], 0)
|
self.assertEqual(mesh.faces_padded().shape[0], 0)
|
||||||
self.assertEqual(mesh.verts_packed().shape[0], 0)
|
self.assertEqual(mesh.verts_packed().shape[0], 0)
|
||||||
self.assertEqual(mesh.faces_packed().shape[0], 0)
|
self.assertEqual(mesh.faces_packed().shape[0], 0)
|
||||||
|
self.assertEqual(mesh.num_faces_per_mesh().shape[0], 0)
|
||||||
|
self.assertEqual(mesh.num_verts_per_mesh().shape[0], 0)
|
||||||
|
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
N, V, F = 10, 100, 300
|
N, V, F = 10, 100, 300
|
||||||
@ -323,9 +328,11 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
mesh = Meshes(verts=torch.stack(verts), faces=torch.stack(faces))
|
mesh = Meshes(verts=torch.stack(verts), faces=torch.stack(faces))
|
||||||
|
|
||||||
|
# Check verts/faces per mesh are set correctly in init.
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
mesh.num_faces_per_mesh().tolist(), num_faces.tolist()
|
mesh._num_faces_per_mesh.tolist(), num_faces.tolist()
|
||||||
)
|
)
|
||||||
|
self.assertListEqual(mesh._num_verts_per_mesh.tolist(), [V] * N)
|
||||||
|
|
||||||
for n, (vv, ff) in enumerate(zip(mesh.verts_list(), mesh.faces_list())):
|
for n, (vv, ff) in enumerate(zip(mesh.verts_list(), mesh.faces_list())):
|
||||||
self.assertClose(ff, faces[n][: num_faces[n]])
|
self.assertClose(ff, faces[n][: num_faces[n]])
|
||||||
@ -364,7 +371,6 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
mesh._num_verts_per_mesh = torch.randint_like(
|
mesh._num_verts_per_mesh = torch.randint_like(
|
||||||
mesh.num_verts_per_mesh(), high=10
|
mesh.num_verts_per_mesh(), high=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check cloned and original Meshes objects do not share tensors.
|
# Check cloned and original Meshes objects do not share tensors.
|
||||||
self.assertFalse(
|
self.assertFalse(
|
||||||
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])
|
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])
|
||||||
|
@ -34,7 +34,7 @@ from pytorch3d.renderer.mesh.texturing import Textures
|
|||||||
from pytorch3d.structures.meshes import Meshes
|
from pytorch3d.structures.meshes import Meshes
|
||||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||||
|
|
||||||
# Save out images generated in the tests for debugging
|
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||||
# All saved images have prefix DEBUG_
|
# All saved images have prefix DEBUG_
|
||||||
DEBUG = False
|
DEBUG = False
|
||||||
DATA_DIR = Path(__file__).resolve().parent / "data"
|
DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||||
@ -90,30 +90,31 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
raster_settings = RasterizationSettings(
|
raster_settings = RasterizationSettings(
|
||||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
image_size=512, blur_radius=0.0, faces_per_pixel=1, bin_size=0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init renderer
|
|
||||||
rasterizer = MeshRasterizer(
|
rasterizer = MeshRasterizer(
|
||||||
cameras=cameras, raster_settings=raster_settings
|
cameras=cameras, raster_settings=raster_settings
|
||||||
)
|
)
|
||||||
renderer = MeshRenderer(
|
|
||||||
rasterizer=rasterizer,
|
|
||||||
shader=HardPhongShader(
|
|
||||||
lights=lights, cameras=cameras, materials=materials
|
|
||||||
),
|
|
||||||
)
|
|
||||||
images = renderer(sphere_mesh)
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
|
||||||
if DEBUG:
|
|
||||||
filename = "DEBUG_simple_sphere_light%s.png" % postfix
|
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
||||||
DATA_DIR / filename
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load reference image
|
# Test several shaders
|
||||||
image_ref_phong = load_rgb_image(
|
shaders = {
|
||||||
"test_simple_sphere_light%s.png" % postfix
|
"phong": HardPhongShader,
|
||||||
)
|
"gouraud": HardGouraudShader,
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref_phong, atol=0.05))
|
"flat": HardFlatShader,
|
||||||
|
}
|
||||||
|
for (name, shader_init) in shaders.items():
|
||||||
|
shader = shader_init(
|
||||||
|
lights=lights, cameras=cameras, materials=materials
|
||||||
|
)
|
||||||
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
|
images = renderer(sphere_mesh)
|
||||||
|
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
|
||||||
|
image_ref = load_rgb_image("test_%s" % filename)
|
||||||
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
|
if DEBUG:
|
||||||
|
filename = "DEBUG_" % filename
|
||||||
|
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||||
|
DATA_DIR / filename
|
||||||
|
)
|
||||||
|
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# Move the light to the +z axis in world space so it is
|
# Move the light to the +z axis in world space so it is
|
||||||
@ -121,7 +122,13 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
# +X left for both world and camera space.
|
# +X left for both world and camera space.
|
||||||
########################################################
|
########################################################
|
||||||
lights.location[..., 2] = -2.0
|
lights.location[..., 2] = -2.0
|
||||||
images = renderer(sphere_mesh, lights=lights)
|
phong_shader = HardPhongShader(
|
||||||
|
lights=lights, cameras=cameras, materials=materials
|
||||||
|
)
|
||||||
|
phong_renderer = MeshRenderer(
|
||||||
|
rasterizer=rasterizer, shader=phong_shader
|
||||||
|
)
|
||||||
|
images = phong_renderer(sphere_mesh, lights=lights)
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
rgb = images[0, ..., :3].squeeze().cpu()
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
|
filename = "DEBUG_simple_sphere_dark%s.png" % postfix
|
||||||
@ -135,53 +142,6 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
|
self.assertTrue(torch.allclose(rgb, image_ref_phong_dark, atol=0.05))
|
||||||
|
|
||||||
######################################
|
|
||||||
# Change the shader to a GouraudShader
|
|
||||||
######################################
|
|
||||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
|
||||||
renderer = MeshRenderer(
|
|
||||||
rasterizer=rasterizer,
|
|
||||||
shader=HardGouraudShader(
|
|
||||||
lights=lights, cameras=cameras, materials=materials
|
|
||||||
),
|
|
||||||
)
|
|
||||||
images = renderer(sphere_mesh)
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
|
||||||
if DEBUG:
|
|
||||||
filename = "DEBUG_simple_sphere_light_gouraud%s.png" % postfix
|
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
||||||
DATA_DIR / filename
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load reference image
|
|
||||||
image_ref_gouraud = load_rgb_image(
|
|
||||||
"test_simple_sphere_light_gouraud%s.png" % postfix
|
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref_gouraud, atol=0.005))
|
|
||||||
|
|
||||||
######################################
|
|
||||||
# Change the shader to a HardFlatShader
|
|
||||||
######################################
|
|
||||||
renderer = MeshRenderer(
|
|
||||||
rasterizer=rasterizer,
|
|
||||||
shader=HardFlatShader(
|
|
||||||
lights=lights, cameras=cameras, materials=materials
|
|
||||||
),
|
|
||||||
)
|
|
||||||
images = renderer(sphere_mesh)
|
|
||||||
rgb = images[0, ..., :3].squeeze().cpu()
|
|
||||||
if DEBUG:
|
|
||||||
filename = "DEBUG_simple_sphere_light_flat%s.png" % postfix
|
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
|
||||||
DATA_DIR / filename
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load reference image
|
|
||||||
image_ref_flat = load_rgb_image(
|
|
||||||
"test_simple_sphere_light_flat%s.png" % postfix
|
|
||||||
)
|
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref_flat, atol=0.005))
|
|
||||||
|
|
||||||
def test_simple_sphere_elevated_camera(self):
|
def test_simple_sphere_elevated_camera(self):
|
||||||
"""
|
"""
|
||||||
Test output of phong and gouraud shading matches a reference image using
|
Test output of phong and gouraud shading matches a reference image using
|
||||||
@ -193,13 +153,13 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
|
|
||||||
def test_simple_sphere_batched(self):
|
def test_simple_sphere_batched(self):
|
||||||
"""
|
"""
|
||||||
Test output of phong shading matches a reference image using
|
Test a mesh with vertex textures can be extended to form a batch, and
|
||||||
the default values for the light sources.
|
is rendered correctly with Phong, Gouraud and Flat Shaders.
|
||||||
"""
|
"""
|
||||||
batch_size = 5
|
batch_size = 20
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
|
||||||
# Init mesh
|
# Init mesh with vertex textures.
|
||||||
sphere_meshes = ico_sphere(5, device).extend(batch_size)
|
sphere_meshes = ico_sphere(5, device).extend(batch_size)
|
||||||
verts_padded = sphere_meshes.verts_padded()
|
verts_padded = sphere_meshes.verts_padded()
|
||||||
faces_padded = sphere_meshes.faces_padded()
|
faces_padded = sphere_meshes.faces_padded()
|
||||||
@ -224,26 +184,24 @@ class TestRenderingMeshes(unittest.TestCase):
|
|||||||
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
|
||||||
|
|
||||||
# Init renderer
|
# Init renderer
|
||||||
renderer = MeshRenderer(
|
rasterizer = MeshRasterizer(
|
||||||
rasterizer=MeshRasterizer(
|
cameras=cameras, raster_settings=raster_settings
|
||||||
cameras=cameras, raster_settings=raster_settings
|
|
||||||
),
|
|
||||||
shader=HardPhongShader(
|
|
||||||
lights=lights, cameras=cameras, materials=materials
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
images = renderer(sphere_meshes)
|
shaders = {
|
||||||
|
"phong": HardGouraudShader,
|
||||||
# Load ref image
|
"gouraud": HardGouraudShader,
|
||||||
image_ref = load_rgb_image("test_simple_sphere_light.png")
|
"flat": HardFlatShader,
|
||||||
|
}
|
||||||
for i in range(batch_size):
|
for (name, shader_init) in shaders.items():
|
||||||
rgb = images[i, ..., :3].squeeze().cpu()
|
shader = shader_init(
|
||||||
if DEBUG:
|
lights=lights, cameras=cameras, materials=materials
|
||||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
)
|
||||||
DATA_DIR / f"DEBUG_simple_sphere_{i}.png"
|
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
|
||||||
)
|
images = renderer(sphere_meshes)
|
||||||
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
image_ref = load_rgb_image("test_simple_sphere_light_%s.png" % name)
|
||||||
|
for i in range(batch_size):
|
||||||
|
rgb = images[i, ..., :3].squeeze().cpu()
|
||||||
|
self.assertTrue(torch.allclose(rgb, image_ref, atol=0.05))
|
||||||
|
|
||||||
def test_silhouette_with_grad(self):
|
def test_silhouette_with_grad(self):
|
||||||
"""
|
"""
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import unittest
|
import unittest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -61,3 +62,28 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
|
|||||||
example = TensorPropertiesTestClass(x=(), y=())
|
example = TensorPropertiesTestClass(x=(), y=())
|
||||||
self.assertTrue(len(example) == 0)
|
self.assertTrue(len(example) == 0)
|
||||||
self.assertTrue(example.isempty())
|
self.assertTrue(example.isempty())
|
||||||
|
|
||||||
|
def test_gather_props(self):
|
||||||
|
N = 4
|
||||||
|
x = torch.randn((N, 3, 4))
|
||||||
|
y = torch.randn((N, 5))
|
||||||
|
test_class = TensorPropertiesTestClass(x=x, y=y)
|
||||||
|
|
||||||
|
S = 15
|
||||||
|
idx = torch.tensor(np.random.choice(N, S))
|
||||||
|
test_class_gathered = test_class.gather_props(idx)
|
||||||
|
|
||||||
|
self.assertTrue(test_class_gathered.x.shape == (S, 3, 4))
|
||||||
|
self.assertTrue(test_class_gathered.y.shape == (S, 5))
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
inds = idx == i
|
||||||
|
if inds.sum() > 0:
|
||||||
|
# Check the gathered points in the output have the same value from
|
||||||
|
# the input.
|
||||||
|
self.assertClose(
|
||||||
|
test_class_gathered.x[inds].mean(dim=0), x[i, ...]
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
test_class_gathered.y[inds].mean(dim=0), y[i, ...]
|
||||||
|
)
|
||||||
|
@ -12,6 +12,7 @@ from pytorch3d.renderer.mesh.texturing import (
|
|||||||
interpolate_vertex_colors,
|
interpolate_vertex_colors,
|
||||||
)
|
)
|
||||||
from pytorch3d.structures import Meshes, Textures
|
from pytorch3d.structures import Meshes, Textures
|
||||||
|
from pytorch3d.structures.utils import list_to_padded
|
||||||
|
|
||||||
from common_testing import TestCaseMixin
|
from common_testing import TestCaseMixin
|
||||||
from test_meshes import TestMeshes
|
from test_meshes import TestMeshes
|
||||||
@ -154,6 +155,108 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.allclose(texels.squeeze(), expected_out.squeeze())
|
torch.allclose(texels.squeeze(), expected_out.squeeze())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_init_rgb_uv_fail(self):
|
||||||
|
V = 20
|
||||||
|
# Maps has wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "maps"):
|
||||||
|
Textures(
|
||||||
|
maps=torch.ones((5, 16, 16, 3, 4)),
|
||||||
|
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
|
||||||
|
verts_uvs=torch.ones((5, V, 2)),
|
||||||
|
)
|
||||||
|
# faces_uvs has wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "faces_uvs"):
|
||||||
|
Textures(
|
||||||
|
maps=torch.ones((5, 16, 16, 3)),
|
||||||
|
faces_uvs=torch.randint(size=(5, 10, 3, 3), low=0, high=V),
|
||||||
|
verts_uvs=torch.ones((5, V, 2)),
|
||||||
|
)
|
||||||
|
# verts_uvs has wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "verts_uvs"):
|
||||||
|
Textures(
|
||||||
|
maps=torch.ones((5, 16, 16, 3)),
|
||||||
|
faces_uvs=torch.randint(size=(5, 10, 3), low=0, high=V),
|
||||||
|
verts_uvs=torch.ones((5, V, 2, 3)),
|
||||||
|
)
|
||||||
|
# verts_rgb has wrong shape
|
||||||
|
with self.assertRaisesRegex(ValueError, "verts_rgb"):
|
||||||
|
Textures(verts_rgb=torch.ones((5, 16, 16, 3)))
|
||||||
|
|
||||||
|
# maps provided without verts/faces uvs
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "faces_uvs and verts_uvs are required"
|
||||||
|
):
|
||||||
|
Textures(maps=torch.ones((5, 16, 16, 3)))
|
||||||
|
|
||||||
|
def test_padded_to_packed(self):
|
||||||
|
N = 2
|
||||||
|
# Case where each face in the mesh has 3 unique uv vertex indices
|
||||||
|
# - i.e. even if a vertex is shared between multiple faces it will
|
||||||
|
# have a unique uv coordinate for each face.
|
||||||
|
faces_uvs_list = [
|
||||||
|
torch.tensor([[0, 1, 2], [3, 5, 4], [7, 6, 8]]),
|
||||||
|
torch.tensor([[0, 1, 2], [3, 4, 5]]),
|
||||||
|
] # (N, 3, 3)
|
||||||
|
verts_uvs_list = [torch.ones(9, 2), torch.ones(6, 2)]
|
||||||
|
faces_uvs_padded = list_to_padded(faces_uvs_list, pad_value=-1)
|
||||||
|
verts_uvs_padded = list_to_padded(verts_uvs_list)
|
||||||
|
tex = Textures(
|
||||||
|
maps=torch.ones((N, 16, 16, 3)),
|
||||||
|
faces_uvs=faces_uvs_padded,
|
||||||
|
verts_uvs=verts_uvs_padded,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is set inside Meshes when textures is passed as an input.
|
||||||
|
# Here we set _num_faces_per_mesh and _num_verts_per_mesh explicity.
|
||||||
|
tex1 = tex.clone()
|
||||||
|
tex1._num_faces_per_mesh = (
|
||||||
|
faces_uvs_padded.gt(-1).all(-1).sum(-1).tolist()
|
||||||
|
)
|
||||||
|
tex1._num_verts_per_mesh = torch.tensor([5, 4])
|
||||||
|
faces_packed = tex1.faces_uvs_packed()
|
||||||
|
verts_packed = tex1.verts_uvs_packed()
|
||||||
|
faces_list = tex1.faces_uvs_list()
|
||||||
|
verts_list = tex1.verts_uvs_list()
|
||||||
|
|
||||||
|
for f1, f2 in zip(faces_uvs_list, faces_list):
|
||||||
|
self.assertTrue((f1 == f2).all().item())
|
||||||
|
|
||||||
|
for f, v1, v2 in zip(faces_list, verts_list, verts_uvs_list):
|
||||||
|
idx = f.unique()
|
||||||
|
self.assertTrue((v1[idx] == v2).all().item())
|
||||||
|
|
||||||
|
self.assertTrue(faces_packed.shape == (3 + 2, 3))
|
||||||
|
|
||||||
|
# verts_packed is just flattened verts_padded.
|
||||||
|
# split sizes are not used for verts_uvs.
|
||||||
|
self.assertTrue(verts_packed.shape == (9 * 2, 2))
|
||||||
|
|
||||||
|
# Case where num_faces_per_mesh is not set
|
||||||
|
tex2 = tex.clone()
|
||||||
|
faces_packed = tex2.faces_uvs_packed()
|
||||||
|
verts_packed = tex2.verts_uvs_packed()
|
||||||
|
faces_list = tex2.faces_uvs_list()
|
||||||
|
verts_list = tex2.verts_uvs_list()
|
||||||
|
|
||||||
|
# Packed is just flattened padded as num_faces_per_mesh
|
||||||
|
# has not been provided.
|
||||||
|
self.assertTrue(verts_packed.shape == (9 * 2, 2))
|
||||||
|
self.assertTrue(faces_packed.shape == (3 * 2, 3))
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
self.assertTrue(
|
||||||
|
(faces_list[i] == faces_uvs_padded[i, ...].squeeze())
|
||||||
|
.all()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(N):
|
||||||
|
self.assertTrue(
|
||||||
|
(verts_list[i] == verts_uvs_padded[i, ...].squeeze())
|
||||||
|
.all()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
V = 20
|
V = 20
|
||||||
tex = Textures(
|
tex = Textures(
|
||||||
@ -233,13 +336,17 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
|||||||
mesh = TestMeshes.init_mesh(B, 30, 50)
|
mesh = TestMeshes.init_mesh(B, 30, 50)
|
||||||
V = mesh._V
|
V = mesh._V
|
||||||
F = mesh._F
|
F = mesh._F
|
||||||
tex = Textures(
|
|
||||||
|
# 1. Texture uvs
|
||||||
|
tex_uv = Textures(
|
||||||
maps=torch.randn((B, 16, 16, 3)),
|
maps=torch.randn((B, 16, 16, 3)),
|
||||||
faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
|
faces_uvs=torch.randint(size=(B, F, 3), low=0, high=V),
|
||||||
verts_uvs=torch.randn((B, V, 2)),
|
verts_uvs=torch.randn((B, V, 2)),
|
||||||
)
|
)
|
||||||
tex_mesh = Meshes(
|
tex_mesh = Meshes(
|
||||||
verts=mesh.verts_padded(), faces=mesh.faces_padded(), textures=tex
|
verts=mesh.verts_padded(),
|
||||||
|
faces=mesh.faces_padded(),
|
||||||
|
textures=tex_uv,
|
||||||
)
|
)
|
||||||
N = 20
|
N = 20
|
||||||
new_mesh = tex_mesh.extend(N)
|
new_mesh = tex_mesh.extend(N)
|
||||||
@ -269,5 +376,43 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
|||||||
new_tex.maps_padded(),
|
new_tex.maps_padded(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertIsNone(new_tex.verts_rgb_list())
|
||||||
|
self.assertIsNone(new_tex.verts_rgb_padded())
|
||||||
|
self.assertIsNone(new_tex.verts_rgb_packed())
|
||||||
|
|
||||||
|
# 2. Texture vertex RGB
|
||||||
|
tex_rgb = Textures(verts_rgb=torch.randn((B, V, 3)))
|
||||||
|
tex_mesh_rgb = Meshes(
|
||||||
|
verts=mesh.verts_padded(),
|
||||||
|
faces=mesh.faces_padded(),
|
||||||
|
textures=tex_rgb,
|
||||||
|
)
|
||||||
|
N = 20
|
||||||
|
new_mesh_rgb = tex_mesh_rgb.extend(N)
|
||||||
|
|
||||||
|
self.assertEqual(len(tex_mesh_rgb) * N, len(new_mesh_rgb))
|
||||||
|
|
||||||
|
tex_init = tex_mesh_rgb.textures
|
||||||
|
new_tex = new_mesh_rgb.textures
|
||||||
|
|
||||||
|
for i in range(len(tex_mesh_rgb)):
|
||||||
|
for n in range(N):
|
||||||
|
self.assertClose(
|
||||||
|
tex_init.verts_rgb_list()[i],
|
||||||
|
new_tex.verts_rgb_list()[i * N + n],
|
||||||
|
)
|
||||||
|
self.assertAllSeparate(
|
||||||
|
[tex_init.verts_rgb_padded(), new_tex.verts_rgb_padded()]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNone(new_tex.verts_uvs_padded())
|
||||||
|
self.assertIsNone(new_tex.verts_uvs_list())
|
||||||
|
self.assertIsNone(new_tex.verts_uvs_packed())
|
||||||
|
self.assertIsNone(new_tex.faces_uvs_padded())
|
||||||
|
self.assertIsNone(new_tex.faces_uvs_list())
|
||||||
|
self.assertIsNone(new_tex.faces_uvs_packed())
|
||||||
|
|
||||||
|
# 3. Error
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
tex_mesh.extend(N=-1)
|
tex_mesh.extend(N=-1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user