mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Texture Atlas sampling bug fix
Summary: Fixes the index out of bound errors for texture sampling from a texture atlas: when barycentric coordinates are 1.0, the integer index into the (R, R) per face texture map is R (max can only be R-1). Reviewed By: gkioxari Differential Revision: D25543803 fbshipit-source-id: 82d0935b981352b49c1d95d5a17f9cc88bad0a82
This commit is contained in:
parent
3d769a66cb
commit
01759d8ffb
@ -84,7 +84,8 @@ For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturi
|
||||
|
||||
1. **Vertex Textures**: D dimensional textures for each vertex (for example an RGB color) which can be interpolated across the face. This can be represented as an `(N, V, D)` tensor. This is a fairly simple representation though and cannot model complex textures if the mesh faces are large.
|
||||
2. **UV Textures**: vertex UV coordinates and **one** texture map for the whole mesh. For a point on a face with given barycentric coordinates, the face color can be computed by interpolating the vertex uv coordinates and then sampling from the texture map. This representation requires two tensors (UVs: `(N, V, 2), Texture map: `(N, H, W, 3)`), and is limited to only support one texture map per mesh.
|
||||
3. **Face Textures**: In more complex cases such as ShapeNet meshes, there are multiple texture maps per mesh and some faces have texture while other do not. For these cases, a more flexible representation is a texture atlas, where each face is represented as an `(RxR)` texture map where R is the texture resolution. For a given point on the face, the texture value can be sampled from the per face texture map using the barycentric coordinates of the point. This representation requires one tensor of shape `(N, F, R, R, 3)`. This texturing method is inspired by the SoftRasterizer implementation. For more details refer to the [`make_material_atlas`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/io/mtl_io.py#L123) and [`sample_textures`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/textures.py#L452) functions.
|
||||
3. **Face Textures**: In more complex cases such as ShapeNet meshes, there are multiple texture maps per mesh and some faces have texture while other do not. For these cases, a more flexible representation is a texture atlas, where each face is represented as an `(RxR)` texture map where R is the texture resolution. For a given point on the face, the texture value can be sampled from the per face texture map using the barycentric coordinates of the point. This representation requires one tensor of shape `(N, F, R, R, 3)`. This texturing method is inspired by the SoftRasterizer implementation. For more details refer to the [`make_material_atlas`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/io/mtl_io.py#L123) and [`sample_textures`](https://github.com/facebookresearch/pytorch3d/blob/master/pytorch3d/renderer/mesh/textures.py#L452) functions. **NOTE:**: The `TextureAtlas` texture sampling is only differentiable with respect to the texture atlas but not differentiable with respect to the barycentric coordinates.
|
||||
|
||||
|
||||
<img src="assets/texturing.jpg" width="1000">
|
||||
|
||||
|
@ -479,6 +479,18 @@ class TexturesAtlas(TexturesBase):
|
||||
|
||||
def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
|
||||
"""
|
||||
This is similar to a nearest neighbor sampling and involves a
|
||||
discretization step. The barycentric coordinates from
|
||||
rasterization are used to find the nearest grid cell in the texture
|
||||
atlas and the RGB is returned as the color.
|
||||
This means that this step is differentiable with respect to the RGB
|
||||
values of the texture atlas but not differentiable with respect to the
|
||||
barycentric coordinates.
|
||||
|
||||
TODO: Add a different sampling mode which interpolates the barycentric
|
||||
coordinates to sample the texture and will be differentiable w.r.t
|
||||
the barycentric coordinates.
|
||||
|
||||
Args:
|
||||
fragments:
|
||||
The outputs of rasterization. From this we use
|
||||
@ -504,7 +516,10 @@ class TexturesAtlas(TexturesBase):
|
||||
# pyre-fixme[16]: `bool` has no attribute `__getitem__`.
|
||||
mask = (pix_to_face < 0)[..., None]
|
||||
bary_w01 = torch.where(mask, torch.zeros_like(bary_w01), bary_w01)
|
||||
w_xy = (bary_w01 * R).to(torch.int64) # (N, H, W, K, 2)
|
||||
# If barycentric coordinates are > 1.0 (in the case of
|
||||
# blur_radius > 0.0), wxy might be > R. We need to clamp this
|
||||
# index to R-1 to index into the texture atlas.
|
||||
w_xy = (bary_w01 * R).to(torch.int64).clamp(max=R - 1) # (N, H, W, K, 2)
|
||||
|
||||
below_diag = (
|
||||
bary_w01.sum(dim=-1) * R - w_xy.float().sum(dim=-1)
|
||||
|
@ -956,6 +956,7 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
def test_texture_map_atlas(self):
|
||||
"""
|
||||
Test a mesh with a texture map as a per face atlas is loaded and rendered correctly.
|
||||
Also check that the backward pass for texture atlas rendering is differentiable.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
obj_dir = Path(__file__).resolve().parent.parent / "docs/tutorials/data"
|
||||
@ -970,10 +971,11 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
texture_atlas_size=8,
|
||||
texture_wrap=None,
|
||||
)
|
||||
atlas = aux.texture_atlas
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces.verts_idx],
|
||||
textures=TexturesAtlas(atlas=[aux.texture_atlas]),
|
||||
textures=TexturesAtlas(atlas=[atlas]),
|
||||
)
|
||||
|
||||
# Init rasterizer settings
|
||||
@ -981,7 +983,10 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
|
||||
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512, blur_radius=0.0, faces_per_pixel=1, cull_backfaces=True
|
||||
image_size=512,
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
cull_backfaces=True,
|
||||
)
|
||||
|
||||
# Init shader settings
|
||||
@ -993,23 +998,52 @@ class TestRenderMeshes(TestCaseMixin, unittest.TestCase):
|
||||
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
|
||||
|
||||
# The HardPhongShader can be used directly with atlas textures.
|
||||
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
|
||||
rasterizer=rasterizer,
|
||||
shader=HardPhongShader(lights=lights, cameras=cameras, materials=materials),
|
||||
)
|
||||
|
||||
images = renderer(mesh)
|
||||
rgb = images[0, ..., :3].squeeze().cpu()
|
||||
rgb = images[0, ..., :3].squeeze()
|
||||
|
||||
# Load reference image
|
||||
image_ref = load_rgb_image("test_texture_atlas_8x8_back.png", DATA_DIR)
|
||||
|
||||
if DEBUG:
|
||||
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
|
||||
Image.fromarray((rgb.detach().cpu().numpy() * 255).astype(np.uint8)).save(
|
||||
DATA_DIR / "DEBUG_texture_atlas_8x8_back.png"
|
||||
)
|
||||
|
||||
self.assertClose(rgb, image_ref, atol=0.05)
|
||||
self.assertClose(rgb.cpu(), image_ref, atol=0.05)
|
||||
|
||||
# Check gradients are propagated
|
||||
# correctly back to the texture atlas.
|
||||
# Because of how texture sampling is implemented
|
||||
# for the texture atlas it is not possible to get
|
||||
# gradients back to the vertices.
|
||||
atlas.requires_grad = True
|
||||
mesh = Meshes(
|
||||
verts=[verts],
|
||||
faces=[faces.verts_idx],
|
||||
textures=TexturesAtlas(atlas=[atlas]),
|
||||
)
|
||||
raster_settings = RasterizationSettings(
|
||||
image_size=512,
|
||||
blur_radius=0.0001,
|
||||
faces_per_pixel=5,
|
||||
cull_backfaces=True,
|
||||
clip_barycentric_coords=True,
|
||||
)
|
||||
images = renderer(mesh, raster_settings=raster_settings)
|
||||
images[0, ...].sum().backward()
|
||||
|
||||
fragments = rasterizer(mesh, raster_settings=raster_settings)
|
||||
# Some of the bary coordinates are outisde the
|
||||
# [0, 1] range as expected because the blur is > 0
|
||||
self.assertTrue(fragments.bary_coords.ge(1.0).any())
|
||||
self.assertIsNotNone(atlas.grad)
|
||||
self.assertTrue(atlas.grad.sum().abs() > 0.0)
|
||||
|
||||
def test_simple_sphere_outside_zfar(self):
|
||||
"""
|
||||
|
Loading…
x
Reference in New Issue
Block a user