mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-02 10:15:59 +08:00
ViewPooler class
Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator. Reviewed By: shapovalov Differential Revision: D35852367 fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bef959c755
commit
47d06c8924
@@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase):
|
||||
cfg = _load_model_config_from_yaml(str(config_file))
|
||||
model = GenericModel(**cfg)
|
||||
model.to(device)
|
||||
self._one_model_test(model, device, eval_test=True)
|
||||
self._one_model_test(
|
||||
model,
|
||||
device,
|
||||
eval_test=True,
|
||||
bw_test=True,
|
||||
)
|
||||
|
||||
def _one_model_test(
|
||||
self,
|
||||
@@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
device,
|
||||
n_train_cameras: int = 5,
|
||||
eval_test: bool = True,
|
||||
bw_test: bool = True,
|
||||
):
|
||||
|
||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||
@@ -86,8 +92,12 @@ class TestGenericModel(unittest.TestCase):
|
||||
**random_args,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
train_preds["objective"].backward()
|
||||
self.assertTrue(
|
||||
train_preds["objective"].isfinite().item()
|
||||
) # check finiteness of the objective
|
||||
|
||||
if bw_test:
|
||||
train_preds["objective"].backward()
|
||||
|
||||
if eval_test:
|
||||
model.eval()
|
||||
|
||||
Reference in New Issue
Block a user