diff --git a/tests/test_render_multigpu.py b/tests/test_render_multigpu.py index 40ff9f1d..a2bf7289 100644 --- a/tests/test_render_multigpu.py +++ b/tests/test_render_multigpu.py @@ -159,8 +159,8 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase): verts = ico_sphere(3).verts_padded() texs = verts.new_ones(verts.shape) model = Model() + model.to(GPU_LIST[0]) model = nn.DataParallel(model, device_ids=GPU_LIST) - model.to(f"cuda:{model.device_ids[0]}") # Test a few iterations for _ in range(100):