mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Add the OverfitModel
Summary: Introduces the OverfitModel for NeRF-style training with overfitting to one scene. It is a specific case of GenericModel. It has been disentangle to ease usage. ## General modification 1. Modularize a minimum GenericModel to introduce OverfitModel 2. Introduce OverfitModel and ensure through unit testing that it behaves like GenericModel. ## Modularization The following methods have been extracted from GenericModel to allow modularity with ManyViewModel: - get_objective is now a call to weighted_sum_losses - log_loss_weights - prepare_inputs The generic methods have been moved to an utils.py file. Simplify the code to introduce OverfitModel. Private methods like chunk_generator are now public and can now be used by ManyViewModel. Reviewed By: shapovalov Differential Revision: D43771992 fbshipit-source-id: 6102aeb21c7fdd56aa2ff9cd1dd23fd9fbf26315
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d8b029aae
commit
813e941de5
5
tests/implicitron/models/__init__.py
Normal file
5
tests/implicitron/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
217
tests/implicitron/models/test_overfit_model.py
Normal file
217
tests/implicitron/models/test_overfit_model.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.overfit_model import OverfitModel
|
||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
||||
|
||||
DEVICE = torch.device("cuda:0")
|
||||
|
||||
|
||||
def _generate_fake_inputs(N: int, H: int, W: int) -> Dict[str, Any]:
|
||||
R, T = look_at_view_transform(azim=torch.rand(N) * 360)
|
||||
return {
|
||||
"camera": PerspectiveCameras(R=R, T=T, device=DEVICE),
|
||||
"fg_probability": torch.randint(
|
||||
high=2, size=(N, 1, H, W), device=DEVICE
|
||||
).float(),
|
||||
"depth_map": torch.rand((N, 1, H, W), device=DEVICE) + 0.1,
|
||||
"mask_crop": torch.randint(high=2, size=(N, 1, H, W), device=DEVICE).float(),
|
||||
"sequence_name": ["sequence"] * N,
|
||||
"image_rgb": torch.rand((N, 1, H, W), device=DEVICE),
|
||||
}
|
||||
|
||||
|
||||
def mock_safe_multinomial(input: torch.Tensor, num_samples: int) -> torch.Tensor:
|
||||
"""Return non deterministic indexes to mock safe_multinomial
|
||||
|
||||
Args:
|
||||
input: tensor of shape [B, n] containing non-negative values;
|
||||
rows are interpreted as unnormalized event probabilities
|
||||
in categorical distributions.
|
||||
num_samples: number of samples to take.
|
||||
|
||||
Returns:
|
||||
Tensor of shape [B, num_samples]
|
||||
"""
|
||||
batch_size = input.shape[0]
|
||||
return torch.arange(num_samples).repeat(batch_size, 1).to(DEVICE)
|
||||
|
||||
|
||||
class TestOverfitModel(unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_overfit_model_vs_generic_model_with_batch_size_one(self):
|
||||
"""In this test we compare OverfitModel to GenericModel behavior.
|
||||
|
||||
We use a Nerf setup (2 rendering passes).
|
||||
|
||||
OverfitModel is a specific case of GenericModel. Hence, with the same inputs,
|
||||
they should provide the exact same results.
|
||||
"""
|
||||
expand_args_fields(OverfitModel)
|
||||
expand_args_fields(GenericModel)
|
||||
batch_size, image_height, image_width = 1, 80, 80
|
||||
assert batch_size == 1
|
||||
overfit_model = OverfitModel(
|
||||
render_image_height=image_height,
|
||||
render_image_width=image_width,
|
||||
coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
|
||||
# To avoid randomization to compare the outputs of our model
|
||||
# we deactivate the stratified_point_sampling_training
|
||||
raysampler_AdaptiveRaySampler_args={
|
||||
"stratified_point_sampling_training": False
|
||||
},
|
||||
global_encoder_class_type="SequenceAutodecoder",
|
||||
global_encoder_SequenceAutodecoder_args={
|
||||
"autodecoder_args": {
|
||||
"n_instances": 1000,
|
||||
"init_scale": 1.0,
|
||||
"encoding_dim": 64,
|
||||
}
|
||||
},
|
||||
)
|
||||
generic_model = GenericModel(
|
||||
render_image_height=image_height,
|
||||
render_image_width=image_width,
|
||||
n_train_target_views=batch_size,
|
||||
num_passes=2,
|
||||
# To avoid randomization to compare the outputs of our model
|
||||
# we deactivate the stratified_point_sampling_training
|
||||
raysampler_AdaptiveRaySampler_args={
|
||||
"stratified_point_sampling_training": False
|
||||
},
|
||||
global_encoder_class_type="SequenceAutodecoder",
|
||||
global_encoder_SequenceAutodecoder_args={
|
||||
"autodecoder_args": {
|
||||
"n_instances": 1000,
|
||||
"init_scale": 1.0,
|
||||
"encoding_dim": 64,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Check if they do share the number of parameters
|
||||
num_params_mvm = sum(p.numel() for p in overfit_model.parameters())
|
||||
num_params_gm = sum(p.numel() for p in generic_model.parameters())
|
||||
self.assertEqual(num_params_mvm, num_params_gm)
|
||||
|
||||
# Adapt the mapping from generic model to overfit model
|
||||
mapping_om_from_gm = {
|
||||
key.replace("_implicit_functions.0._fn", "implicit_function").replace(
|
||||
"_implicit_functions.1._fn", "coarse_implicit_function"
|
||||
): val
|
||||
for key, val in generic_model.state_dict().items()
|
||||
}
|
||||
# Copy parameters from generic_model to overfit_model
|
||||
overfit_model.load_state_dict(mapping_om_from_gm)
|
||||
|
||||
overfit_model.to(DEVICE)
|
||||
generic_model.to(DEVICE)
|
||||
inputs_ = _generate_fake_inputs(batch_size, image_height, image_width)
|
||||
|
||||
# training forward pass
|
||||
overfit_model.train()
|
||||
generic_model.train()
|
||||
|
||||
with patch(
|
||||
"pytorch3d.renderer.implicit.raysampling._safe_multinomial",
|
||||
side_effect=mock_safe_multinomial,
|
||||
):
|
||||
train_preds_om = overfit_model(
|
||||
**inputs_,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
)
|
||||
train_preds_gm = generic_model(
|
||||
**inputs_,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
)
|
||||
|
||||
self.assertTrue(len(train_preds_om) == len(train_preds_gm))
|
||||
|
||||
self.assertTrue(train_preds_om["objective"].isfinite().item())
|
||||
# We avoid all the randomization and the weights are the same
|
||||
# The objective should be the same
|
||||
self.assertTrue(
|
||||
torch.allclose(train_preds_om["objective"], train_preds_gm["objective"])
|
||||
)
|
||||
|
||||
# Test if the evaluation works
|
||||
overfit_model.eval()
|
||||
generic_model.eval()
|
||||
with torch.no_grad():
|
||||
eval_preds_om = overfit_model(
|
||||
**inputs_,
|
||||
evaluation_mode=EvaluationMode.EVALUATION,
|
||||
)
|
||||
eval_preds_gm = generic_model(
|
||||
**inputs_,
|
||||
evaluation_mode=EvaluationMode.EVALUATION,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
eval_preds_om["images_render"].shape,
|
||||
(batch_size, 3, image_height, image_width),
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(eval_preds_om["objective"], eval_preds_gm["objective"])
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
eval_preds_om["images_render"], eval_preds_gm["images_render"]
|
||||
)
|
||||
)
|
||||
|
||||
def test_overfit_model_check_share_weights(self):
|
||||
model = OverfitModel(share_implicit_function_across_passes=True)
|
||||
for p1, p2 in zip(
|
||||
model.implicit_function.parameters(),
|
||||
model.coarse_implicit_function.parameters(),
|
||||
):
|
||||
self.assertEqual(id(p1), id(p2))
|
||||
|
||||
model.to(DEVICE)
|
||||
inputs_ = _generate_fake_inputs(2, 80, 80)
|
||||
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
|
||||
|
||||
def test_overfit_model_check_no_share_weights(self):
|
||||
model = OverfitModel(
|
||||
share_implicit_function_across_passes=False,
|
||||
coarse_implicit_function_class_type="NeuralRadianceFieldImplicitFunction",
|
||||
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args={
|
||||
"transformer_dim_down_factor": 1.0,
|
||||
"n_hidden_neurons_xyz": 256,
|
||||
"n_layers_xyz": 8,
|
||||
"append_xyz": (5,),
|
||||
},
|
||||
)
|
||||
for p1, p2 in zip(
|
||||
model.implicit_function.parameters(),
|
||||
model.coarse_implicit_function.parameters(),
|
||||
):
|
||||
self.assertNotEqual(id(p1), id(p2))
|
||||
|
||||
model.to(DEVICE)
|
||||
inputs_ = _generate_fake_inputs(2, 80, 80)
|
||||
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
|
||||
|
||||
def test_overfit_model_coarse_implicit_function_is_none(self):
|
||||
model = OverfitModel(
|
||||
share_implicit_function_across_passes=False,
|
||||
coarse_implicit_function_NeuralRadianceFieldImplicitFunction_args=None,
|
||||
)
|
||||
self.assertIsNone(model.coarse_implicit_function)
|
||||
model.to(DEVICE)
|
||||
inputs_ = _generate_fake_inputs(2, 80, 80)
|
||||
model(**inputs_, evaluation_mode=EvaluationMode.TRAINING)
|
||||
66
tests/implicitron/models/test_utils.py
Normal file
66
tests/implicitron/models/test_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch3d.implicitron.models.utils import preprocess_input, weighted_sum_losses
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_prepare_inputs_wrong_num_dim(self):
|
||||
img = torch.randn(3, 3, 3)
|
||||
with self.assertRaises(ValueError) as context:
|
||||
img, fg_prob, depth_map = preprocess_input(
|
||||
img, None, None, True, True, 0.5, (0.0, 0.0, 0.0)
|
||||
)
|
||||
self.assertEqual(
|
||||
"Model received unbatched inputs. "
|
||||
+ "Perhaps they came from a FrameData which had not been collated.",
|
||||
context.exception,
|
||||
)
|
||||
|
||||
def test_prepare_inputs_mask_image_true(self):
|
||||
batch, channels, height, width = 2, 3, 10, 10
|
||||
img = torch.ones(batch, channels, height, width)
|
||||
# Create a mask on the lower triangular matrix
|
||||
fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
|
||||
|
||||
out_img, out_fg_prob, out_depth_map = preprocess_input(
|
||||
img, fg_prob, None, True, False, 0.3, (0.0, 0.0, 0.0)
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(out_img, torch.tril(img)))
|
||||
self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
|
||||
self.assertIsNone(out_depth_map)
|
||||
|
||||
def test_prepare_inputs_mask_depth_true(self):
|
||||
batch, channels, height, width = 2, 3, 10, 10
|
||||
img = torch.ones(batch, channels, height, width)
|
||||
depth_map = torch.randn(batch, channels, height, width)
|
||||
# Create a mask on the lower triangular matrix
|
||||
fg_prob = torch.tril(torch.ones(batch, 1, height, width)) * 0.3
|
||||
|
||||
out_img, out_fg_prob, out_depth_map = preprocess_input(
|
||||
img, fg_prob, depth_map, False, True, 0.3, (0.0, 0.0, 0.0)
|
||||
)
|
||||
|
||||
self.assertTrue(torch.equal(out_img, img))
|
||||
self.assertTrue(torch.equal(out_fg_prob, fg_prob >= 0.3))
|
||||
self.assertTrue(torch.equal(out_depth_map, torch.tril(depth_map)))
|
||||
|
||||
def test_weighted_sum_losses(self):
|
||||
preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
|
||||
weights = {"a": 2.0, "b": 0.0}
|
||||
loss = weighted_sum_losses(preds, weights)
|
||||
self.assertEqual(loss, 4.0)
|
||||
|
||||
def test_weighted_sum_losses_raise_warning(self):
|
||||
preds = {"a": torch.tensor(2), "b": torch.tensor(2)}
|
||||
weights = {"c": 2.0, "d": 2.0}
|
||||
self.assertIsNone(weighted_sum_losses(preds, weights))
|
||||
Reference in New Issue
Block a user