fix test_config_use

Summary: Fixes to reenable test_create_gm_overrides. Followup from D35852367 (47d06c8924) using logic from D36349361 (9e57b994ca).

Reviewed By: shapovalov

Differential Revision: D36371762

fbshipit-source-id: ad5fbbb4b5729fac41980d118f17a2589f7e6aba
This commit is contained in:
Jeremy Reizenstein 2022-05-13 07:15:26 -07:00 committed by Facebook GitHub Bot
parent 2c1901522a
commit b5f3d3ce12
2 changed files with 11 additions and 13 deletions

View File

@ -15,6 +15,7 @@ tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
raysampler_class_type: AdaptiveRaySampler
renderer_class_type: LSTMRenderer
image_feature_extractor_enabled: true
view_pooler_enabled: true
@ -50,23 +51,21 @@ sequence_autodecoder_args:
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_args:
raysampler_AdaptiveRaySampler_args:
image_width: 400
image_height: 400
scene_center:
- 0.0
- 0.0
- 0.0
scene_extent: 0.0
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.1
max_depth: 8.0
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
scene_extent: 8.0
scene_center:
- 0.0
- 0.0
- 0.0
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0

View File

@ -32,9 +32,9 @@ from pytorch3d.implicitron.tools.config import (
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
from .common_resources import provide_lpips_vgg
from .common_resources import provide_resnet34
else:
from common_resources import provide_lpips_vgg # noqa
from common_resources import provide_resnet34
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
@ -48,7 +48,6 @@ class TestGenericModel(unittest.TestCase):
self.maxDiff = None
def test_create_gm(self):
provide_lpips_vgg()
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
@ -60,8 +59,8 @@ class TestGenericModel(unittest.TestCase):
self.assertIsNone(gm.view_pooler)
self.assertIsNone(gm.image_feature_extractor)
def _test_create_gm_overrides(self):
provide_lpips_vgg()
def test_create_gm_overrides(self):
provide_resnet34()
args = get_default_args(GenericModel)
args.view_pooler_enabled = True
args.image_feature_extractor_enabled = True