mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
detach for meshes, pointclouds, textures
Summary: Add `detach` for Meshes, Pointclouds, Textures Reviewed By: nikhilaravi Differential Revision: D23070418 fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
This commit is contained in:
parent
5852b74d12
commit
7f2f95f225
@ -242,6 +242,13 @@ class TexturesBase(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
"""
|
||||||
|
Each texture class should implement a method
|
||||||
|
to detach all necessary internal tensors.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
"""
|
"""
|
||||||
Each texture class should implement a method
|
Each texture class should implement a method
|
||||||
@ -388,6 +395,8 @@ class TexturesAtlas(TexturesBase):
|
|||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
tex = self.__class__(atlas=self.atlas_padded().clone())
|
tex = self.__class__(atlas=self.atlas_padded().clone())
|
||||||
|
if self._atlas_list is not None:
|
||||||
|
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
||||||
num_faces = (
|
num_faces = (
|
||||||
self._num_faces_per_mesh.clone()
|
self._num_faces_per_mesh.clone()
|
||||||
if torch.is_tensor(self._num_faces_per_mesh)
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
@ -397,6 +406,19 @@ class TexturesAtlas(TexturesBase):
|
|||||||
tex._num_faces_per_mesh = num_faces
|
tex._num_faces_per_mesh = num_faces
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
tex = self.__class__(atlas=self.atlas_padded().detach())
|
||||||
|
if self._atlas_list is not None:
|
||||||
|
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
||||||
|
num_faces = (
|
||||||
|
self._num_faces_per_mesh.detach()
|
||||||
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
|
else self._num_faces_per_mesh
|
||||||
|
)
|
||||||
|
tex.valid = self.valid.detach()
|
||||||
|
tex._num_faces_per_mesh = num_faces
|
||||||
|
return tex
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
props = ["atlas_list", "_num_faces_per_mesh"]
|
props = ["atlas_list", "_num_faces_per_mesh"]
|
||||||
new_props = self._getitem(index, props=props)
|
new_props = self._getitem(index, props=props)
|
||||||
@ -656,6 +678,12 @@ class TexturesUV(TexturesBase):
|
|||||||
self.faces_uvs_padded().clone(),
|
self.faces_uvs_padded().clone(),
|
||||||
self.verts_uvs_padded().clone(),
|
self.verts_uvs_padded().clone(),
|
||||||
)
|
)
|
||||||
|
if self._maps_list is not None:
|
||||||
|
tex._maps_list = [m.clone() for m in self._maps_list]
|
||||||
|
if self._verts_uvs_list is not None:
|
||||||
|
tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
|
||||||
|
if self._faces_uvs_list is not None:
|
||||||
|
tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
|
||||||
num_faces = (
|
num_faces = (
|
||||||
self._num_faces_per_mesh.clone()
|
self._num_faces_per_mesh.clone()
|
||||||
if torch.is_tensor(self._num_faces_per_mesh)
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
@ -665,6 +693,27 @@ class TexturesUV(TexturesBase):
|
|||||||
tex.valid = self.valid.clone()
|
tex.valid = self.valid.clone()
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
tex = self.__class__(
|
||||||
|
self.maps_padded().detach(),
|
||||||
|
self.faces_uvs_padded().detach(),
|
||||||
|
self.verts_uvs_padded().detach(),
|
||||||
|
)
|
||||||
|
if self._maps_list is not None:
|
||||||
|
tex._maps_list = [m.detach() for m in self._maps_list]
|
||||||
|
if self._verts_uvs_list is not None:
|
||||||
|
tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]
|
||||||
|
if self._faces_uvs_list is not None:
|
||||||
|
tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list]
|
||||||
|
num_faces = (
|
||||||
|
self._num_faces_per_mesh.detach()
|
||||||
|
if torch.is_tensor(self._num_faces_per_mesh)
|
||||||
|
else self._num_faces_per_mesh
|
||||||
|
)
|
||||||
|
tex._num_faces_per_mesh = num_faces
|
||||||
|
tex.valid = self.valid.detach()
|
||||||
|
return tex
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
|
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
|
||||||
new_props = self._getitem(index, props)
|
new_props = self._getitem(index, props)
|
||||||
@ -892,8 +941,8 @@ class TexturesVertex(TexturesBase):
|
|||||||
has a D dimensional feature vector.
|
has a D dimensional feature vector.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
verts_features: (N, V, D) tensor giving a feature vector with
|
verts_features: list of (Vi, D) or (N, V, D) tensor giving a feature
|
||||||
artbitrary dimensions for each vertex.
|
vector with artbitrary dimensions for each vertex.
|
||||||
"""
|
"""
|
||||||
if isinstance(verts_features, (tuple, list)):
|
if isinstance(verts_features, (tuple, list)):
|
||||||
correct_shape = all(
|
correct_shape = all(
|
||||||
@ -948,15 +997,28 @@ class TexturesVertex(TexturesBase):
|
|||||||
tex = self.__class__(self.verts_features_padded().clone())
|
tex = self.__class__(self.verts_features_padded().clone())
|
||||||
if self._verts_features_list is not None:
|
if self._verts_features_list is not None:
|
||||||
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
|
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
|
||||||
num_faces = (
|
num_verts = (
|
||||||
self._num_verts_per_mesh.clone()
|
self._num_verts_per_mesh.clone()
|
||||||
if torch.is_tensor(self._num_verts_per_mesh)
|
if torch.is_tensor(self._num_verts_per_mesh)
|
||||||
else self._num_verts_per_mesh
|
else self._num_verts_per_mesh
|
||||||
)
|
)
|
||||||
tex._num_verts_per_mesh = num_faces
|
tex._num_verts_per_mesh = num_verts
|
||||||
tex.valid = self.valid.clone()
|
tex.valid = self.valid.clone()
|
||||||
return tex
|
return tex
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
tex = self.__class__(self.verts_features_padded().detach())
|
||||||
|
if self._verts_features_list is not None:
|
||||||
|
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
|
||||||
|
num_verts = (
|
||||||
|
self._num_verts_per_mesh.detach()
|
||||||
|
if torch.is_tensor(self._num_verts_per_mesh)
|
||||||
|
else self._num_verts_per_mesh
|
||||||
|
)
|
||||||
|
tex._num_verts_per_mesh = num_verts
|
||||||
|
tex.valid = self.valid.detach()
|
||||||
|
return tex
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
props = ["verts_features_list", "_num_verts_per_mesh"]
|
props = ["verts_features_list", "_num_verts_per_mesh"]
|
||||||
new_props = self._getitem(index, props)
|
new_props = self._getitem(index, props)
|
||||||
|
@ -1138,6 +1138,28 @@ class Meshes(object):
|
|||||||
other.textures = self.textures.clone()
|
other.textures = self.textures.clone()
|
||||||
return other
|
return other
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
"""
|
||||||
|
Detach Meshes object. All internal tensors are detached individually.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new Meshes object.
|
||||||
|
"""
|
||||||
|
verts_list = self.verts_list()
|
||||||
|
faces_list = self.faces_list()
|
||||||
|
new_verts_list = [v.detach() for v in verts_list]
|
||||||
|
new_faces_list = [f.detach() for f in faces_list]
|
||||||
|
other = self.__class__(verts=new_verts_list, faces=new_faces_list)
|
||||||
|
for k in self._INTERNAL_TENSORS:
|
||||||
|
v = getattr(self, k)
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
setattr(other, k, v.detach())
|
||||||
|
|
||||||
|
# Textures is not a tensor but has a detach method
|
||||||
|
if self.textures is not None:
|
||||||
|
other.textures = self.textures.detach()
|
||||||
|
return other
|
||||||
|
|
||||||
def to(self, device, copy: bool = False):
|
def to(self, device, copy: bool = False):
|
||||||
"""
|
"""
|
||||||
Match functionality of torch.Tensor.to()
|
Match functionality of torch.Tensor.to()
|
||||||
|
@ -655,6 +655,42 @@ class Pointclouds(object):
|
|||||||
setattr(other, k, v.clone())
|
setattr(other, k, v.clone())
|
||||||
return other
|
return other
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
"""
|
||||||
|
Detach Pointclouds object. All internal tensors are detached
|
||||||
|
individually.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
new Pointclouds object.
|
||||||
|
"""
|
||||||
|
# instantiate new pointcloud with the representation which is not None
|
||||||
|
# (either list or tensor) to save compute.
|
||||||
|
new_points, new_normals, new_features = None, None, None
|
||||||
|
if self._points_list is not None:
|
||||||
|
new_points = [v.detach() for v in self.points_list()]
|
||||||
|
normals_list = self.normals_list()
|
||||||
|
features_list = self.features_list()
|
||||||
|
if normals_list is not None:
|
||||||
|
new_normals = [n.detach() for n in normals_list]
|
||||||
|
if features_list is not None:
|
||||||
|
new_features = [f.detach() for f in features_list]
|
||||||
|
elif self._points_padded is not None:
|
||||||
|
new_points = self.points_padded().detach()
|
||||||
|
normals_padded = self.normals_padded()
|
||||||
|
features_padded = self.features_padded()
|
||||||
|
if normals_padded is not None:
|
||||||
|
new_normals = self.normals_padded().detach()
|
||||||
|
if features_padded is not None:
|
||||||
|
new_features = self.features_padded().detach()
|
||||||
|
other = self.__class__(
|
||||||
|
points=new_points, normals=new_normals, features=new_features
|
||||||
|
)
|
||||||
|
for k in self._INTERNAL_TENSORS:
|
||||||
|
v = getattr(self, k)
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
setattr(other, k, v.detach())
|
||||||
|
return other
|
||||||
|
|
||||||
def to(self, device, copy: bool = False):
|
def to(self, device, copy: bool = False):
|
||||||
"""
|
"""
|
||||||
Match functionality of torch.Tensor.to()
|
Match functionality of torch.Tensor.to()
|
||||||
|
@ -20,6 +20,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
max_f: int = 300,
|
max_f: int = 300,
|
||||||
lists_to_tensors: bool = False,
|
lists_to_tensors: bool = False,
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
|
requires_grad: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Function to generate a Meshes object of N meshes with
|
Function to generate a Meshes object of N meshes with
|
||||||
@ -57,7 +58,12 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# Generate the actual vertices and faces.
|
# Generate the actual vertices and faces.
|
||||||
for i in range(num_meshes):
|
for i in range(num_meshes):
|
||||||
verts = torch.rand((v[i], 3), dtype=torch.float32, device=device)
|
verts = torch.rand(
|
||||||
|
(v[i], 3),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
)
|
||||||
faces = torch.randint(
|
faces = torch.randint(
|
||||||
v[i], size=(f[i], 3), dtype=torch.int64, device=device
|
v[i], size=(f[i], 3), dtype=torch.int64, device=device
|
||||||
)
|
)
|
||||||
@ -353,6 +359,26 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertSeparate(new_mesh.faces_padded(), mesh.faces_padded())
|
self.assertSeparate(new_mesh.faces_padded(), mesh.faces_padded())
|
||||||
self.assertSeparate(new_mesh.edges_packed(), mesh.edges_packed())
|
self.assertSeparate(new_mesh.edges_packed(), mesh.edges_packed())
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
N = 5
|
||||||
|
mesh = TestMeshes.init_mesh(N, 10, 100, requires_grad=True)
|
||||||
|
for force in [0, 1]:
|
||||||
|
if force:
|
||||||
|
# force mesh to have computed attributes
|
||||||
|
mesh.verts_packed()
|
||||||
|
mesh.edges_packed()
|
||||||
|
mesh.verts_padded()
|
||||||
|
|
||||||
|
new_mesh = mesh.detach()
|
||||||
|
|
||||||
|
self.assertFalse(new_mesh.verts_packed().requires_grad)
|
||||||
|
self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
|
||||||
|
self.assertTrue(new_mesh.verts_padded().requires_grad == False)
|
||||||
|
self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
|
||||||
|
for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
|
||||||
|
self.assertTrue(newv.requires_grad == False)
|
||||||
|
self.assertClose(newv, v)
|
||||||
|
|
||||||
def test_laplacian_packed(self):
|
def test_laplacian_packed(self):
|
||||||
def naive_laplacian_packed(meshes):
|
def naive_laplacian_packed(meshes):
|
||||||
verts_packed = meshes.verts_packed()
|
verts_packed = meshes.verts_packed()
|
||||||
|
@ -24,6 +24,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
|||||||
with_normals: bool = True,
|
with_normals: bool = True,
|
||||||
with_features: bool = True,
|
with_features: bool = True,
|
||||||
min_points: int = 0,
|
min_points: int = 0,
|
||||||
|
requires_grad: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Function to generate a Pointclouds object of N meshes with
|
Function to generate a Pointclouds object of N meshes with
|
||||||
@ -49,16 +50,31 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
|||||||
p.fill_(p[0])
|
p.fill_(p[0])
|
||||||
|
|
||||||
points_list = [
|
points_list = [
|
||||||
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
|
torch.rand(
|
||||||
|
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
|
||||||
|
)
|
||||||
|
for i in p
|
||||||
]
|
]
|
||||||
normals_list, features_list = None, None
|
normals_list, features_list = None, None
|
||||||
if with_normals:
|
if with_normals:
|
||||||
normals_list = [
|
normals_list = [
|
||||||
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
|
torch.rand(
|
||||||
|
(i, 3),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
)
|
||||||
|
for i in p
|
||||||
]
|
]
|
||||||
if with_features:
|
if with_features:
|
||||||
features_list = [
|
features_list = [
|
||||||
torch.rand((i, channels), device=device, dtype=torch.float32) for i in p
|
torch.rand(
|
||||||
|
(i, channels),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.float32,
|
||||||
|
requires_grad=requires_grad,
|
||||||
|
)
|
||||||
|
for i in p
|
||||||
]
|
]
|
||||||
|
|
||||||
if lists_to_tensors:
|
if lists_to_tensors:
|
||||||
@ -382,6 +398,39 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.assertCloudsEqual(clouds, new_clouds)
|
self.assertCloudsEqual(clouds, new_clouds)
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
N = 5
|
||||||
|
for lists_to_tensors in (True, False):
|
||||||
|
clouds = self.init_cloud(
|
||||||
|
N, 100, 5, lists_to_tensors=lists_to_tensors, requires_grad=True
|
||||||
|
)
|
||||||
|
for force in (False, True):
|
||||||
|
if force:
|
||||||
|
clouds.points_packed()
|
||||||
|
|
||||||
|
new_clouds = clouds.detach()
|
||||||
|
|
||||||
|
for cloud in new_clouds.points_list():
|
||||||
|
self.assertTrue(cloud.requires_grad == False)
|
||||||
|
for normal in new_clouds.normals_list():
|
||||||
|
self.assertTrue(normal.requires_grad == False)
|
||||||
|
for feats in new_clouds.features_list():
|
||||||
|
self.assertTrue(feats.requires_grad == False)
|
||||||
|
|
||||||
|
for attrib in [
|
||||||
|
"points_packed",
|
||||||
|
"normals_packed",
|
||||||
|
"features_packed",
|
||||||
|
"points_padded",
|
||||||
|
"normals_padded",
|
||||||
|
"features_padded",
|
||||||
|
]:
|
||||||
|
self.assertTrue(
|
||||||
|
getattr(new_clouds, attrib)().requires_grad == False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertCloudsEqual(clouds, new_clouds)
|
||||||
|
|
||||||
def assertCloudsEqual(self, cloud1, cloud2):
|
def assertCloudsEqual(self, cloud1, cloud2):
|
||||||
N = len(cloud1)
|
N = len(cloud1)
|
||||||
self.assertEqual(N, len(cloud2))
|
self.assertEqual(N, len(cloud2))
|
||||||
|
@ -113,11 +113,37 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
|
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
|
||||||
|
tex.verts_features_list()
|
||||||
tex_cloned = tex.clone()
|
tex_cloned = tex.clone()
|
||||||
self.assertSeparate(
|
self.assertSeparate(
|
||||||
tex._verts_features_padded, tex_cloned._verts_features_padded
|
tex._verts_features_padded, tex_cloned._verts_features_padded
|
||||||
)
|
)
|
||||||
|
self.assertClose(tex._verts_features_padded, tex_cloned._verts_features_padded)
|
||||||
self.assertSeparate(tex.valid, tex_cloned.valid)
|
self.assertSeparate(tex.valid, tex_cloned.valid)
|
||||||
|
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
|
||||||
|
for i in range(tex._N):
|
||||||
|
self.assertSeparate(
|
||||||
|
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
|
||||||
|
)
|
||||||
|
self.assertClose(
|
||||||
|
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
tex = TexturesVertex(
|
||||||
|
verts_features=torch.rand(size=(10, 100, 128), requires_grad=True)
|
||||||
|
)
|
||||||
|
tex.verts_features_list()
|
||||||
|
tex_detached = tex.detach()
|
||||||
|
self.assertFalse(tex_detached._verts_features_padded.requires_grad)
|
||||||
|
self.assertClose(
|
||||||
|
tex_detached._verts_features_padded, tex._verts_features_padded
|
||||||
|
)
|
||||||
|
for i in range(tex._N):
|
||||||
|
self.assertClose(
|
||||||
|
tex._verts_features_list[i], tex_detached._verts_features_list[i]
|
||||||
|
)
|
||||||
|
self.assertFalse(tex_detached._verts_features_list[i].requires_grad)
|
||||||
|
|
||||||
def test_extend(self):
|
def test_extend(self):
|
||||||
B = 10
|
B = 10
|
||||||
@ -278,9 +304,25 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
|
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
|
||||||
|
tex.atlas_list()
|
||||||
tex_cloned = tex.clone()
|
tex_cloned = tex.clone()
|
||||||
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
|
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
|
||||||
|
self.assertClose(tex._atlas_padded, tex_cloned._atlas_padded)
|
||||||
self.assertSeparate(tex.valid, tex_cloned.valid)
|
self.assertSeparate(tex.valid, tex_cloned.valid)
|
||||||
|
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
|
||||||
|
for i in range(tex._N):
|
||||||
|
self.assertSeparate(tex._atlas_list[i], tex_cloned._atlas_list[i])
|
||||||
|
self.assertClose(tex._atlas_list[i], tex_cloned._atlas_list[i])
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3), requires_grad=True))
|
||||||
|
tex.atlas_list()
|
||||||
|
tex_detached = tex.detach()
|
||||||
|
self.assertFalse(tex_detached._atlas_padded.requires_grad)
|
||||||
|
self.assertClose(tex_detached._atlas_padded, tex._atlas_padded)
|
||||||
|
for i in range(tex._N):
|
||||||
|
self.assertFalse(tex_detached._atlas_list[i].requires_grad)
|
||||||
|
self.assertClose(tex._atlas_list[i], tex_detached._atlas_list[i])
|
||||||
|
|
||||||
def test_extend(self):
|
def test_extend(self):
|
||||||
B = 10
|
B = 10
|
||||||
@ -478,11 +520,49 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
|||||||
faces_uvs=torch.rand(size=(5, 10, 3)),
|
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||||
verts_uvs=torch.rand(size=(5, 15, 2)),
|
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||||
)
|
)
|
||||||
|
tex.faces_uvs_list()
|
||||||
|
tex.verts_uvs_list()
|
||||||
tex_cloned = tex.clone()
|
tex_cloned = tex.clone()
|
||||||
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
|
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
|
||||||
|
self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
|
||||||
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
|
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
|
||||||
|
self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
|
||||||
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
|
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
|
||||||
|
self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
|
||||||
self.assertSeparate(tex.valid, tex_cloned.valid)
|
self.assertSeparate(tex.valid, tex_cloned.valid)
|
||||||
|
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
|
||||||
|
for i in range(tex._N):
|
||||||
|
self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
|
||||||
|
self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
|
||||||
|
self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
|
||||||
|
self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
|
||||||
|
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
|
||||||
|
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
|
||||||
|
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
|
||||||
|
|
||||||
|
def test_detach(self):
|
||||||
|
tex = TexturesUV(
|
||||||
|
maps=torch.ones((5, 16, 16, 3), requires_grad=True),
|
||||||
|
faces_uvs=torch.rand(size=(5, 10, 3)),
|
||||||
|
verts_uvs=torch.rand(size=(5, 15, 2)),
|
||||||
|
)
|
||||||
|
tex.faces_uvs_list()
|
||||||
|
tex.verts_uvs_list()
|
||||||
|
tex_detached = tex.detach()
|
||||||
|
self.assertFalse(tex_detached._maps_padded.requires_grad)
|
||||||
|
self.assertClose(tex._maps_padded, tex_detached._maps_padded)
|
||||||
|
self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
|
||||||
|
self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
|
||||||
|
self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
|
||||||
|
self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
|
||||||
|
for i in range(tex._N):
|
||||||
|
self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
|
||||||
|
self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
|
||||||
|
self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
|
||||||
|
self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
|
||||||
|
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
|
||||||
|
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
|
||||||
|
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
|
||||||
|
|
||||||
def test_extend(self):
|
def test_extend(self):
|
||||||
B = 5
|
B = 5
|
||||||
|
Loading…
x
Reference in New Issue
Block a user