mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
1GPU for implicitron tests
Reviewed By: shapovalov Differential Revision: D38794764 fbshipit-source-id: 140c8a935d760bab8569d903cc52ac3dd73cd553
This commit is contained in:
parent
7623457686
commit
d42e0d3d86
@ -32,7 +32,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
|
||||
def test_gm(self):
|
||||
# Simple test of a forward and backward pass of the default GenericModel.
|
||||
device = torch.device("cuda:1")
|
||||
device = torch.device("cuda:0")
|
||||
expand_args_fields(GenericModel)
|
||||
model = GenericModel(render_image_height=80, render_image_width=80)
|
||||
model.to(device)
|
||||
@ -116,7 +116,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
|
||||
def test_idr(self):
|
||||
# Forward pass of GenericModel with IDR.
|
||||
device = torch.device("cuda:1")
|
||||
device = torch.device("cuda:0")
|
||||
args = get_default_args(GenericModel)
|
||||
args.renderer_class_type = "SignedDistanceFunctionRenderer"
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
@ -153,7 +153,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
|
||||
def test_viewpool(self):
|
||||
device = torch.device("cuda:1")
|
||||
device = torch.device("cuda:0")
|
||||
args = get_default_args(GenericModel)
|
||||
args.view_pooler_enabled = True
|
||||
args.image_feature_extractor_class_type = "ResNetFeatureExtractor"
|
||||
|
Loading…
x
Reference in New Issue
Block a user