mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix batching bug from TexturesUV packed ambiguity, other textures tidyup
Summary: faces_uvs_packed and verts_uvs_packed were only used in one place and the definition of the former was ambiguous. This meant that the wrong coordinates could be used for meshes other than the first in the batch. I have therefore removed both functions and build their common result inline. Added a test that a simple batch of two meshes is rendered consistently with the rendering of each alone. This test would have failed before. I hope this fixes https://github.com/facebookresearch/pytorch3d/issues/283. Some other small improvements to the textures code. Reviewed By: nikhilaravi Differential Revision: D23161936 fbshipit-source-id: f99b560a46f6b30262e07028b049812bc04350a7
This commit is contained in:
parent
9aaba0483c
commit
9a50cf800e
@ -3,9 +3,9 @@
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.")
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.")
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
|
||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
@ -7,10 +7,10 @@ from typing import Dict, List, Optional
|
||||
import torch
|
||||
from pytorch3d.io import load_objs_as_meshes
|
||||
from pytorch3d.renderer import (
|
||||
FoVPerspectiveCameras,
|
||||
HardPhongShader,
|
||||
MeshRasterizer,
|
||||
MeshRenderer,
|
||||
FoVPerspectiveCameras,
|
||||
PointLights,
|
||||
RasterizationSettings,
|
||||
TexturesVertex,
|
||||
|
@ -45,7 +45,10 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
|
||||
# Mask for the background.
|
||||
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
|
||||
|
||||
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
||||
if torch.is_tensor(blend_params.background_color):
|
||||
background_color = blend_params.background_color
|
||||
else:
|
||||
background_color = colors.new_tensor(blend_params.background_color) # (3)
|
||||
|
||||
# Find out how much background_color needs to be expanded to be used for masked_scatter.
|
||||
num_background_pixels = is_background.sum()
|
||||
|
@ -137,7 +137,7 @@ def _pad_texture_maps(
|
||||
# This is also useful to have so that inside `Meshes`
|
||||
# we can allow the input textures to be any texture
|
||||
# type which is an instance of the base class.
|
||||
class TexturesBase(object):
|
||||
class TexturesBase:
|
||||
def __init__(self):
|
||||
self._N = 0
|
||||
self.valid = None
|
||||
@ -262,9 +262,6 @@ class TexturesBase(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __repr__(self):
|
||||
return "TexturesBase"
|
||||
|
||||
|
||||
def Textures(
|
||||
maps: Union[List, torch.Tensor, None] = None,
|
||||
@ -385,14 +382,6 @@ class TexturesAtlas(TexturesBase):
|
||||
# refer to the __init__ of Meshes.
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
# This is a hack to allow the child classes to also have the same representation
|
||||
# as the parent. In meshes.py we check that the input textures have the correct
|
||||
# type. However due to circular imports issues, we can't import the texture
|
||||
# classes into any files in pytorch3d.structures. Instead we check
|
||||
# for repr(textures) == "TexturesBase".
|
||||
def __repr__(self):
|
||||
return super().__repr__()
|
||||
|
||||
def clone(self):
|
||||
tex = self.__class__(atlas=self.atlas_padded().clone())
|
||||
if self._atlas_list is not None:
|
||||
@ -556,10 +545,7 @@ class TexturesUV(TexturesBase):
|
||||
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
|
||||
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face
|
||||
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
|
||||
|
||||
Note: only the padded and list representation of the textures are stored
|
||||
and the packed representations is computed on the fly and
|
||||
not cached.
|
||||
(a FloatTensor with values between 0 and 1)
|
||||
"""
|
||||
super().__init__()
|
||||
if isinstance(faces_uvs, (list, tuple)):
|
||||
@ -611,9 +597,6 @@ class TexturesUV(TexturesBase):
|
||||
"verts_uvs and faces_uvs must have the same batch dimension"
|
||||
)
|
||||
if not all(v.device == self.device for v in verts_uvs):
|
||||
import pdb
|
||||
|
||||
pdb.set_trace()
|
||||
raise ValueError("verts_uvs and faces_uvs must be on the same device")
|
||||
|
||||
# These values may be overridden when textures is
|
||||
@ -669,9 +652,6 @@ class TexturesUV(TexturesBase):
|
||||
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__()
|
||||
|
||||
def clone(self):
|
||||
tex = self.__class__(
|
||||
self.maps_padded().clone(),
|
||||
@ -759,12 +739,6 @@ class TexturesUV(TexturesBase):
|
||||
)
|
||||
return self._faces_uvs_list
|
||||
|
||||
def faces_uvs_packed(self) -> torch.Tensor:
|
||||
if self.isempty():
|
||||
return torch.zeros((self._N, 3), dtype=torch.float32, device=self.device)
|
||||
faces_uvs_list = self.faces_uvs_list()
|
||||
return list_to_packed(faces_uvs_list)[0]
|
||||
|
||||
def verts_uvs_padded(self) -> torch.Tensor:
|
||||
if self._verts_uvs_padded is None:
|
||||
if self.isempty():
|
||||
@ -789,12 +763,6 @@ class TexturesUV(TexturesBase):
|
||||
)
|
||||
return self._verts_uvs_list
|
||||
|
||||
def verts_uvs_packed(self) -> torch.Tensor:
|
||||
if self.isempty():
|
||||
return torch.zeros((self._N, 2), dtype=torch.float32, device=self.device)
|
||||
verts_uvs_list = self.verts_uvs_list()
|
||||
return list_to_packed(verts_uvs_list)[0]
|
||||
|
||||
# Currently only the padded maps are used.
|
||||
def maps_padded(self) -> torch.Tensor:
|
||||
return self._maps_padded
|
||||
@ -850,9 +818,15 @@ class TexturesUV(TexturesBase):
|
||||
texels: tensor of shape (N, H, W, K, C) giving the interpolated
|
||||
texture for each pixel in the rasterized image.
|
||||
"""
|
||||
verts_uvs = self.verts_uvs_packed()
|
||||
faces_uvs = self.faces_uvs_packed()
|
||||
faces_verts_uvs = verts_uvs[faces_uvs]
|
||||
if self.isempty():
|
||||
faces_verts_uvs = torch.zeros(
|
||||
(self._N, 3, 2), dtype=torch.float32, device=self.device
|
||||
)
|
||||
else:
|
||||
packing_list = [
|
||||
i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
|
||||
]
|
||||
faces_verts_uvs = torch.cat(packing_list)
|
||||
texture_maps = self.maps_padded()
|
||||
|
||||
# pixel_uvs: (N, H, W, K, 2)
|
||||
@ -890,6 +864,7 @@ class TexturesUV(TexturesBase):
|
||||
if texture_maps.device != pixel_uvs.device:
|
||||
texture_maps = texture_maps.to(pixel_uvs.device)
|
||||
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
|
||||
# texels now has shape (NK, C, H_out, W_out)
|
||||
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
|
||||
return texels
|
||||
|
||||
@ -990,9 +965,6 @@ class TexturesVertex(TexturesBase):
|
||||
# refer to the __init__ of Meshes.
|
||||
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
|
||||
|
||||
def __repr__(self):
|
||||
return super().__repr__()
|
||||
|
||||
def clone(self):
|
||||
tex = self.__class__(self.verts_features_padded().clone())
|
||||
if self._verts_features_list is not None:
|
||||
@ -1048,7 +1020,7 @@ class TexturesVertex(TexturesBase):
|
||||
if self._verts_features_list is None:
|
||||
if self.isempty():
|
||||
self._verts_features_list = [
|
||||
torch.empty((0, 3, 0), dtype=torch.float32, device=self.device)
|
||||
torch.empty((0, 3), dtype=torch.float32, device=self.device)
|
||||
] * self._N
|
||||
else:
|
||||
self._verts_features_list = padded_to_list(
|
||||
|
@ -233,9 +233,9 @@ class Meshes(object):
|
||||
Refer to comments above for descriptions of List and Padded representations.
|
||||
"""
|
||||
self.device = None
|
||||
if textures is not None and not repr(textures) == "TexturesBase":
|
||||
if textures is not None and not hasattr(textures, "sample_textures"):
|
||||
msg = "Expected textures to be an instance of type TexturesBase; got %r"
|
||||
raise ValueError(msg % repr(textures))
|
||||
raise ValueError(msg % type(textures))
|
||||
self.textures = textures
|
||||
|
||||
# Indicates whether the meshes in the list/batch have the same number
|
||||
|
@ -33,8 +33,9 @@ from pytorch3d.renderer.mesh.shader import (
|
||||
SoftSilhouetteShader,
|
||||
TexturedSoftPhongShader,
|
||||
)
|
||||
from pytorch3d.structures.meshes import Meshes, join_mesh
|
||||
from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
from pytorch3d.utils.torus import torus
|
||||
|
||||
|
||||
# If DEBUG=True, save out images generated in the tests for debugging.
|
||||
@ -490,6 +491,86 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
|
||||
def test_batch_uvs(self):
|
||||
"""Test that two random tori with TexturesUV render the same as each individually."""
|
||||
torch.manual_seed(1)
|
||||
device = torch.device("cuda:0")
|
||||
plain_torus = torus(r=1, R=4, sides=10, rings=10, device=device)
|
||||
[verts] = plain_torus.verts_list()
|
||||
[faces] = plain_torus.faces_list()
|
||||
nocolor = torch.zeros((100, 100), device=device)
|
||||
color_gradient = torch.linspace(0, 1, steps=100, device=device)
|
||||
color_gradient1 = color_gradient[None].expand_as(nocolor)
|
||||
color_gradient2 = color_gradient[:, None].expand_as(nocolor)
|
||||
colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2)
|
||||
colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2)
|
||||
verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device)
|
||||
verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device)
|
||||
|
||||
textures1 = TexturesUV(
|
||||
maps=[colors1], faces_uvs=[faces], verts_uvs=[verts_uvs1]
|
||||
)
|
||||
textures2 = TexturesUV(
|
||||
maps=[colors2], faces_uvs=[faces], verts_uvs=[verts_uvs2]
|
||||
)
|
||||
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
|
||||
mesh2 = Meshes(verts=[verts], faces=[faces], textures=textures2)
|
||||
mesh_both = join_meshes_as_batch([mesh1, mesh2])
|
||||
|
||||
R, T = look_at_view_transform(10, 10, 0)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=128, blur_radius=0.0, faces_per_pixel=1
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
lights = PointLights(device=device)
|
||||
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||
|
||||
blend_params = BlendParams(
|
||||
sigma=1e-1,
|
||||
gamma=1e-4,
|
||||
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
|
||||
)
|
||||
# Init renderer
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
shader=HardPhongShader(
|
||||
device=device, lights=lights, cameras=cameras, blend_params=blend_params
|
||||
),
|
||||
)
|
||||
|
||||
outputs = []
|
||||
for meshes in [mesh_both, mesh1, mesh2]:
|
||||
outputs.append(renderer(meshes))
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray(
|
||||
(outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_batch_uvs0.png")
|
||||
Image.fromarray(
|
||||
(outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_batch_uvs1.png")
|
||||
Image.fromarray(
|
||||
(outputs[0][1, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_batch_uvs2.png")
|
||||
Image.fromarray(
|
||||
(outputs[2][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
|
||||
).save(DATA_DIR / "test_batch_uvs3.png")
|
||||
|
||||
diff = torch.abs(outputs[0][0, ..., :3] - outputs[1][0, ..., :3])
|
||||
Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save(
|
||||
DATA_DIR / "test_batch_uvs01.png"
|
||||
)
|
||||
diff = torch.abs(outputs[0][1, ..., :3] - outputs[2][0, ..., :3])
|
||||
Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save(
|
||||
DATA_DIR / "test_batch_uvs23.png"
|
||||
)
|
||||
|
||||
self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5)
|
||||
self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5)
|
||||
|
||||
def test_joined_spheres(self):
|
||||
"""
|
||||
Test a list of Meshes can be joined as a single mesh and
|
||||
|
@ -29,8 +29,8 @@ def tryindex(self, index, tex, meshes, source):
|
||||
basic = basic[None]
|
||||
|
||||
if len(basic) == 0:
|
||||
self.assertEquals(len(from_texture), 0)
|
||||
self.assertEquals(len(from_meshes), 0)
|
||||
self.assertEqual(len(from_texture), 0)
|
||||
self.assertEqual(len(from_meshes), 0)
|
||||
else:
|
||||
self.assertClose(basic, from_texture)
|
||||
self.assertClose(basic, from_meshes)
|
||||
@ -608,12 +608,8 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
[
|
||||
tex_init.faces_uvs_padded(),
|
||||
new_tex.faces_uvs_padded(),
|
||||
tex_init.faces_uvs_packed(),
|
||||
new_tex.faces_uvs_packed(),
|
||||
tex_init.verts_uvs_padded(),
|
||||
new_tex.verts_uvs_padded(),
|
||||
tex_init.verts_uvs_packed(),
|
||||
new_tex.verts_uvs_packed(),
|
||||
tex_init.maps_padded(),
|
||||
new_tex.maps_padded(),
|
||||
]
|
||||
@ -646,11 +642,9 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
tex1 = tex.clone()
|
||||
tex1._num_faces_per_mesh = num_faces_per_mesh
|
||||
tex1._num_verts_per_mesh = num_verts_per_mesh
|
||||
verts_packed = tex1.verts_uvs_packed()
|
||||
verts_list = tex1.verts_uvs_list()
|
||||
verts_padded = tex1.verts_uvs_padded()
|
||||
|
||||
faces_packed = tex1.faces_uvs_packed()
|
||||
faces_list = tex1.faces_uvs_list()
|
||||
faces_padded = tex1.faces_uvs_padded()
|
||||
|
||||
@ -660,9 +654,7 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
for f1, f2 in zip(verts_list, verts_uvs_list):
|
||||
self.assertTrue((f1 == f2).all().item())
|
||||
|
||||
self.assertTrue(faces_packed.shape == (3 + 2, 3))
|
||||
self.assertTrue(faces_padded.shape == (2, 3, 3))
|
||||
self.assertTrue(verts_packed.shape == (9 + 6, 2))
|
||||
self.assertTrue(verts_padded.shape == (2, 9, 2))
|
||||
|
||||
# Case where num_faces_per_mesh is not set and faces_verts_uvs
|
||||
@ -672,16 +664,9 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
|
||||
verts_uvs=verts_padded,
|
||||
faces_uvs=faces_padded,
|
||||
)
|
||||
faces_packed = tex2.faces_uvs_packed()
|
||||
faces_list = tex2.faces_uvs_list()
|
||||
verts_packed = tex2.verts_uvs_packed()
|
||||
verts_list = tex2.verts_uvs_list()
|
||||
|
||||
# Packed is just flattened padded as num_faces_per_mesh
|
||||
# has not been provided.
|
||||
self.assertTrue(faces_packed.shape == (3 * 2, 3))
|
||||
self.assertTrue(verts_packed.shape == (9 * 2, 2))
|
||||
|
||||
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
|
||||
n = num_faces_per_mesh[i]
|
||||
self.assertTrue((f1[:n] == f2).all().item())
|
||||
|
Loading…
x
Reference in New Issue
Block a user