Increase code coverage of shader

Summary: Increase code coverage of shader and re-include them in code coverage test

Reviewed By: nikhilaravi

Differential Revision: D29097503

fbshipit-source-id: 2791989ee1562cfa193f3addea0ce72d6840614a
This commit is contained in:
Patrick Labatut 2021-06-17 01:34:49 -07:00 committed by Facebook GitHub Bot
parent c75ca04cf7
commit a8610e9da4
2 changed files with 104 additions and 12 deletions

View File

@ -26,7 +26,7 @@ from .shading import flat_shading, gouraud_shading, phong_shading
# - blend colors across top K faces per pixel. # - blend colors across top K faces per pixel.
class HardPhongShader(nn.Module): # pragma: no cover class HardPhongShader(nn.Module):
""" """
Per pixel lighting - the lighting model is applied using the interpolated Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function hard assigns coordinates and normals for each pixel. The blending function hard assigns
@ -58,7 +58,8 @@ class HardPhongShader(nn.Module): # pragma: no cover
def to(self, device: Device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
self.lights = self.lights.to(device) self.lights = self.lights.to(device)
return self return self
@ -86,7 +87,7 @@ class HardPhongShader(nn.Module): # pragma: no cover
return images return images
class SoftPhongShader(nn.Module): # pragma: no cover class SoftPhongShader(nn.Module):
""" """
Per pixel lighting - the lighting model is applied using the interpolated Per pixel lighting - the lighting model is applied using the interpolated
coordinates and normals for each pixel. The blending function returns the coordinates and normals for each pixel. The blending function returns the
@ -118,7 +119,8 @@ class SoftPhongShader(nn.Module): # pragma: no cover
def to(self, device: Device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
self.lights = self.lights.to(device) self.lights = self.lights.to(device)
return self return self
@ -150,7 +152,7 @@ class SoftPhongShader(nn.Module): # pragma: no cover
return images return images
class HardGouraudShader(nn.Module): # pragma: no cover class HardGouraudShader(nn.Module):
""" """
Per vertex lighting - the lighting model is applied to the vertex colors and Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to the colors are then interpolated using the barycentric coordinates to
@ -183,7 +185,8 @@ class HardGouraudShader(nn.Module): # pragma: no cover
def to(self, device: Device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
self.lights = self.lights.to(device) self.lights = self.lights.to(device)
return self return self
@ -214,7 +217,7 @@ class HardGouraudShader(nn.Module): # pragma: no cover
return images return images
class SoftGouraudShader(nn.Module): # pragma: no cover class SoftGouraudShader(nn.Module):
""" """
Per vertex lighting - the lighting model is applied to the vertex colors and Per vertex lighting - the lighting model is applied to the vertex colors and
the colors are then interpolated using the barycentric coordinates to the colors are then interpolated using the barycentric coordinates to
@ -247,7 +250,8 @@ class SoftGouraudShader(nn.Module): # pragma: no cover
def to(self, device: Device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
self.lights = self.lights.to(device) self.lights = self.lights.to(device)
return self return self
@ -277,7 +281,7 @@ class SoftGouraudShader(nn.Module): # pragma: no cover
def TexturedSoftPhongShader( def TexturedSoftPhongShader(
device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None device: Device = "cpu", cameras=None, lights=None, materials=None, blend_params=None
): # pragma: no cover ):
""" """
TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead. TexturedSoftPhongShader class has been DEPRECATED. Use SoftPhongShader instead.
Preserving TexturedSoftPhongShader as a function for backwards compatibility. Preserving TexturedSoftPhongShader as a function for backwards compatibility.
@ -296,7 +300,7 @@ def TexturedSoftPhongShader(
) )
class HardFlatShader(nn.Module): # pragma: no cover class HardFlatShader(nn.Module):
""" """
Per face lighting - the lighting model is applied using the average face Per face lighting - the lighting model is applied using the average face
position and the face normal. The blending function hard assigns position and the face normal. The blending function hard assigns
@ -328,7 +332,8 @@ class HardFlatShader(nn.Module): # pragma: no cover
def to(self, device: Device): def to(self, device: Device):
# Manually move to device modules which are not subclasses of nn.Module # Manually move to device modules which are not subclasses of nn.Module
self.cameras = self.cameras.to(device) if self.cameras is not None:
self.cameras = self.cameras.to(device)
self.materials = self.materials.to(device) self.materials = self.materials.to(device)
self.lights = self.lights.to(device) self.lights = self.lights.to(device)
return self return self
@ -355,7 +360,7 @@ class HardFlatShader(nn.Module): # pragma: no cover
return images return images
class SoftSilhouetteShader(nn.Module): # pragma: no cover class SoftSilhouetteShader(nn.Module):
""" """
Calculate the silhouette by blending the top K faces for each pixel based Calculate the silhouette by blending the top K faces for each pixel based
on the 2d euclidean distance of the center of the pixel to the mesh face. on the 2d euclidean distance of the center of the pixel to the mesh face.

87
tests/test_shader.py Normal file
View File

@ -0,0 +1,87 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import torch
from common_testing import TestCaseMixin
from pytorch3d.renderer.cameras import (
look_at_view_transform,
PerspectiveCameras,
)
from pytorch3d.renderer.mesh.rasterizer import Fragments
from pytorch3d.renderer.mesh.shader import (
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
)
from pytorch3d.structures.meshes import Meshes
class TestShader(TestCaseMixin, unittest.TestCase):
def test_to(self):
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda")
R, T = look_at_view_transform()
shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
]
for shader_class in shader_classes:
for cameras_class in (None, PerspectiveCameras):
if cameras_class is None:
cameras = None
else:
cameras = PerspectiveCameras(device=cpu_device, R=R, T=T)
cpu_shader = shader_class(device=cpu_device, cameras=cameras)
if cameras is None:
self.assertIsNone(cpu_shader.cameras)
else:
self.assertEqual(cpu_device, cpu_shader.cameras.device)
self.assertEqual(cpu_device, cpu_shader.materials.device)
self.assertEqual(cpu_device, cpu_shader.lights.device)
cuda_shader = cpu_shader.to(cuda_device)
self.assertIs(cpu_shader, cuda_shader)
if cameras is None:
self.assertIsNone(cuda_shader.cameras)
else:
self.assertEqual(cuda_device, cuda_shader.cameras.device)
self.assertEqual(cuda_device, cuda_shader.materials.device)
self.assertEqual(cuda_device, cuda_shader.lights.device)
def test_cameras_check(self):
verts = torch.tensor(
[[-1, -1, 0], [1, -1, 1], [1, 1, 0], [-1, 1, 1]], dtype=torch.float32
)
faces = torch.tensor([[0, 1, 2], [2, 3, 0]], dtype=torch.int64)
meshes = Meshes(verts=[verts], faces=[faces])
pix_to_face = torch.tensor([0, 1], dtype=torch.int64).view(1, 1, 1, 2)
barycentric_coords = torch.tensor(
[[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]], dtype=torch.float32
).view(1, 1, 1, 2, -1)
fragments = Fragments(
pix_to_face=pix_to_face,
bary_coords=barycentric_coords,
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
]
for shader_class in shader_classes:
shader = shader_class()
with self.assertRaises(ValueError):
shader(fragments, meshes)