Texturing API updates

Summary:
A fairly big refactor of the texturing API with some breaking changes to how textures are defined.

Main changes:
- There are now 3 types of texture classes: `TexturesUV`, `TexturesAtlas` and `TexturesVertex`. Each class:
   - has a `sample_textures` function which accepts the `fragments` from rasterization and returns `texels`. This means that the shaders will not need to know the type of the mesh texture which will resolve several issues people were reporting on GitHub.
  -  has a `join_batch` method for joining multiple textures of the same type into a batch

Reviewed By: gkioxari

Differential Revision: D21067427

fbshipit-source-id: 4b346500a60181e72fdd1b0dd89b5505c7a33926
This commit is contained in:
Nikhila Ravi
2020-07-29 16:06:58 -07:00
committed by Facebook GitHub Bot
parent b73d3d6ed9
commit a3932960b3
19 changed files with 1872 additions and 785 deletions

View File

@@ -8,9 +8,9 @@ from pytorch3d.ops.interp_face_attrs import (
interpolate_face_attributes,
interpolate_face_attributes_python,
)
from pytorch3d.renderer.mesh import TexturesVertex
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.texturing import interpolate_vertex_colors
from pytorch3d.structures import Meshes, Textures
from pytorch3d.structures import Meshes
class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
@@ -96,16 +96,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
self.assertClose(grad_face_attrs_py, grad_face_attrs_cu, rtol=1e-3)
def test_interpolate_attributes(self):
"""
This tests both interpolate_vertex_colors as well as
interpolate_face_attributes.
"""
verts = torch.randn((4, 3), dtype=torch.float32)
faces = torch.tensor([[2, 1, 0], [3, 1, 0]], dtype=torch.int64)
vert_tex = torch.tensor(
[[0, 1, 0], [0, 1, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.float32
)
tex = Textures(verts_rgb=vert_tex[None, :])
tex = TexturesVertex(verts_features=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
@@ -120,7 +116,13 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
texels = interpolate_vertex_colors(fragments, mesh)
verts_features_packed = mesh.textures.verts_features_packed()
faces_verts_features = verts_features_packed[mesh.faces_packed()]
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
)
self.assertTrue(torch.allclose(texels, expected_vals[None, :]))
def test_interpolate_attributes_grad(self):
@@ -131,7 +133,7 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
dtype=torch.float32,
requires_grad=True,
)
tex = Textures(verts_rgb=vert_tex[None, :])
tex = TexturesVertex(verts_features=vert_tex[None, :])
mesh = Meshes(verts=[verts], faces=[faces], textures=tex)
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
@@ -147,7 +149,12 @@ class TestInterpolateFaceAttributes(TestCaseMixin, unittest.TestCase):
[[0.3, 0.3, 0.3], [0.9, 0.9, 0.9], [0.5, 0.5, 0.5], [0.3, 0.3, 0.3]],
dtype=torch.float32,
)
texels = interpolate_vertex_colors(fragments, mesh)
verts_features_packed = mesh.textures.verts_features_packed()
faces_verts_features = verts_features_packed[mesh.faces_packed()]
texels = interpolate_face_attributes(
fragments.pix_to_face, fragments.bary_coords, faces_verts_features
)
texels.sum().backward()
self.assertTrue(hasattr(vert_tex, "grad"))
self.assertTrue(torch.allclose(vert_tex.grad, grad_vert_tex[None, :]))