diff --git a/tests/test_splatter_blend.py b/tests/test_splatter_blend.py index caaa5d5f..e90bc152 100644 --- a/tests/test_splatter_blend.py +++ b/tests/test_splatter_blend.py @@ -7,7 +7,7 @@ import unittest import torch -from common_testing import TestCaseMixin +from pytorch3d.common.compat import meshgrid_ij from pytorch3d.renderer.cameras import FoVPerspectiveCameras from pytorch3d.renderer.splatter_blend import ( _compute_occlusion_layers, @@ -20,6 +20,8 @@ from pytorch3d.renderer.splatter_blend import ( _prepare_pixels_and_colors, ) +from .common_testing import TestCaseMixin + offsets = torch.tensor( [ [-1, -1], @@ -248,15 +250,13 @@ class TestComputeSplattingColorsAndWeights(TestCaseMixin, unittest.TestCase): def setUp(self): self.N, self.H, self.W, self.K = 2, 3, 4, 5 self.pixel_coords_screen = ( - torch.tile( - torch.stack( - torch.meshgrid( - torch.arange(self.H), torch.arange(self.W), indexing="ij" - ), - dim=-1, - ).reshape(1, self.H, self.W, 1, 2), - (self.N, 1, 1, self.K, 1), - ).float() + torch.stack( + meshgrid_ij(torch.arange(self.H), torch.arange(self.W)), + dim=-1, + ) + .reshape(1, self.H, self.W, 1, 2) + .expand(self.N, self.H, self.W, self.K, 2) + .float() + 0.5 ) self.colors = torch.ones((self.N, self.H, self.W, self.K, 4))