SplatterBlender follow-ups

Summary: A few minor additions I didn't fit into the SplatterBlender diffs, as requested by reviewers.

Reviewed By: jcjohnson

Differential Revision: D36682437

fbshipit-source-id: 57af995e766dfd2674b3984a3ba00aef7ca7db80
This commit is contained in:
Krzysztof Chalupka
2022-05-26 13:03:57 -07:00
committed by Facebook GitHub Bot
parent c31bf85a23
commit a42a89a5ba
5 changed files with 36 additions and 53 deletions

View File

@@ -14,6 +14,7 @@ from pytorch3d.renderer.mesh.shader import (
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
SplatterPhongShader,
)
from pytorch3d.structures.meshes import Meshes
@@ -21,20 +22,22 @@ from .common_testing import TestCaseMixin
class TestShader(TestCaseMixin, unittest.TestCase):
def setUp(self):
self.shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
SplatterPhongShader,
]
def test_to(self):
cpu_device = torch.device("cpu")
cuda_device = torch.device("cuda:0")
R, T = look_at_view_transform()
shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
]
for shader_class in shader_classes:
for shader_class in self.shader_classes:
for cameras_class in (None, PerspectiveCameras):
if cameras_class is None:
cameras = None
@@ -53,8 +56,11 @@ class TestShader(TestCaseMixin, unittest.TestCase):
self.assertIs(cpu_shader, cuda_shader)
if cameras is None:
self.assertIsNone(cuda_shader.cameras)
with self.assertRaisesRegexp(ValueError, "Cameras must be"):
cuda_shader._get_cameras()
else:
self.assertEqual(cuda_device, cuda_shader.cameras.device)
self.assertIsInstance(cuda_shader._get_cameras(), cameras_class)
self.assertEqual(cuda_device, cuda_shader.materials.device)
self.assertEqual(cuda_device, cuda_shader.lights.device)
@@ -75,14 +81,8 @@ class TestShader(TestCaseMixin, unittest.TestCase):
zbuf=torch.ones_like(pix_to_face),
dists=torch.ones_like(pix_to_face),
)
shader_classes = [
HardFlatShader,
HardGouraudShader,
HardPhongShader,
SoftPhongShader,
]
for shader_class in shader_classes:
for shader_class in self.shader_classes:
shader = shader_class()
with self.assertRaises(ValueError):