diff --git a/tests/implicitron/test_forward_pass.py b/tests/implicitron/test_forward_pass.py index 285186ff..f7faa6a9 100644 --- a/tests/implicitron/test_forward_pass.py +++ b/tests/implicitron/test_forward_pass.py @@ -32,7 +32,7 @@ class TestGenericModel(unittest.TestCase): def test_gm(self): # Simple test of a forward and backward pass of the default GenericModel. - device = torch.device("cuda:1") + device = torch.device("cuda:0") expand_args_fields(GenericModel) model = GenericModel(render_image_height=80, render_image_width=80) model.to(device) @@ -116,7 +116,7 @@ class TestGenericModel(unittest.TestCase): def test_idr(self): # Forward pass of GenericModel with IDR. - device = torch.device("cuda:1") + device = torch.device("cuda:0") args = get_default_args(GenericModel) args.renderer_class_type = "SignedDistanceFunctionRenderer" args.implicit_function_class_type = "IdrFeatureField" @@ -153,7 +153,7 @@ class TestGenericModel(unittest.TestCase): self.assertGreater(train_preds["objective"].item(), 0) def test_viewpool(self): - device = torch.device("cuda:1") + device = torch.device("cuda:0") args = get_default_args(GenericModel) args.view_pooler_enabled = True args.image_feature_extractor_class_type = "ResNetFeatureExtractor"