mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
ViewPooler class
Summary: Implements a ViewPooler that groups ViewSampler and FeatureAggregator. Reviewed By: shapovalov Differential Revision: D35852367 fbshipit-source-id: c1bcaf5a1f826ff94efce53aa5836121ad9c50ec
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bef959c755
commit
47d06c8924
@@ -8,7 +8,6 @@ bg_color:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
view_pool: false
|
||||
num_passes: 1
|
||||
chunk_size_grid: 4096
|
||||
render_features_dimensions: 3
|
||||
@@ -17,7 +16,8 @@ n_train_target_views: 1
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
renderer_class_type: LSTMRenderer
|
||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||
image_feature_extractor_enabled: true
|
||||
view_pooler_enabled: true
|
||||
implicit_function_class_type: IdrFeatureField
|
||||
loss_weights:
|
||||
loss_rgb_mse: 1.0
|
||||
@@ -91,15 +91,17 @@ image_feature_extractor_args:
|
||||
add_images: true
|
||||
global_average_pool: false
|
||||
feature_rescale: 1.0
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
sampling_mode: bilinear
|
||||
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
weight_by_ray_angle_gamma: 1.0
|
||||
min_ray_angle_weight: 0.1
|
||||
view_pooler_args:
|
||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
sampling_mode: bilinear
|
||||
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
weight_by_ray_angle_gamma: 1.0
|
||||
min_ray_angle_weight: 0.1
|
||||
implicit_function_IdrFeatureField_args:
|
||||
feature_vector_size: 3
|
||||
d_in: 3
|
||||
|
||||
@@ -20,9 +20,8 @@ from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
|
||||
from pytorch3d.implicitron.models.view_pooler.feature_aggregator import (
|
||||
AngleWeightedIdentityFeatureAggregator,
|
||||
AngleWeightedReductionFeatureAggregator,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
get_default_args,
|
||||
@@ -32,7 +31,10 @@ from pytorch3d.implicitron.tools.config import (
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import get_tests_dir
|
||||
|
||||
from .common_resources import provide_lpips_vgg
|
||||
else:
|
||||
from common_resources import provide_lpips_vgg # noqa
|
||||
from tests.common_testing import get_tests_dir
|
||||
|
||||
DATA_DIR = get_tests_dir() / "implicitron/data"
|
||||
@@ -46,28 +48,33 @@ class TestGenericModel(unittest.TestCase):
|
||||
self.maxDiff = None
|
||||
|
||||
def test_create_gm(self):
|
||||
provide_lpips_vgg()
|
||||
args = get_default_args(GenericModel)
|
||||
gm = GenericModel(**args)
|
||||
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
|
||||
self.assertIsInstance(
|
||||
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
|
||||
)
|
||||
self.assertIsInstance(
|
||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||
)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
self.assertFalse(hasattr(gm, "image_feature_extractor"))
|
||||
self.assertIsNone(gm.view_pooler)
|
||||
self.assertIsNone(gm.image_feature_extractor)
|
||||
|
||||
def test_create_gm_overrides(self):
|
||||
def _test_create_gm_overrides(self):
|
||||
provide_lpips_vgg()
|
||||
args = get_default_args(GenericModel)
|
||||
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
|
||||
args.view_pooler_enabled = True
|
||||
args.image_feature_extractor_enabled = True
|
||||
args.view_pooler_args.feature_aggregator_class_type = (
|
||||
"AngleWeightedIdentityFeatureAggregator"
|
||||
)
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
args.renderer_class_type = "LSTMRenderer"
|
||||
gm = GenericModel(**args)
|
||||
self.assertIsInstance(gm.renderer, LSTMRenderer)
|
||||
self.assertIsInstance(
|
||||
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
|
||||
gm.view_pooler.feature_aggregator,
|
||||
AngleWeightedIdentityFeatureAggregator,
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
|
||||
@@ -56,7 +56,12 @@ class TestGenericModel(unittest.TestCase):
|
||||
cfg = _load_model_config_from_yaml(str(config_file))
|
||||
model = GenericModel(**cfg)
|
||||
model.to(device)
|
||||
self._one_model_test(model, device, eval_test=True)
|
||||
self._one_model_test(
|
||||
model,
|
||||
device,
|
||||
eval_test=True,
|
||||
bw_test=True,
|
||||
)
|
||||
|
||||
def _one_model_test(
|
||||
self,
|
||||
@@ -64,6 +69,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
device,
|
||||
n_train_cameras: int = 5,
|
||||
eval_test: bool = True,
|
||||
bw_test: bool = True,
|
||||
):
|
||||
|
||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||
@@ -86,8 +92,12 @@ class TestGenericModel(unittest.TestCase):
|
||||
**random_args,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
train_preds["objective"].backward()
|
||||
self.assertTrue(
|
||||
train_preds["objective"].isfinite().item()
|
||||
) # check finiteness of the objective
|
||||
|
||||
if bw_test:
|
||||
train_preds["objective"].backward()
|
||||
|
||||
if eval_test:
|
||||
model.eval()
|
||||
|
||||
@@ -9,7 +9,7 @@ import unittest
|
||||
|
||||
import pytorch3d as pt3d
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
|
||||
from pytorch3d.implicitron.models.view_pooler.view_sampler import ViewSampler
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user