Add blurpool following MIPNerf paper.

Summary:
Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415).
It has been added has an option for RayPointRefiner.

Reviewed By: shapovalov

Differential Revision: D46356189

fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
This commit is contained in:
Emilien Garreau
2023-07-06 02:20:53 -07:00
committed by Facebook GitHub Bot
parent ccf860f1db
commit 5910d81b7b
4 changed files with 97 additions and 5 deletions

View File

@@ -5,9 +5,14 @@
# LICENSE file in the root directory of this source tree.
import unittest
from itertools import product
import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
apply_blurpool_on_weights,
RayPointRefiner,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
from tests.common_testing import TestCaseMixin
@@ -17,11 +22,12 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
length = 15
n_pts_per_ray = 10
for add_input_samples in [False, True]:
for add_input_samples, use_blurpool in product([False, True], [False, True]):
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
blurpool_weights=use_blurpool,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = ImplicitronRayBundle(
@@ -50,6 +56,7 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
blurpool_weights=use_blurpool,
)
refined_random = ray_point_refiner_random(bundle, weights)
lengths_random = refined_random.lengths
@@ -62,3 +69,24 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)
def test_apply_blurpool_on_weights(self):
weights = torch.tensor(
[
[0.5, 0.6, 0.7],
[0.5, 0.3, 0.9],
]
)
expected_weights = 0.5 * torch.tensor(
[
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
]
)
out_weights = apply_blurpool_on_weights(weights)
self.assertTrue(torch.allclose(out_weights, expected_weights))
def test_shapes_apply_blurpool_on_weights(self):
weights = torch.randn((5, 4, 3, 2, 1))
out_weights = apply_blurpool_on_weights(weights)
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)