mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 12:22:49 +08:00
textures dimension check
Summary: When textures are set on `Meshes` we need to check if the dimensions actually match that of the verts/faces in the mesh. There was a github issue where someone tried to set the attribute after construction of the `Meshes` object and ran into an error when trying to sample textures. The desired usage is to initialize the class with the textures (not set an attribute afterwards) but in any case we need to check the dimensions match before sampling textures. Reviewed By: bottler Differential Revision: D29020249 fbshipit-source-id: 9fb8a5368b83c9ec53652df92b96fc8b2613f591
This commit is contained in:
parent
1cd1436460
commit
ef16253953
@ -574,6 +574,15 @@ class TexturesAtlas(TexturesBase):
|
|||||||
"""
|
"""
|
||||||
return self.__class__(atlas=[torch.cat(self.atlas_list())])
|
return self.__class__(atlas=[torch.cat(self.atlas_list())])
|
||||||
|
|
||||||
|
def check_shapes(
|
||||||
|
self, batch_size: int, max_num_verts: int, max_num_faces: int
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the dimensions of the atlas match that of the mesh faces
|
||||||
|
"""
|
||||||
|
# (N, F) should be the same
|
||||||
|
return self.atlas_padded().shape[0:2] == (batch_size, max_num_faces)
|
||||||
|
|
||||||
|
|
||||||
class TexturesUV(TexturesBase):
|
class TexturesUV(TexturesBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1213,6 +1222,18 @@ class TexturesUV(TexturesBase):
|
|||||||
centers = centers[0, :, 0].T
|
centers = centers[0, :, 0].T
|
||||||
return centers
|
return centers
|
||||||
|
|
||||||
|
def check_shapes(
|
||||||
|
self, batch_size: int, max_num_verts: int, max_num_faces: int
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the dimensions of the verts/faces uvs match that of the mesh
|
||||||
|
"""
|
||||||
|
# (N, F) should be the same
|
||||||
|
# (N, V) is not guaranteed to be the same
|
||||||
|
return (self.faces_uvs_padded().shape[0:2] == (batch_size, max_num_faces)) and (
|
||||||
|
self.verts_uvs_padded().shape[0] == batch_size
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TexturesVertex(TexturesBase):
|
class TexturesVertex(TexturesBase):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -1292,6 +1313,13 @@ class TexturesVertex(TexturesBase):
|
|||||||
new_props = self._getitem(index, props)
|
new_props = self._getitem(index, props)
|
||||||
verts_features = new_props["verts_features_list"]
|
verts_features = new_props["verts_features_list"]
|
||||||
if isinstance(verts_features, list):
|
if isinstance(verts_features, list):
|
||||||
|
# Handle the case of an empty list
|
||||||
|
if len(verts_features) == 0:
|
||||||
|
verts_features = torch.empty(
|
||||||
|
size=(0, 0, 3),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=self.verts_features_padded().device,
|
||||||
|
)
|
||||||
new_tex = self.__class__(verts_features=verts_features)
|
new_tex = self.__class__(verts_features=verts_features)
|
||||||
elif torch.is_tensor(verts_features):
|
elif torch.is_tensor(verts_features):
|
||||||
new_tex = self.__class__(verts_features=[verts_features])
|
new_tex = self.__class__(verts_features=[verts_features])
|
||||||
@ -1410,3 +1438,12 @@ class TexturesVertex(TexturesBase):
|
|||||||
Return a new TexturesVertex amalgamating the batch.
|
Return a new TexturesVertex amalgamating the batch.
|
||||||
"""
|
"""
|
||||||
return self.__class__(verts_features=[torch.cat(self.verts_features_list())])
|
return self.__class__(verts_features=[torch.cat(self.verts_features_list())])
|
||||||
|
|
||||||
|
def check_shapes(
|
||||||
|
self, batch_size: int, max_num_verts: int, max_num_faces: int
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the dimensions of the verts features match that of the mesh verts
|
||||||
|
"""
|
||||||
|
# (N, V) should be the same
|
||||||
|
return self.verts_features_padded().shape[:-1] == (batch_size, max_num_verts)
|
||||||
|
@ -255,6 +255,7 @@ class Meshes:
|
|||||||
if textures is not None and not hasattr(textures, "sample_textures"):
|
if textures is not None and not hasattr(textures, "sample_textures"):
|
||||||
msg = "Expected textures to be an instance of type TexturesBase; got %r"
|
msg = "Expected textures to be an instance of type TexturesBase; got %r"
|
||||||
raise ValueError(msg % type(textures))
|
raise ValueError(msg % type(textures))
|
||||||
|
|
||||||
self.textures = textures
|
self.textures = textures
|
||||||
|
|
||||||
# Indicates whether the meshes in the list/batch have the same number
|
# Indicates whether the meshes in the list/batch have the same number
|
||||||
@ -424,10 +425,14 @@ class Meshes:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Set the num verts/faces on the textures if present.
|
# Set the num verts/faces on the textures if present.
|
||||||
if self.textures is not None:
|
if textures is not None:
|
||||||
|
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
|
||||||
|
if not shape_ok:
|
||||||
|
msg = "Textures do not match the dimensions of Meshes."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
|
self.textures._num_faces_per_mesh = self._num_faces_per_mesh.tolist()
|
||||||
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
|
self.textures._num_verts_per_mesh = self._num_verts_per_mesh.tolist()
|
||||||
self.textures._N = self._N
|
|
||||||
self.textures.valid = self.valid
|
self.textures.valid = self.valid
|
||||||
|
|
||||||
if verts_normals is not None:
|
if verts_normals is not None:
|
||||||
@ -1560,6 +1565,13 @@ class Meshes:
|
|||||||
|
|
||||||
def sample_textures(self, fragments):
|
def sample_textures(self, fragments):
|
||||||
if self.textures is not None:
|
if self.textures is not None:
|
||||||
|
|
||||||
|
# Check dimensions of textures match that of meshes
|
||||||
|
shape_ok = self.textures.check_shapes(self._N, self._V, self._F)
|
||||||
|
if not shape_ok:
|
||||||
|
msg = "Textures do not match the dimensions of Meshes."
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
# Pass in faces packed. If the textures are defined per
|
# Pass in faces packed. If the textures are defined per
|
||||||
# vertex, the face indices are needed in order to interpolate
|
# vertex, the face indices are needed in order to interpolate
|
||||||
# the vertex attributes across the face.
|
# the vertex attributes across the face.
|
||||||
|
@ -251,7 +251,7 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
N = 5
|
N = 5
|
||||||
V = 20
|
V = 20
|
||||||
source = {"verts_features": torch.randn(size=(N, 10, 128))}
|
source = {"verts_features": torch.randn(size=(N, V, 128))}
|
||||||
tex = TexturesVertex(verts_features=source["verts_features"])
|
tex = TexturesVertex(verts_features=source["verts_features"])
|
||||||
|
|
||||||
verts = torch.rand(size=(N, V, 3))
|
verts = torch.rand(size=(N, V, 3))
|
||||||
@ -268,6 +268,30 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
|
|||||||
tryindex(self, index, tex, meshes, source)
|
tryindex(self, index, tex, meshes, source)
|
||||||
tryindex(self, [2, 4], tex, meshes, source)
|
tryindex(self, [2, 4], tex, meshes, source)
|
||||||
|
|
||||||
|
def test_sample_textures_error(self):
|
||||||
|
N = 5
|
||||||
|
V = 20
|
||||||
|
verts = torch.rand(size=(N, V, 3))
|
||||||
|
faces = torch.randint(size=(N, 10, 3), high=V)
|
||||||
|
tex = TexturesVertex(verts_features=torch.randn(size=(N, 10, 128)))
|
||||||
|
|
||||||
|
# Verts features have the wrong number of verts
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
# Verts features have the wrong batch dim
|
||||||
|
tex = TexturesVertex(verts_features=torch.randn(size=(1, V, 128)))
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
meshes = Meshes(verts=verts, faces=faces)
|
||||||
|
meshes.textures = tex
|
||||||
|
|
||||||
|
# Cannot use the texture attribute set on meshes for sampling
|
||||||
|
# textures if the dimensions don't match
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
meshes.sample_textures(None)
|
||||||
|
|
||||||
|
|
||||||
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
||||||
def test_sample_texture_atlas(self):
|
def test_sample_texture_atlas(self):
|
||||||
@ -456,11 +480,12 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
N = 5
|
N = 5
|
||||||
V = 20
|
V = 20
|
||||||
source = {"atlas": torch.randn(size=(N, 10, 4, 4, 3))}
|
F = 10
|
||||||
|
source = {"atlas": torch.randn(size=(N, F, 4, 4, 3))}
|
||||||
tex = TexturesAtlas(atlas=source["atlas"])
|
tex = TexturesAtlas(atlas=source["atlas"])
|
||||||
|
|
||||||
verts = torch.rand(size=(N, V, 3))
|
verts = torch.rand(size=(N, V, 3))
|
||||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
faces = torch.randint(size=(N, F, 3), high=V)
|
||||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
tryindex(self, 2, tex, meshes, source)
|
tryindex(self, 2, tex, meshes, source)
|
||||||
@ -473,6 +498,32 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
|||||||
tryindex(self, index, tex, meshes, source)
|
tryindex(self, index, tex, meshes, source)
|
||||||
tryindex(self, [2, 4], tex, meshes, source)
|
tryindex(self, [2, 4], tex, meshes, source)
|
||||||
|
|
||||||
|
def test_sample_textures_error(self):
|
||||||
|
N = 1
|
||||||
|
V = 20
|
||||||
|
F = 10
|
||||||
|
verts = torch.rand(size=(5, V, 3))
|
||||||
|
faces = torch.randint(size=(5, F, 3), high=V)
|
||||||
|
meshes = Meshes(verts=verts, faces=faces)
|
||||||
|
|
||||||
|
# TexturesAtlas have the wrong batch dim
|
||||||
|
tex = TexturesAtlas(atlas=torch.randn(size=(1, F, 4, 4, 3)))
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
# TexturesAtlas have the wrong number of faces
|
||||||
|
tex = TexturesAtlas(atlas=torch.randn(size=(N, 15, 4, 4, 3)))
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
meshes = Meshes(verts=verts, faces=faces)
|
||||||
|
meshes.textures = tex
|
||||||
|
|
||||||
|
# Cannot use the texture attribute set on meshes for sampling
|
||||||
|
# textures if the dimensions don't match
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
meshes.sample_textures(None)
|
||||||
|
|
||||||
|
|
||||||
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
@ -824,9 +875,10 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_getitem(self):
|
def test_getitem(self):
|
||||||
N = 5
|
N = 5
|
||||||
V = 20
|
V = 20
|
||||||
|
F = 10
|
||||||
source = {
|
source = {
|
||||||
"maps": torch.rand(size=(N, 1, 1, 3)),
|
"maps": torch.rand(size=(N, 1, 1, 3)),
|
||||||
"faces_uvs": torch.randint(size=(N, 10, 3), high=V),
|
"faces_uvs": torch.randint(size=(N, F, 3), high=V),
|
||||||
"verts_uvs": torch.randn(size=(N, V, 2)),
|
"verts_uvs": torch.randn(size=(N, V, 2)),
|
||||||
}
|
}
|
||||||
tex = TexturesUV(
|
tex = TexturesUV(
|
||||||
@ -836,7 +888,7 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
verts = torch.rand(size=(N, V, 3))
|
verts = torch.rand(size=(N, V, 3))
|
||||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
faces = torch.randint(size=(N, F, 3), high=V)
|
||||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
tryindex(self, 2, tex, meshes, source)
|
tryindex(self, 2, tex, meshes, source)
|
||||||
@ -858,6 +910,46 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
|||||||
expected = torch.FloatTensor([[32, 224], [64, 96], [64, 128]])
|
expected = torch.FloatTensor([[32, 224], [64, 96], [64, 128]])
|
||||||
self.assertClose(tex.centers_for_image(0), expected)
|
self.assertClose(tex.centers_for_image(0), expected)
|
||||||
|
|
||||||
|
def test_sample_textures_error(self):
|
||||||
|
N = 1
|
||||||
|
V = 20
|
||||||
|
F = 10
|
||||||
|
maps = torch.rand(size=(N, 1, 1, 3))
|
||||||
|
verts_uvs = torch.randn(size=(N, V, 2))
|
||||||
|
tex = TexturesUV(
|
||||||
|
maps=maps,
|
||||||
|
faces_uvs=torch.randint(size=(N, 15, 3), high=V),
|
||||||
|
verts_uvs=verts_uvs,
|
||||||
|
)
|
||||||
|
verts = torch.rand(size=(5, V, 3))
|
||||||
|
faces = torch.randint(size=(5, 10, 3), high=V)
|
||||||
|
meshes = Meshes(verts=verts, faces=faces)
|
||||||
|
|
||||||
|
# Wrong number of faces
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
# Wrong batch dim for faces
|
||||||
|
tex = TexturesUV(
|
||||||
|
maps=maps,
|
||||||
|
faces_uvs=torch.randint(size=(1, F, 3), high=V),
|
||||||
|
verts_uvs=verts_uvs,
|
||||||
|
)
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
Meshes(verts=verts, faces=faces, textures=tex)
|
||||||
|
|
||||||
|
# Wrong batch dim for verts_uvs is not necessary to check as
|
||||||
|
# there is already a check inside TexturesUV for a batch dim
|
||||||
|
# mismatch with faces_uvs
|
||||||
|
|
||||||
|
meshes = Meshes(verts=verts, faces=faces)
|
||||||
|
meshes.textures = tex
|
||||||
|
|
||||||
|
# Cannot use the texture attribute set on meshes for sampling
|
||||||
|
# textures if the dimensions don't match
|
||||||
|
with self.assertRaisesRegex(ValueError, "do not match the dimensions"):
|
||||||
|
meshes.sample_textures(None)
|
||||||
|
|
||||||
|
|
||||||
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
|
class TestRectanglePacking(TestCaseMixin, unittest.TestCase):
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user