mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 14:20:38 +08:00
Add blurpool following MIPNerf paper.
Summary: Add blurpool has defined in [MIP-NeRF](https://arxiv.org/abs/2103.13415). It has been added has an option for RayPointRefiner. Reviewed By: shapovalov Differential Revision: D46356189 fbshipit-source-id: ad841bad86d2b591a68e1cb885d4f781cf26c111
This commit is contained in:
committed by
Facebook GitHub Bot
parent
ccf860f1db
commit
5910d81b7b
@@ -5,9 +5,14 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
||||
|
||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import (
|
||||
apply_blurpool_on_weights,
|
||||
RayPointRefiner,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import ImplicitronRayBundle
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
@@ -17,11 +22,12 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
||||
length = 15
|
||||
n_pts_per_ray = 10
|
||||
|
||||
for add_input_samples in [False, True]:
|
||||
for add_input_samples, use_blurpool in product([False, True], [False, True]):
|
||||
ray_point_refiner = RayPointRefiner(
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
random_sampling=False,
|
||||
add_input_samples=add_input_samples,
|
||||
blurpool_weights=use_blurpool,
|
||||
)
|
||||
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
||||
bundle = ImplicitronRayBundle(
|
||||
@@ -50,6 +56,7 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
random_sampling=True,
|
||||
add_input_samples=add_input_samples,
|
||||
blurpool_weights=use_blurpool,
|
||||
)
|
||||
refined_random = ray_point_refiner_random(bundle, weights)
|
||||
lengths_random = refined_random.lengths
|
||||
@@ -62,3 +69,24 @@ class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
|
||||
)
|
||||
|
||||
def test_apply_blurpool_on_weights(self):
|
||||
weights = torch.tensor(
|
||||
[
|
||||
[0.5, 0.6, 0.7],
|
||||
[0.5, 0.3, 0.9],
|
||||
]
|
||||
)
|
||||
expected_weights = 0.5 * torch.tensor(
|
||||
[
|
||||
[0.5 + 0.6, 0.6 + 0.7, 0.7 + 0.7],
|
||||
[0.5 + 0.5, 0.5 + 0.9, 0.9 + 0.9],
|
||||
]
|
||||
)
|
||||
out_weights = apply_blurpool_on_weights(weights)
|
||||
self.assertTrue(torch.allclose(out_weights, expected_weights))
|
||||
|
||||
def test_shapes_apply_blurpool_on_weights(self):
|
||||
weights = torch.randn((5, 4, 3, 2, 1))
|
||||
out_weights = apply_blurpool_on_weights(weights)
|
||||
self.assertEqual((5, 4, 3, 2, 1), out_weights.shape)
|
||||
|
||||
Reference in New Issue
Block a user