mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 13:50:35 +08:00
Make feature extractor pluggable
Summary: Make ResNetFeatureExtractor be an implementation of FeatureExtractorBase. Reviewed By: davnov134 Differential Revision: D35433098 fbshipit-source-id: 0664a9166a88e150231cfe2eceba017ae55aed3a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
cd7b885169
commit
9ec9d057cc
@@ -17,7 +17,7 @@ sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
raysampler_class_type: AdaptiveRaySampler
|
||||
renderer_class_type: LSTMRenderer
|
||||
image_feature_extractor_enabled: true
|
||||
image_feature_extractor_class_type: ResNetFeatureExtractor
|
||||
view_pooler_enabled: true
|
||||
implicit_function_class_type: IdrFeatureField
|
||||
loss_weights:
|
||||
@@ -73,7 +73,7 @@ renderer_LSTMRenderer_args:
|
||||
hidden_size: 16
|
||||
n_feature_channels: 256
|
||||
verbose: false
|
||||
image_feature_extractor_args:
|
||||
image_feature_extractor_ResNetFeatureExtractor_args:
|
||||
name: resnet34
|
||||
pretrained: true
|
||||
stages:
|
||||
|
||||
@@ -9,6 +9,9 @@ import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.feature_extractor.resnet_feature_extractor import (
|
||||
ResNetFeatureExtractor,
|
||||
)
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
@@ -63,7 +66,6 @@ class TestGenericModel(unittest.TestCase):
|
||||
provide_resnet34()
|
||||
args = get_default_args(GenericModel)
|
||||
args.view_pooler_enabled = True
|
||||
args.image_feature_extractor_enabled = True
|
||||
args.view_pooler_args.feature_aggregator_class_type = (
|
||||
"AngleWeightedIdentityFeatureAggregator"
|
||||
)
|
||||
@@ -77,9 +79,13 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertIsInstance(gm.image_feature_extractor, ResNetFeatureExtractor)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
||||
instance_args = OmegaConf.structured(gm)
|
||||
if DEBUG:
|
||||
full_yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
(DATA_DIR / "overrides_full.yaml").write_text(full_yaml)
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestGenericModel(unittest.TestCase):
|
||||
# Simple test of a forward and backward pass of the default GenericModel.
|
||||
device = torch.device("cuda:1")
|
||||
expand_args_fields(GenericModel)
|
||||
model = GenericModel()
|
||||
model = GenericModel(render_image_height=80, render_image_width=80)
|
||||
model.to(device)
|
||||
self._one_model_test(model, device)
|
||||
|
||||
@@ -149,6 +149,37 @@ class TestGenericModel(unittest.TestCase):
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
|
||||
def test_viewpool(self):
|
||||
device = torch.device("cuda:1")
|
||||
args = get_default_args(GenericModel)
|
||||
args.view_pooler_enabled = True
|
||||
args.image_feature_extractor_ResNetFeatureExtractor_args.add_masks = False
|
||||
model = GenericModel(**args)
|
||||
model.to(device)
|
||||
|
||||
n_train_cameras = 2
|
||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
||||
|
||||
defaulted_args = {
|
||||
"fg_probability": None,
|
||||
"depth_map": None,
|
||||
"mask_crop": None,
|
||||
}
|
||||
|
||||
target_image_rgb = torch.rand(
|
||||
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
||||
device=device,
|
||||
)
|
||||
train_preds = model(
|
||||
camera=cameras,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
image_rgb=target_image_rgb,
|
||||
sequence_name=["a"] * n_train_cameras,
|
||||
**defaulted_args,
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
|
||||
|
||||
def _random_input_tensor(
|
||||
N: int,
|
||||
|
||||
Reference in New Issue
Block a user