mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
volume cropping
Summary: TensoRF at step 2000 does volume croping and resizing. At those steps it calculates part of the voxel grid which has density big enough to have objects and resizes the grid to fit that object. Change is done on 3 levels: - implicit function subscribes to epochs and at specific epochs finds the bounding box of the object and calls resizing of the color and density voxel grids to fit it - VoxelGrid module calls cropping of the underlaying voxel grid and resizing to fit previous size it also adjusts its extends and translation to match wanted size - Each voxel grid has its own way of cropping the underlaying data Reviewed By: kjchalup Differential Revision: D39854548 fbshipit-source-id: 5435b6e599aef1eaab980f5421d3369ee4829c50
This commit is contained in:
committed by
Facebook GitHub Bot
parent
0b5def5257
commit
f55d37f07d
@@ -9,7 +9,7 @@ import unittest
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
interpolate_line,
|
||||
@@ -19,11 +19,12 @@ from pytorch3d.implicitron.models.implicit_function.utils import (
|
||||
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
|
||||
CPFactorizedVoxelGrid,
|
||||
FullResolutionVoxelGrid,
|
||||
FullResolutionVoxelGridValues,
|
||||
VMFactorizedVoxelGrid,
|
||||
VoxelGridModule,
|
||||
)
|
||||
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
@@ -693,6 +694,140 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
|
||||
grid.voxel_grid.get_resolution(epoch=epoch),
|
||||
)
|
||||
|
||||
def _get_min_max_tuple(
|
||||
self, n=4, denominator_base=2, max_exponent=6, add_edge_cases=True
|
||||
):
|
||||
if add_edge_cases:
|
||||
n -= 2
|
||||
|
||||
def get_pair():
|
||||
def get_one():
|
||||
sign = -1 if torch.rand((1,)) < 0.5 else 1
|
||||
exponent = int(torch.randint(1, max_exponent, (1,)))
|
||||
denominator = denominator_base**exponent
|
||||
numerator = int(torch.randint(1, denominator, (1,)))
|
||||
return sign * numerator / denominator * 1.0
|
||||
|
||||
while True:
|
||||
a, b = get_one(), get_one()
|
||||
if a < b:
|
||||
return a, b
|
||||
|
||||
for _ in range(n):
|
||||
a, b, c = get_pair(), get_pair(), get_pair()
|
||||
yield torch.tensor((a[0], b[0], c[0])), torch.tensor((a[1], b[1], c[1]))
|
||||
if add_edge_cases:
|
||||
yield torch.tensor((-1.0, -1.0, -1.0)), torch.tensor((1.0, 1.0, 1.0))
|
||||
yield torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0])
|
||||
|
||||
def test_cropping_voxel_grids(self, n_times=1):
|
||||
"""
|
||||
If the grid is 1d and we crop at A and B
|
||||
---------A---------B---
|
||||
and choose point p between them
|
||||
---------A-----p---B---
|
||||
it can be represented as
|
||||
p = A + (B-A) * p_c
|
||||
where p_c is local coordinate of p in cropped grid. So we now just see
|
||||
if grid evaluated at p and cropped grid evaluated at p_c agree.
|
||||
"""
|
||||
for points_min, points_max in self._get_min_max_tuple(n=10):
|
||||
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
|
||||
n_grids = 1
|
||||
n_components *= 3
|
||||
resolution_changes = {0: (128 + 1, 128 + 1, 128 + 1)}
|
||||
for cls, kwargs in (
|
||||
(
|
||||
FullResolutionVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
},
|
||||
),
|
||||
(
|
||||
CPFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
(
|
||||
VMFactorizedVoxelGrid,
|
||||
{
|
||||
"n_features": n_features,
|
||||
"resolution_changes": resolution_changes,
|
||||
"n_components": n_components,
|
||||
},
|
||||
),
|
||||
):
|
||||
with self.subTest(
|
||||
cls.__name__ + f" points {points_min} and {points_max}"
|
||||
):
|
||||
grid = cls(**kwargs)
|
||||
shapes = grid.get_shapes(epoch=0)
|
||||
params = {
|
||||
name: torch.normal(
|
||||
mean=torch.zeros((n_grids, *shape)),
|
||||
std=1,
|
||||
)
|
||||
for name, shape in shapes.items()
|
||||
}
|
||||
grid_values = grid.values_type(**params)
|
||||
|
||||
grid_values_cropped = grid.crop_local(
|
||||
points_min, points_max, grid_values
|
||||
)
|
||||
|
||||
points_local_cropped = torch.rand((1, n_times, 3))
|
||||
points_local = (
|
||||
points_min[None, None]
|
||||
+ (points_max - points_min)[None, None] * points_local_cropped
|
||||
)
|
||||
points_local_cropped = (points_local_cropped - 0.5) * 2
|
||||
|
||||
pred = grid.evaluate_local(points_local, grid_values)
|
||||
pred_cropped = grid.evaluate_local(
|
||||
points_local_cropped, grid_values_cropped
|
||||
)
|
||||
|
||||
assert torch.allclose(pred, pred_cropped, rtol=1e-4, atol=1e-4), (
|
||||
pred,
|
||||
pred_cropped,
|
||||
points_local,
|
||||
points_local_cropped,
|
||||
)
|
||||
|
||||
def test_cropping_voxel_grid_module(self, n_times=1):
|
||||
for points_min, points_max in self._get_min_max_tuple(n=5, max_exponent=5):
|
||||
extents = torch.ones((3,)) * 2
|
||||
translation = torch.ones((3,)) * 0.2
|
||||
points_min += translation
|
||||
points_max += translation
|
||||
|
||||
default_cfg = get_default_args(VoxelGridModule)
|
||||
custom_cfg = DictConfig(
|
||||
{
|
||||
"extents": tuple(float(e) for e in extents),
|
||||
"translation": tuple(float(t) for t in translation),
|
||||
"voxel_grid_FullResolutionVoxelGrid_args": {
|
||||
"resolution_changes": {0: (128 + 1, 128 + 1, 128 + 1)}
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = OmegaConf.merge(default_cfg, custom_cfg)
|
||||
grid = VoxelGridModule(**cfg)
|
||||
|
||||
points = (torch.rand(3) * (points_max - points_min) + points_min)[None]
|
||||
result = grid(points)
|
||||
grid.crop_self(points_min, points_max)
|
||||
result_cropped = grid(points)
|
||||
|
||||
assert torch.allclose(result, result_cropped, rtol=0.001, atol=0.001), (
|
||||
result,
|
||||
result_cropped,
|
||||
)
|
||||
|
||||
def test_loading_state_dict(self):
|
||||
"""
|
||||
Test loading state dict after rescaling.
|
||||
|
||||
Reference in New Issue
Block a user