volume cropping

Summary:
TensoRF at step 2000 does volume croping and resizing.
At those steps it calculates part of the voxel grid which has density big enough to have objects and resizes the grid to fit that object.
Change is done on 3 levels:
- implicit function subscribes to epochs and at specific epochs finds the bounding box of the object and calls resizing of the color and density voxel grids to fit it
- VoxelGrid module calls cropping of the underlaying voxel grid and resizing to fit previous size it also adjusts its extends and translation to match wanted size
- Each voxel grid has its own way of cropping the underlaying data

Reviewed By: kjchalup

Differential Revision: D39854548

fbshipit-source-id: 5435b6e599aef1eaab980f5421d3369ee4829c50
This commit is contained in:
Darijan Gudelj
2022-10-12 08:31:51 -07:00
committed by Facebook GitHub Bot
parent 0b5def5257
commit f55d37f07d
2 changed files with 377 additions and 11 deletions

View File

@@ -9,7 +9,7 @@ import unittest
from typing import Optional, Tuple
import torch
from omegaconf import DictConfig
from omegaconf import DictConfig, OmegaConf
from pytorch3d.implicitron.models.implicit_function.utils import (
interpolate_line,
@@ -19,11 +19,12 @@ from pytorch3d.implicitron.models.implicit_function.utils import (
from pytorch3d.implicitron.models.implicit_function.voxel_grid import (
CPFactorizedVoxelGrid,
FullResolutionVoxelGrid,
FullResolutionVoxelGridValues,
VMFactorizedVoxelGrid,
VoxelGridModule,
)
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from tests.common_testing import TestCaseMixin
@@ -693,6 +694,140 @@ class TestVoxelGrids(TestCaseMixin, unittest.TestCase):
grid.voxel_grid.get_resolution(epoch=epoch),
)
def _get_min_max_tuple(
self, n=4, denominator_base=2, max_exponent=6, add_edge_cases=True
):
if add_edge_cases:
n -= 2
def get_pair():
def get_one():
sign = -1 if torch.rand((1,)) < 0.5 else 1
exponent = int(torch.randint(1, max_exponent, (1,)))
denominator = denominator_base**exponent
numerator = int(torch.randint(1, denominator, (1,)))
return sign * numerator / denominator * 1.0
while True:
a, b = get_one(), get_one()
if a < b:
return a, b
for _ in range(n):
a, b, c = get_pair(), get_pair(), get_pair()
yield torch.tensor((a[0], b[0], c[0])), torch.tensor((a[1], b[1], c[1]))
if add_edge_cases:
yield torch.tensor((-1.0, -1.0, -1.0)), torch.tensor((1.0, 1.0, 1.0))
yield torch.tensor([0.0, 0.0, 0.0]), torch.tensor([1.0, 1.0, 1.0])
def test_cropping_voxel_grids(self, n_times=1):
"""
If the grid is 1d and we crop at A and B
---------A---------B---
and choose point p between them
---------A-----p---B---
it can be represented as
p = A + (B-A) * p_c
where p_c is local coordinate of p in cropped grid. So we now just see
if grid evaluated at p and cropped grid evaluated at p_c agree.
"""
for points_min, points_max in self._get_min_max_tuple(n=10):
n_grids, n_features, n_components = torch.randint(1, 3, (3,)).tolist()
n_grids = 1
n_components *= 3
resolution_changes = {0: (128 + 1, 128 + 1, 128 + 1)}
for cls, kwargs in (
(
FullResolutionVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
},
),
(
CPFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
(
VMFactorizedVoxelGrid,
{
"n_features": n_features,
"resolution_changes": resolution_changes,
"n_components": n_components,
},
),
):
with self.subTest(
cls.__name__ + f" points {points_min} and {points_max}"
):
grid = cls(**kwargs)
shapes = grid.get_shapes(epoch=0)
params = {
name: torch.normal(
mean=torch.zeros((n_grids, *shape)),
std=1,
)
for name, shape in shapes.items()
}
grid_values = grid.values_type(**params)
grid_values_cropped = grid.crop_local(
points_min, points_max, grid_values
)
points_local_cropped = torch.rand((1, n_times, 3))
points_local = (
points_min[None, None]
+ (points_max - points_min)[None, None] * points_local_cropped
)
points_local_cropped = (points_local_cropped - 0.5) * 2
pred = grid.evaluate_local(points_local, grid_values)
pred_cropped = grid.evaluate_local(
points_local_cropped, grid_values_cropped
)
assert torch.allclose(pred, pred_cropped, rtol=1e-4, atol=1e-4), (
pred,
pred_cropped,
points_local,
points_local_cropped,
)
def test_cropping_voxel_grid_module(self, n_times=1):
for points_min, points_max in self._get_min_max_tuple(n=5, max_exponent=5):
extents = torch.ones((3,)) * 2
translation = torch.ones((3,)) * 0.2
points_min += translation
points_max += translation
default_cfg = get_default_args(VoxelGridModule)
custom_cfg = DictConfig(
{
"extents": tuple(float(e) for e in extents),
"translation": tuple(float(t) for t in translation),
"voxel_grid_FullResolutionVoxelGrid_args": {
"resolution_changes": {0: (128 + 1, 128 + 1, 128 + 1)}
},
}
)
cfg = OmegaConf.merge(default_cfg, custom_cfg)
grid = VoxelGridModule(**cfg)
points = (torch.rand(3) * (points_max - points_min) + points_min)[None]
result = grid(points)
grid.crop_self(points_min, points_max)
result_cropped = grid(points)
assert torch.allclose(result, result_cropped, rtol=0.001, atol=0.001), (
result,
result_cropped,
)
def test_loading_state_dict(self):
"""
Test loading state dict after rescaling.