mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-14 11:26:24 +08:00
implicitron v0 (#1133)
Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
0e377c6850
commit
cdd2142dd5
5
tests/implicitron/__init__.py
Normal file
5
tests/implicitron/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
114
tests/implicitron/common_resources.py
Normal file
114
tests/implicitron/common_resources.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Generator, Tuple
|
||||
from zipfile import ZipFile
|
||||
|
||||
from iopath.common.file_io import PathManager
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_skateboard_data(
|
||||
avoid_manifold: bool = False, silence_logs: bool = False
|
||||
) -> Generator[Tuple[str, PathManager], None, None]:
|
||||
"""
|
||||
Context manager for accessing Co3D dataset by tests, at least for
|
||||
the first 5 skateboards. Internally, we want this to exercise the
|
||||
normal way to access the data directly manifold, but on an RE
|
||||
worker this is impossible so we use a workaround.
|
||||
|
||||
Args:
|
||||
avoid_manifold: Use the method used by RE workers even locally.
|
||||
silence_logs: Whether to reduce log output from iopath library.
|
||||
|
||||
Yields:
|
||||
dataset_root: (str) path to dataset root.
|
||||
path_manager: path_manager to access it with.
|
||||
"""
|
||||
path_manager = PathManager()
|
||||
if silence_logs:
|
||||
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
|
||||
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
|
||||
|
||||
if not os.environ.get("FB_TEST", False):
|
||||
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
|
||||
raise unittest.SkipTest("Unknown environment. Data not available.")
|
||||
yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", path_manager
|
||||
|
||||
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
|
||||
from libfb.py.parutil import get_file_path
|
||||
|
||||
par_path = "skateboard_first_5"
|
||||
source = get_file_path(par_path)
|
||||
assert Path(source).is_file()
|
||||
with tempfile.TemporaryDirectory() as dest:
|
||||
with ZipFile(source) as f:
|
||||
f.extractall(dest)
|
||||
yield os.path.join(dest, "extracted"), path_manager
|
||||
else:
|
||||
from iopath.fb.manifold import ManifoldPathHandler
|
||||
|
||||
path_manager.register_handler(ManifoldPathHandler())
|
||||
|
||||
yield "manifold://co3d/tree/extracted", path_manager
|
||||
|
||||
|
||||
def provide_lpips_vgg():
|
||||
"""
|
||||
Ensure the weights files are available for lpips.LPIPS(net="vgg")
|
||||
to be called. Specifically, torchvision's vgg16
|
||||
"""
|
||||
# In OSS, torchvision looks for vgg16 weights in
|
||||
# https://download.pytorch.org/models/vgg16-397923af.pth
|
||||
# Inside fbcode, this is replaced by asking iopath for
|
||||
# manifold://torchvision/tree/models/vgg16-397923af.pth
|
||||
# (the code for this replacement is in
|
||||
# fbcode/pytorch/vision/fb/_internally_replaced_utils.py )
|
||||
#
|
||||
# iopath does this by looking for the file at the cache location
|
||||
# and if it is not there getting it from manifold.
|
||||
# (the code for this is in
|
||||
# fbcode/fair_infra/data/iopath/iopath/fb/manifold.py )
|
||||
#
|
||||
# On the remote execution worker, manifold is inaccessible.
|
||||
# We solve this by making the cached file available before iopath
|
||||
# looks.
|
||||
#
|
||||
# By default the cache location is
|
||||
# ~/.torch/iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
|
||||
# But we can't write to the home directory on the RE worker.
|
||||
# We define FVCORE_CACHE to change the cache location to
|
||||
# iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
|
||||
# (Without it, manifold caches in unstable temporary locations on RE.)
|
||||
#
|
||||
# The file we want has been copied from
|
||||
# tree/models/vgg16-397923af.pth in the torchvision bucket
|
||||
# to
|
||||
# tree/testing/vgg16-397923af.pth in the co3d bucket
|
||||
# and the TARGETS file copies it somewhere in the PAR which we
|
||||
# recover with get_file_path.
|
||||
# (It can't copy straight to a nested location, see
|
||||
# https://fb.workplace.com/groups/askbuck/posts/2644615728920359/)
|
||||
# Here we symlink it to the new cache location.
|
||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||
from libfb.py.parutil import get_file_path
|
||||
|
||||
os.environ["FVCORE_CACHE"] = "iopath_cache"
|
||||
|
||||
par_path = "vgg_weights_for_lpips"
|
||||
source = Path(get_file_path(par_path))
|
||||
assert source.is_file()
|
||||
|
||||
dest = Path("iopath_cache/manifold_cache/tree/models")
|
||||
if not dest.exists():
|
||||
dest.mkdir(parents=True)
|
||||
(dest / "vgg16-397923af.pth").symlink_to(source)
|
||||
122
tests/implicitron/data/overrides.yaml
Normal file
122
tests/implicitron/data/overrides.yaml
Normal file
@@ -0,0 +1,122 @@
|
||||
mask_images: true
|
||||
mask_depths: true
|
||||
render_image_width: 400
|
||||
render_image_height: 400
|
||||
mask_threshold: 0.5
|
||||
output_rasterized_mc: false
|
||||
bg_color:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
view_pool: false
|
||||
num_passes: 1
|
||||
chunk_size_grid: 4096
|
||||
render_features_dimensions: 3
|
||||
tqdm_trigger_threshold: 16
|
||||
n_train_target_views: 1
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
renderer_class_type: LSTMRenderer
|
||||
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
|
||||
implicit_function_class_type: IdrFeatureField
|
||||
loss_weights:
|
||||
loss_rgb_mse: 1.0
|
||||
loss_prev_stage_rgb_mse: 1.0
|
||||
loss_mask_bce: 0.0
|
||||
loss_prev_stage_mask_bce: 0.0
|
||||
log_vars:
|
||||
- loss_rgb_psnr_fg
|
||||
- loss_rgb_psnr
|
||||
- loss_rgb_mse
|
||||
- loss_rgb_huber
|
||||
- loss_depth_abs
|
||||
- loss_depth_abs_fg
|
||||
- loss_mask_neg_iou
|
||||
- loss_mask_bce
|
||||
- loss_mask_beta_prior
|
||||
- loss_eikonal
|
||||
- loss_density_tv
|
||||
- loss_depth_neg_penalty
|
||||
- loss_autodecoder_norm
|
||||
- loss_prev_stage_rgb_mse
|
||||
- loss_prev_stage_rgb_psnr_fg
|
||||
- loss_prev_stage_rgb_psnr
|
||||
- loss_prev_stage_mask_bce
|
||||
- objective
|
||||
- epoch
|
||||
- sec/it
|
||||
sequence_autodecoder_args:
|
||||
encoding_dim: 0
|
||||
n_instances: 0
|
||||
init_scale: 1.0
|
||||
ignore_input: false
|
||||
raysampler_args:
|
||||
image_width: 400
|
||||
image_height: 400
|
||||
scene_center:
|
||||
- 0.0
|
||||
- 0.0
|
||||
- 0.0
|
||||
scene_extent: 0.0
|
||||
sampling_mode_training: mask_sample
|
||||
sampling_mode_evaluation: full_grid
|
||||
n_pts_per_ray_training: 64
|
||||
n_pts_per_ray_evaluation: 64
|
||||
n_rays_per_image_sampled_from_mask: 1024
|
||||
min_depth: 0.1
|
||||
max_depth: 8.0
|
||||
stratified_point_sampling_training: true
|
||||
stratified_point_sampling_evaluation: false
|
||||
renderer_LSTMRenderer_args:
|
||||
num_raymarch_steps: 10
|
||||
init_depth: 17.0
|
||||
init_depth_noise_std: 0.0005
|
||||
hidden_size: 16
|
||||
n_feature_channels: 256
|
||||
verbose: false
|
||||
image_feature_extractor_args:
|
||||
name: resnet34
|
||||
pretrained: true
|
||||
stages:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
normalize_image: true
|
||||
image_rescale: 0.16
|
||||
first_max_pool: true
|
||||
proj_dim: 32
|
||||
l2_norm: true
|
||||
add_masks: true
|
||||
add_images: true
|
||||
global_average_pool: false
|
||||
feature_rescale: 1.0
|
||||
view_sampler_args:
|
||||
masked_sampling: false
|
||||
sampling_mode: bilinear
|
||||
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
|
||||
exclude_target_view: true
|
||||
exclude_target_view_mask_features: true
|
||||
concatenate_output: true
|
||||
weight_by_ray_angle_gamma: 1.0
|
||||
min_ray_angle_weight: 0.1
|
||||
implicit_function_IdrFeatureField_args:
|
||||
feature_vector_size: 3
|
||||
d_in: 3
|
||||
d_out: 1
|
||||
dims:
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
- 512
|
||||
geometric_init: true
|
||||
bias: 1.0
|
||||
skip_in: []
|
||||
weight_norm: true
|
||||
n_harmonic_functions_xyz: 0
|
||||
pooled_feature_dim: 0
|
||||
encoding_dim: 0
|
||||
215
tests/implicitron/test_batch_sampler.py
Normal file
215
tests/implicitron/test_batch_sampler.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFrameAnnotation:
|
||||
frame_number: int
|
||||
frame_timestamp: float = 0.0
|
||||
|
||||
|
||||
class MockDataset:
|
||||
def __init__(self, num_seq, max_frame_gap=1):
|
||||
"""
|
||||
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
|
||||
"""
|
||||
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
|
||||
self.seq_to_idx = {
|
||||
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
|
||||
}
|
||||
|
||||
# frame numbers within sequence: [0, ..., 4, n, ..., n+4]
|
||||
# where n - 4 == max_frame_gap
|
||||
frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap))
|
||||
self.frame_annots = [
|
||||
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
|
||||
]
|
||||
|
||||
def get_frame_numbers_and_timestamps(self, idxs):
|
||||
out = []
|
||||
for idx in idxs:
|
||||
frame_annotation = self.frame_annots[idx]["frame_annotation"]
|
||||
out.append(
|
||||
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
class TestSceneBatchSampler(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dataset_overfit = MockDataset(1)
|
||||
|
||||
def test_overfit(self):
|
||||
num_batches = 3
|
||||
batch_size = 10
|
||||
sampler = SceneBatchSampler(
|
||||
self.dataset_overfit,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
images_per_seq_options=[10], # will try to sample batch_size anyway
|
||||
)
|
||||
|
||||
self.assertEqual(len(sampler), num_batches)
|
||||
|
||||
it = iter(sampler)
|
||||
for _ in range(num_batches):
|
||||
batch = next(it)
|
||||
self.assertIsNotNone(batch)
|
||||
self.assertEqual(len(batch), batch_size) # true for our examples
|
||||
self.assertTrue(all(idx // 10 == 0 for idx in batch))
|
||||
|
||||
with self.assertRaises(StopIteration):
|
||||
batch = next(it)
|
||||
|
||||
def test_multiseq(self):
|
||||
for ips_options in [[10], [2], [3], [2, 3, 4]]:
|
||||
for sample_consecutive_frames in [True, False]:
|
||||
for consecutive_frames_max_gap in [0, 1, 3]:
|
||||
self._test_multiseq_flavour(
|
||||
ips_options,
|
||||
sample_consecutive_frames,
|
||||
consecutive_frames_max_gap,
|
||||
)
|
||||
|
||||
def test_multiseq_gaps(self):
|
||||
num_batches = 16
|
||||
batch_size = 10
|
||||
dataset_multiseq = MockDataset(5, max_frame_gap=3)
|
||||
for ips_options in [[10], [2], [3], [2, 3, 4]]:
|
||||
debug_info = f" Images per sequence: {ips_options}."
|
||||
|
||||
sampler = SceneBatchSampler(
|
||||
dataset_multiseq,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
images_per_seq_options=ips_options,
|
||||
sample_consecutive_frames=True,
|
||||
consecutive_frames_max_gap=1,
|
||||
)
|
||||
|
||||
self.assertEqual(len(sampler), num_batches, msg=debug_info)
|
||||
|
||||
it = iter(sampler)
|
||||
for _ in range(num_batches):
|
||||
batch = next(it)
|
||||
self.assertIsNotNone(batch, "batch is None in" + debug_info)
|
||||
if max(ips_options) > 5:
|
||||
# true for our examples
|
||||
self.assertEqual(len(batch), 5, msg=debug_info)
|
||||
else:
|
||||
# true for our examples
|
||||
self.assertEqual(len(batch), batch_size, msg=debug_info)
|
||||
|
||||
self._check_frames_are_consecutive(
|
||||
batch, dataset_multiseq.frame_annots, debug_info
|
||||
)
|
||||
|
||||
def _test_multiseq_flavour(
|
||||
self,
|
||||
ips_options,
|
||||
sample_consecutive_frames,
|
||||
consecutive_frames_max_gap,
|
||||
num_batches=16,
|
||||
batch_size=10,
|
||||
):
|
||||
debug_info = (
|
||||
f" Images per sequence: {ips_options}, "
|
||||
f"sample_consecutive_frames: {sample_consecutive_frames}, "
|
||||
f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, "
|
||||
)
|
||||
# in this test, either consecutive_frames_max_gap == max_frame_gap,
|
||||
# or consecutive_frames_max_gap == 0, so segments consist of full sequences
|
||||
frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3
|
||||
dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap)
|
||||
sampler = SceneBatchSampler(
|
||||
dataset_multiseq,
|
||||
batch_size=batch_size,
|
||||
num_batches=num_batches,
|
||||
images_per_seq_options=ips_options,
|
||||
sample_consecutive_frames=sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=consecutive_frames_max_gap,
|
||||
)
|
||||
|
||||
self.assertEqual(len(sampler), num_batches, msg=debug_info)
|
||||
|
||||
it = iter(sampler)
|
||||
typical_counts = set()
|
||||
for _ in range(num_batches):
|
||||
batch = next(it)
|
||||
self.assertIsNotNone(batch, "batch is None in" + debug_info)
|
||||
# true for our examples
|
||||
self.assertEqual(len(batch), batch_size, msg=debug_info)
|
||||
# find distribution over sequences
|
||||
counts = _count_by_quotient(batch, 10)
|
||||
freqs = _count_by_quotient(counts.values(), 1)
|
||||
self.assertLessEqual(
|
||||
len(freqs),
|
||||
2,
|
||||
msg="We should have maximum of 2 different "
|
||||
"frequences of sequences in the batch." + debug_info,
|
||||
)
|
||||
if len(freqs) == 2:
|
||||
most_seq_count = max(*freqs.keys())
|
||||
last_seq = min(*freqs.keys())
|
||||
self.assertEqual(
|
||||
freqs[last_seq],
|
||||
1,
|
||||
msg="Only one odd sequence allowed." + debug_info,
|
||||
)
|
||||
else:
|
||||
self.assertEqual(len(freqs), 1)
|
||||
most_seq_count = next(iter(freqs))
|
||||
|
||||
self.assertIn(most_seq_count, ips_options)
|
||||
typical_counts.add(most_seq_count)
|
||||
|
||||
if sample_consecutive_frames:
|
||||
self._check_frames_are_consecutive(
|
||||
batch,
|
||||
dataset_multiseq.frame_annots,
|
||||
debug_info,
|
||||
max_gap=consecutive_frames_max_gap,
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
all(i in typical_counts for i in ips_options),
|
||||
"Some of the frequency options did not occur among "
|
||||
f"the {num_batches} batches (could be just bad luck)." + debug_info,
|
||||
)
|
||||
|
||||
with self.assertRaises(StopIteration):
|
||||
batch = next(it)
|
||||
|
||||
def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1):
|
||||
# make sure that sampled frames are consecutive
|
||||
for i in range(len(batch) - 1):
|
||||
curr_idx, next_idx = batch[i : i + 2]
|
||||
if curr_idx // 10 == next_idx // 10: # same sequence
|
||||
if max_gap > 0:
|
||||
curr_idx, next_idx = [
|
||||
annots[idx]["frame_annotation"].frame_number
|
||||
for idx in (curr_idx, next_idx)
|
||||
]
|
||||
gap = max_gap
|
||||
else:
|
||||
gap = 1 # we'll check that raw dataset indices are consecutive
|
||||
|
||||
self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info)
|
||||
|
||||
|
||||
def _count_by_quotient(indices, divisor):
|
||||
counter = defaultdict(int)
|
||||
for i in indices:
|
||||
counter[i // divisor] += 1
|
||||
|
||||
return counter
|
||||
177
tests/implicitron/test_circle_fitting.py
Normal file
177
tests/implicitron/test_circle_fitting.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from math import pi
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.circle_fitting import (
|
||||
_signed_area,
|
||||
fit_circle_in_2d,
|
||||
fit_circle_in_3d,
|
||||
)
|
||||
from pytorch3d.transforms import random_rotation
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import TestCaseMixin
|
||||
else:
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestCircleFitting(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def _assertParallel(self, a, b, **kwargs):
|
||||
"""
|
||||
Given a and b of shape (..., 3) each containing 3D vectors,
|
||||
assert that correspnding vectors are parallel. Changed sign is ok.
|
||||
"""
|
||||
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
|
||||
|
||||
def test_simple_3d(self):
|
||||
device = torch.device("cuda:0")
|
||||
for _ in range(7):
|
||||
radius = 10 * torch.rand(1, device=device)[0]
|
||||
center = 10 * torch.rand(3, device=device)
|
||||
rot = random_rotation(device=device)
|
||||
offset = torch.rand(3, device=device)
|
||||
up = torch.rand(3, device=device)
|
||||
self._simple_3d_test(radius, center, rot, offset, up)
|
||||
|
||||
def _simple_3d_test(self, radius, center, rot, offset, up):
|
||||
# angles are increasing so the points move in a well defined direction.
|
||||
angles = torch.cumsum(torch.rand(17, device=rot.device), dim=0)
|
||||
many = torch.stack(
|
||||
[torch.cos(angles), torch.sin(angles), torch.zeros_like(angles)], dim=1
|
||||
)
|
||||
source_points = (many * radius) @ rot + center[None]
|
||||
|
||||
# case with no generation
|
||||
result = fit_circle_in_3d(source_points)
|
||||
self.assertClose(result.radius, radius)
|
||||
self.assertClose(result.center, center)
|
||||
self._assertParallel(result.normal, rot[2], atol=1e-5)
|
||||
self.assertEqual(result.generated_points.shape, (0, 3))
|
||||
|
||||
# Generate 5 points around the circle
|
||||
n_new_points = 5
|
||||
result2 = fit_circle_in_3d(source_points, n_points=n_new_points)
|
||||
self.assertClose(result2.radius, radius)
|
||||
self.assertClose(result2.center, center)
|
||||
self.assertClose(result2.normal, result.normal)
|
||||
self.assertEqual(result2.generated_points.shape, (5, 3))
|
||||
|
||||
observed_points = result2.generated_points
|
||||
self.assertClose(observed_points[0], observed_points[4], atol=1e-4)
|
||||
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
|
||||
observed_normal = torch.cross(
|
||||
observed_points[0] - observed_points[2],
|
||||
observed_points[1] - observed_points[3],
|
||||
dim=-1,
|
||||
)
|
||||
self._assertParallel(observed_normal, result.normal, atol=1e-4)
|
||||
diameters = observed_points[:2] - observed_points[2:4]
|
||||
self.assertClose(
|
||||
torch.norm(diameters, dim=1), diameters.new_full((2,), 2 * radius)
|
||||
)
|
||||
|
||||
# Regenerate the input points
|
||||
result3 = fit_circle_in_3d(source_points, angles=angles - angles[0])
|
||||
self.assertClose(result3.radius, radius)
|
||||
self.assertClose(result3.center, center)
|
||||
self.assertClose(result3.normal, result.normal)
|
||||
self.assertClose(result3.generated_points, source_points, atol=1e-5)
|
||||
|
||||
# Test with offset
|
||||
result4 = fit_circle_in_3d(
|
||||
source_points, angles=angles - angles[0], offset=offset, up=up
|
||||
)
|
||||
self.assertClose(result4.radius, radius)
|
||||
self.assertClose(result4.center, center)
|
||||
self.assertClose(result4.normal, result.normal)
|
||||
observed_offsets = result4.generated_points - source_points
|
||||
|
||||
# observed_offset is constant
|
||||
self.assertClose(
|
||||
observed_offsets.min(0).values, observed_offsets.max(0).values, atol=1e-5
|
||||
)
|
||||
# observed_offset has the right length
|
||||
self.assertClose(observed_offsets[0].norm(), offset.norm())
|
||||
|
||||
self.assertClose(result.normal.norm(), torch.ones(()))
|
||||
# component of observed_offset along normal
|
||||
component = torch.dot(observed_offsets[0], result.normal)
|
||||
self.assertClose(component.abs(), offset[2].abs(), atol=1e-5)
|
||||
agree_normal = torch.dot(result.normal, up) > 0
|
||||
agree_signs = component * offset[2] > 0
|
||||
self.assertEqual(agree_normal, agree_signs)
|
||||
|
||||
def test_simple_2d(self):
|
||||
radius = 7.0
|
||||
center = torch.tensor([9, 2.5])
|
||||
angles = torch.cumsum(torch.rand(17), dim=0)
|
||||
many = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
|
||||
source_points = (many * radius) + center[None]
|
||||
|
||||
result = fit_circle_in_2d(source_points)
|
||||
self.assertClose(result.radius, torch.tensor(radius))
|
||||
self.assertClose(result.center, center)
|
||||
self.assertEqual(result.generated_points.shape, (0, 2))
|
||||
|
||||
# Generate 5 points around the circle
|
||||
n_new_points = 5
|
||||
result2 = fit_circle_in_2d(source_points, n_points=n_new_points)
|
||||
self.assertClose(result2.radius, torch.tensor(radius))
|
||||
self.assertClose(result2.center, center)
|
||||
self.assertEqual(result2.generated_points.shape, (5, 2))
|
||||
|
||||
observed_points = result2.generated_points
|
||||
self.assertClose(observed_points[0], observed_points[4])
|
||||
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
|
||||
diameters = observed_points[:2] - observed_points[2:4]
|
||||
self.assertClose(torch.norm(diameters, dim=1), torch.full((2,), 2 * radius))
|
||||
|
||||
# Regenerate the input points
|
||||
result3 = fit_circle_in_2d(source_points, angles=angles - angles[0])
|
||||
self.assertClose(result3.radius, torch.tensor(radius))
|
||||
self.assertClose(result3.center, center)
|
||||
self.assertClose(result3.generated_points, source_points, atol=1e-5)
|
||||
|
||||
def test_minimum_inputs(self):
|
||||
fit_circle_in_3d(torch.rand(3, 3), n_points=10)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "2 points are not enough to determine a circle"
|
||||
):
|
||||
fit_circle_in_3d(torch.rand(2, 3))
|
||||
|
||||
def test_signed_area(self):
|
||||
n_points = 1001
|
||||
angles = torch.linspace(0, 2 * pi, n_points)
|
||||
radius = 0.85
|
||||
center = torch.rand(2)
|
||||
circle = center + radius * torch.stack(
|
||||
[torch.cos(angles), torch.sin(angles)], dim=1
|
||||
)
|
||||
circle_area = torch.tensor(pi * radius * radius)
|
||||
self.assertClose(_signed_area(circle), circle_area)
|
||||
# clockwise is negative
|
||||
self.assertClose(_signed_area(circle.flip(0)), -circle_area)
|
||||
|
||||
# Semicircles
|
||||
self.assertClose(_signed_area(circle[: (n_points + 1) // 2]), circle_area / 2)
|
||||
self.assertClose(_signed_area(circle[n_points // 2 :]), circle_area / 2)
|
||||
|
||||
# A straight line bounds no area
|
||||
self.assertClose(_signed_area(torch.rand(2, 2)), torch.tensor(0.0))
|
||||
|
||||
# Letter 'L' written anticlockwise.
|
||||
L_shape = [[0, 1], [0, 0], [1, 0]]
|
||||
# Triangle area is 0.5 * b * h.
|
||||
self.assertClose(_signed_area(torch.tensor(L_shape)), torch.tensor(0.5))
|
||||
610
tests/implicitron/test_config.py
Normal file
610
tests/implicitron/test_config.py
Normal file
@@ -0,0 +1,610 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import textwrap
|
||||
import unittest
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
ReplaceableBase,
|
||||
_is_actually_dataclass,
|
||||
_Registry,
|
||||
expand_args_fields,
|
||||
get_default_args,
|
||||
get_default_args_field,
|
||||
registry,
|
||||
remove_unused_components,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Animal(ReplaceableBase):
|
||||
pass
|
||||
|
||||
|
||||
class Fruit(ReplaceableBase):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class Banana(Fruit):
|
||||
pips: int
|
||||
spots: int
|
||||
bananame: str
|
||||
|
||||
|
||||
@registry.register
|
||||
class Pear(Fruit):
|
||||
n_pips: int = 13
|
||||
|
||||
|
||||
class Pineapple(Fruit):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class Orange(Fruit):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class Kiwi(Fruit):
|
||||
pass
|
||||
|
||||
|
||||
@registry.register
|
||||
class LargePear(Pear):
|
||||
pass
|
||||
|
||||
|
||||
class MainTest(Configurable):
|
||||
the_fruit: Fruit
|
||||
n_ids: int
|
||||
n_reps: int = 8
|
||||
the_second_fruit: Fruit
|
||||
|
||||
def create_the_second_fruit(self):
|
||||
expand_args_fields(Pineapple)
|
||||
self.the_second_fruit = Pineapple()
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
|
||||
class TestConfig(unittest.TestCase):
|
||||
def test_is_actually_dataclass(self):
|
||||
@dataclass
|
||||
class A:
|
||||
pass
|
||||
|
||||
self.assertTrue(_is_actually_dataclass(A))
|
||||
self.assertTrue(is_dataclass(A))
|
||||
|
||||
class B(A):
|
||||
a: int
|
||||
|
||||
self.assertFalse(_is_actually_dataclass(B))
|
||||
self.assertTrue(is_dataclass(B))
|
||||
|
||||
def test_simple_replacement(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 9780
|
||||
struct.the_fruit_Pear_args.n_pips = 3
|
||||
struct.the_fruit_class_type = "Pear"
|
||||
struct.the_second_fruit_class_type = "Pear"
|
||||
|
||||
main = MainTest(**struct)
|
||||
self.assertIsInstance(main.the_fruit, Pear)
|
||||
self.assertEqual(main.n_reps, 8)
|
||||
self.assertEqual(main.n_ids, 9780)
|
||||
self.assertEqual(main.the_fruit.n_pips, 3)
|
||||
self.assertIsInstance(main.the_second_fruit, Pineapple)
|
||||
|
||||
struct2 = get_default_args(MainTest)
|
||||
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
|
||||
|
||||
self.assertEqual(
|
||||
MainTest._creation_functions,
|
||||
("create_the_fruit", "create_the_second_fruit"),
|
||||
)
|
||||
|
||||
def test_detect_bases(self):
|
||||
# testing the _base_class_from_class function
|
||||
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
|
||||
self.assertIsNone(_Registry._base_class_from_class(MainTest))
|
||||
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
|
||||
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
|
||||
|
||||
class PricklyPear(Pear):
|
||||
pass
|
||||
|
||||
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
|
||||
|
||||
def test_registry_entries(self):
|
||||
self.assertIs(registry.get(Fruit, "Banana"), Banana)
|
||||
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
|
||||
registry.get(Animal, "Banana")
|
||||
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
|
||||
registry.get(Fruit, "PricklyPear")
|
||||
|
||||
self.assertIs(registry.get(Pear, "Pear"), Pear)
|
||||
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
|
||||
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
|
||||
registry.get(Pear, "Banana")
|
||||
|
||||
all_fruit = set(registry.get_all(Fruit))
|
||||
self.assertIn(Banana, all_fruit)
|
||||
self.assertIn(Pear, all_fruit)
|
||||
self.assertIn(LargePear, all_fruit)
|
||||
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
|
||||
|
||||
@registry.register
|
||||
class Apple(Fruit):
|
||||
pass
|
||||
|
||||
@registry.register
|
||||
class CrabApple(Apple):
|
||||
pass
|
||||
|
||||
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
|
||||
|
||||
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
|
||||
|
||||
@registry.register
|
||||
class NotAFruit:
|
||||
pass
|
||||
|
||||
def test_recursion(self):
|
||||
class Shape(ReplaceableBase):
|
||||
pass
|
||||
|
||||
@registry.register
|
||||
class Triangle(Shape):
|
||||
a: float = 5.0
|
||||
|
||||
@registry.register
|
||||
class Square(Shape):
|
||||
a: float = 3.0
|
||||
|
||||
@registry.register
|
||||
class LargeShape(Shape):
|
||||
inner: Shape
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
class ShapeContainer(Configurable):
|
||||
shape: Shape
|
||||
|
||||
container = ShapeContainer(**get_default_args(ShapeContainer))
|
||||
# This is because ShapeContainer is missing __post_init__
|
||||
with self.assertRaises(AttributeError):
|
||||
container.shape
|
||||
|
||||
class ShapeContainer2(Configurable):
|
||||
x: Shape
|
||||
x_class_type: str = "LargeShape"
|
||||
|
||||
def __post_init__(self):
|
||||
self.x_LargeShape_args.inner_class_type = "Triangle"
|
||||
run_auto_creation(self)
|
||||
|
||||
container2_args = get_default_args(ShapeContainer2)
|
||||
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
|
||||
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
|
||||
# We do not perform expansion that would result in an infinite recursion,
|
||||
# so this member is not present.
|
||||
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
|
||||
container2_args.x_LargeShape_args.inner_Square_args.a += 100
|
||||
container2 = ShapeContainer2(**container2_args)
|
||||
self.assertIsInstance(container2.x, LargeShape)
|
||||
self.assertIsInstance(container2.x.inner, Triangle)
|
||||
self.assertEqual(container2.x.inner.a, 15.0)
|
||||
|
||||
def test_simpleclass_member(self):
|
||||
# Members which are not dataclasses are
|
||||
# tolerated. But it would be nice to be able to
|
||||
# configure them.
|
||||
class Foo:
|
||||
def __init__(self, a=1, b=2):
|
||||
self.a, self.b = a, b
|
||||
|
||||
@dataclass()
|
||||
class Bar:
|
||||
aa: int = 9
|
||||
bb: int = 9
|
||||
|
||||
class Container(Configurable):
|
||||
bar: Bar = Bar()
|
||||
# TODO make this work?
|
||||
# foo: Foo = Foo()
|
||||
fruit: Fruit
|
||||
fruit_class_type: str = "Orange"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
|
||||
container_args = get_default_args(Container)
|
||||
container = Container(**container_args)
|
||||
self.assertIsInstance(container.fruit, Orange)
|
||||
# self.assertIsInstance(container.bar, Bar)
|
||||
|
||||
container_defaulted = Container()
|
||||
container_defaulted.fruit_Pear_args.n_pips += 4
|
||||
|
||||
container_args2 = get_default_args(Container)
|
||||
container = Container(**container_args2)
|
||||
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
|
||||
|
||||
def test_inheritance(self):
|
||||
class FruitBowl(ReplaceableBase):
|
||||
main_fruit: Fruit
|
||||
main_fruit_class_type: str = "Orange"
|
||||
|
||||
def __post_init__(self):
|
||||
raise ValueError("This doesn't get called")
|
||||
|
||||
class LargeFruitBowl(FruitBowl):
|
||||
extra_fruit: Fruit
|
||||
extra_fruit_class_type: str = "Kiwi"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
large_args = get_default_args(LargeFruitBowl)
|
||||
self.assertNotIn("extra_fruit", large_args)
|
||||
self.assertNotIn("main_fruit", large_args)
|
||||
large = LargeFruitBowl(**large_args)
|
||||
self.assertIsInstance(large.main_fruit, Orange)
|
||||
self.assertIsInstance(large.extra_fruit, Kiwi)
|
||||
|
||||
def test_inheritance2(self):
|
||||
# This is a case where a class could contain an instance
|
||||
# of a subclass, which is ignored.
|
||||
class Parent(ReplaceableBase):
|
||||
pass
|
||||
|
||||
class Main(Configurable):
|
||||
parent: Parent
|
||||
# Note - no __post__init__
|
||||
|
||||
@registry.register
|
||||
class Derived(Parent, Main):
|
||||
pass
|
||||
|
||||
args = get_default_args(Main)
|
||||
# Derived has been ignored in processing Main.
|
||||
self.assertCountEqual(args.keys(), ["parent_class_type"])
|
||||
|
||||
main = Main(**args)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
|
||||
run_auto_creation(main)
|
||||
|
||||
main.parent_class_type = "Derived"
|
||||
# Illustrates that a dict works fine instead of a DictConfig.
|
||||
main.parent_Derived_args = {}
|
||||
with self.assertRaises(AttributeError):
|
||||
main.parent
|
||||
run_auto_creation(main)
|
||||
self.assertIsInstance(main.parent, Derived)
|
||||
|
||||
def test_redefine(self):
|
||||
class FruitBowl(ReplaceableBase):
|
||||
main_fruit: Fruit
|
||||
main_fruit_class_type: str = "Grape"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
@registry.register
|
||||
@dataclass
|
||||
class Grape(Fruit):
|
||||
large: bool = False
|
||||
|
||||
def get_color(self):
|
||||
return "red"
|
||||
|
||||
def __post_init__(self):
|
||||
raise ValueError("This doesn't get called")
|
||||
|
||||
bowl_args = get_default_args(FruitBowl)
|
||||
|
||||
@registry.register
|
||||
@dataclass
|
||||
class Grape(Fruit): # noqa: F811
|
||||
large: bool = True
|
||||
|
||||
def get_color(self):
|
||||
return "green"
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "New implementation of Grape is being chosen."
|
||||
):
|
||||
bowl = FruitBowl(**bowl_args)
|
||||
self.assertIsInstance(bowl.main_fruit, Grape)
|
||||
|
||||
# Redefining the same class won't help with defaults because encoded in args
|
||||
self.assertEqual(bowl.main_fruit.large, False)
|
||||
|
||||
# But the override worked.
|
||||
self.assertEqual(bowl.main_fruit.get_color(), "green")
|
||||
|
||||
# 2. Try redefining without the dataclass modifier
|
||||
# This relies on the fact that default creation processes the class.
|
||||
# (otherwise incomprehensible messages)
|
||||
@registry.register
|
||||
class Grape(Fruit): # noqa: F811
|
||||
large: bool = True
|
||||
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "New implementation of Grape is being chosen."
|
||||
):
|
||||
bowl = FruitBowl(**bowl_args)
|
||||
|
||||
# 3. Adding a new class doesn't get picked up, because the first
|
||||
# get_default_args call has frozen FruitBowl. This is intrinsic to
|
||||
# the way dataclass and expand_args_fields work in-place but
|
||||
# expand_args_fields is not pure - it depends on the registry.
|
||||
@registry.register
|
||||
class Fig(Fruit):
|
||||
pass
|
||||
|
||||
bowl_args2 = get_default_args(FruitBowl)
|
||||
self.assertIn("main_fruit_Grape_args", bowl_args2)
|
||||
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
|
||||
|
||||
# TODO Is it possible to make this work?
|
||||
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
|
||||
# bowl_args2.main_fruit_class_type = "Fig"
|
||||
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
|
||||
|
||||
# Note that it is possible to use Fig if you can set
|
||||
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
|
||||
# before run_auto_creation happens. See test_inheritance2
|
||||
# for an example.
|
||||
|
||||
def test_no_replacement(self):
|
||||
# Test of Configurables without ReplaceableBase
|
||||
class A(Configurable):
|
||||
n: int = 9
|
||||
|
||||
class B(Configurable):
|
||||
a: A
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
class C(Configurable):
|
||||
b: B
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
c_args = get_default_args(C)
|
||||
c = C(**c_args)
|
||||
self.assertIsInstance(c.b.a, A)
|
||||
self.assertEqual(c.b.a.n, 9)
|
||||
|
||||
def test_doc(self):
|
||||
# The case in the docstring.
|
||||
class A(ReplaceableBase):
|
||||
k: int = 1
|
||||
|
||||
@registry.register
|
||||
class A1(A):
|
||||
m: int = 3
|
||||
|
||||
@registry.register
|
||||
class A2(A):
|
||||
n: str = "2"
|
||||
|
||||
class B(Configurable):
|
||||
a: A
|
||||
a_class_type: str = "A2"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
b_args = get_default_args(B)
|
||||
self.assertNotIn("a", b_args)
|
||||
b = B(**b_args)
|
||||
self.assertEqual(b.a.n, "2")
|
||||
|
||||
def test_raw_types(self):
|
||||
@dataclass
|
||||
class MyDataclass:
|
||||
int_field: int = 0
|
||||
none_field: Optional[int] = None
|
||||
float_field: float = 9.3
|
||||
bool_field: bool = True
|
||||
tuple_field: tuple = (3, True, "j")
|
||||
|
||||
class SimpleClass:
|
||||
def __init__(self, tuple_member_=(3, 4)):
|
||||
self.tuple_member = tuple_member_
|
||||
|
||||
def get_tuple(self):
|
||||
return self.tuple_member
|
||||
|
||||
def f(*, a: int = 3, b: str = "kj"):
|
||||
self.assertEqual(a, 3)
|
||||
self.assertEqual(b, "kj")
|
||||
|
||||
class C(Configurable):
|
||||
simple: DictConfig = get_default_args_field(SimpleClass)
|
||||
# simple2: SimpleClass2 = SimpleClass2()
|
||||
mydata: DictConfig = get_default_args_field(MyDataclass)
|
||||
a_tuple: Tuple[float] = (4.0, 3.0)
|
||||
f_args: DictConfig = get_default_args_field(f)
|
||||
|
||||
args = get_default_args(C)
|
||||
c = C(**args)
|
||||
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
|
||||
|
||||
mydata = MyDataclass(**c.mydata)
|
||||
simple = SimpleClass(**c.simple)
|
||||
|
||||
# OmegaConf converts tuples to ListConfigs (which act like lists).
|
||||
self.assertEqual(simple.get_tuple(), [3, 4])
|
||||
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
|
||||
self.assertEqual(c.a_tuple, [4.0, 3.0])
|
||||
self.assertTrue(isinstance(c.a_tuple, ListConfig))
|
||||
self.assertEqual(mydata.tuple_field, (3, True, "j"))
|
||||
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
|
||||
f(**c.f_args)
|
||||
|
||||
def test_irrelevant_bases(self):
|
||||
class NotADataclass:
|
||||
# Like torch.nn.Module, this class contains annotations
|
||||
# but is not designed to be dataclass'd.
|
||||
# This test ensures that such classes, when inherited fron,
|
||||
# are not accidentally expand_args_fields.
|
||||
a: int = 9
|
||||
b: int
|
||||
|
||||
class LeftConfigured(Configurable, NotADataclass):
|
||||
left: int = 1
|
||||
|
||||
class RightConfigured(NotADataclass, Configurable):
|
||||
right: int = 2
|
||||
|
||||
class Outer(Configurable):
|
||||
left: LeftConfigured
|
||||
right: RightConfigured
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
outer = Outer(**get_default_args(Outer))
|
||||
self.assertEqual(outer.left.left, 1)
|
||||
self.assertEqual(outer.right.right, 2)
|
||||
with self.assertRaisesRegex(TypeError, "non-default argument"):
|
||||
dataclass(NotADataclass)
|
||||
|
||||
def test_unprocessed(self):
|
||||
# behavior of Configurable classes which need processing in __new__,
|
||||
class Unprocessed(Configurable):
|
||||
a: int = 9
|
||||
|
||||
class UnprocessedReplaceable(ReplaceableBase):
|
||||
a: int = 1
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
||||
Unprocessed()
|
||||
with self.assertWarnsRegex(UserWarning, "must be processed"):
|
||||
UnprocessedReplaceable()
|
||||
|
||||
def test_enum(self):
|
||||
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
|
||||
# are in use.
|
||||
|
||||
class A(Enum):
|
||||
B1 = "b1"
|
||||
B2 = "b2"
|
||||
|
||||
class C(Configurable):
|
||||
a: A = A.B1
|
||||
|
||||
base = get_default_args(C)
|
||||
replaced = OmegaConf.merge(base, {"a": "B2"})
|
||||
self.assertEqual(replaced.a, A.B2)
|
||||
with self.assertRaises(ValidationError):
|
||||
# You can't use a value which is not one of the
|
||||
# choices, even if it is the str representation
|
||||
# of one of the choices.
|
||||
OmegaConf.merge(base, {"a": "b2"})
|
||||
|
||||
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
|
||||
self.assertEqual(remerged.a, A.B1)
|
||||
|
||||
def test_remove_unused_components(self):
|
||||
struct = get_default_args(MainTest)
|
||||
struct.n_ids = 32
|
||||
struct.the_fruit_class_type = "Pear"
|
||||
struct.the_second_fruit_class_type = "Banana"
|
||||
remove_unused_components(struct)
|
||||
expected_keys = [
|
||||
"n_ids",
|
||||
"n_reps",
|
||||
"the_fruit_Pear_args",
|
||||
"the_fruit_class_type",
|
||||
"the_second_fruit_Banana_args",
|
||||
"the_second_fruit_class_type",
|
||||
]
|
||||
expected_yaml = textwrap.dedent(
|
||||
"""\
|
||||
n_ids: 32
|
||||
n_reps: 8
|
||||
the_fruit_class_type: Pear
|
||||
the_fruit_Pear_args:
|
||||
n_pips: 13
|
||||
the_second_fruit_class_type: Banana
|
||||
the_second_fruit_Banana_args:
|
||||
pips: ???
|
||||
spots: ???
|
||||
bananame: ???
|
||||
"""
|
||||
)
|
||||
self.assertEqual(sorted(struct.keys()), expected_keys)
|
||||
|
||||
# Check that struct is what we expect
|
||||
expected = OmegaConf.create(expected_yaml)
|
||||
self.assertEqual(struct, expected)
|
||||
|
||||
# Check that we get what we expect when writing to yaml.
|
||||
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
|
||||
|
||||
main = MainTest(**struct)
|
||||
instance_data = OmegaConf.structured(main)
|
||||
remove_unused_components(instance_data)
|
||||
self.assertEqual(sorted(instance_data.keys()), expected_keys)
|
||||
self.assertEqual(instance_data, expected)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class MockDataclass:
|
||||
field_no_default: int
|
||||
field_primitive_type: int = 42
|
||||
field_reference_type: List[int] = field(default_factory=lambda: [])
|
||||
|
||||
|
||||
class MockClassWithInit: # noqa: B903
|
||||
def __init__(
|
||||
self,
|
||||
field_no_default: int,
|
||||
field_primitive_type: int = 42,
|
||||
field_reference_type: List[int] = [], # noqa: B006
|
||||
):
|
||||
self.field_no_default = field_no_default
|
||||
self.field_primitive_type = field_primitive_type
|
||||
self.field_reference_type = field_reference_type
|
||||
|
||||
|
||||
class TestRawClasses(unittest.TestCase):
|
||||
def test_get_default_args(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
inst = cls(field_no_default=0)
|
||||
dataclass_defaults.field_no_default = 0
|
||||
for name, val in dataclass_defaults.items():
|
||||
self.assertTrue(hasattr(inst, name))
|
||||
self.assertEqual(val, getattr(inst, name))
|
||||
|
||||
def test_get_default_args_readonly(self):
|
||||
for cls in [MockDataclass, MockClassWithInit]:
|
||||
dataclass_defaults = get_default_args(cls)
|
||||
dataclass_defaults["field_reference_type"].append(13)
|
||||
inst = cls(field_no_default=0)
|
||||
self.assertEqual(inst.field_reference_type, [])
|
||||
81
tests/implicitron/test_config_use.py
Normal file
81
tests/implicitron/test_config_use.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.base import GenericModel
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
)
|
||||
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
|
||||
NeuralRadianceFieldImplicitFunction,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
|
||||
AngleWeightedIdentityFeatureAggregator,
|
||||
AngleWeightedReductionFeatureAggregator,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
get_default_args,
|
||||
remove_unused_components,
|
||||
)
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import get_tests_dir
|
||||
else:
|
||||
from tests.common_testing import get_tests_dir
|
||||
|
||||
DATA_DIR = get_tests_dir() / "implicitron/data"
|
||||
DEBUG: bool = False
|
||||
|
||||
# Tests the use of the config system in implicitron
|
||||
|
||||
|
||||
class TestGenericModel(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.maxDiff = None
|
||||
|
||||
def test_create_gm(self):
|
||||
args = get_default_args(GenericModel)
|
||||
gm = GenericModel(**args)
|
||||
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
|
||||
self.assertIsInstance(
|
||||
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
|
||||
)
|
||||
self.assertIsInstance(
|
||||
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
|
||||
)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
self.assertFalse(hasattr(gm, "image_feature_extractor"))
|
||||
|
||||
def test_create_gm_overrides(self):
|
||||
args = get_default_args(GenericModel)
|
||||
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
|
||||
args.implicit_function_class_type = "IdrFeatureField"
|
||||
args.renderer_class_type = "LSTMRenderer"
|
||||
gm = GenericModel(**args)
|
||||
self.assertIsInstance(gm.renderer, LSTMRenderer)
|
||||
self.assertIsInstance(
|
||||
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
|
||||
)
|
||||
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
|
||||
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
|
||||
self.assertFalse(hasattr(gm, "implicit_function"))
|
||||
|
||||
instance_args = OmegaConf.structured(gm)
|
||||
remove_unused_components(instance_args)
|
||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||
191
tests/implicitron/test_dataset_visualize.py
Normal file
191
tests/implicitron/test_dataset_visualize.py
Normal file
@@ -0,0 +1,191 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
|
||||
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
|
||||
from pytorch3d.vis.plotly_vis import plot_scene
|
||||
from visdom import Visdom
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from .common_resources import get_skateboard_data
|
||||
else:
|
||||
from common_resources import get_skateboard_data
|
||||
|
||||
|
||||
class TestDatasetVisualize(unittest.TestCase):
|
||||
def setUp(self):
|
||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||
raise unittest.SkipTest("Visdom not available")
|
||||
category = "skateboard"
|
||||
stack = contextlib.ExitStack()
|
||||
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
|
||||
self.addCleanup(stack.close)
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 256
|
||||
self.datasets = {
|
||||
"simple": ImplicitronDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
dataset_root=dataset_root,
|
||||
image_height=self.image_size,
|
||||
image_width=self.image_size,
|
||||
box_crop=True,
|
||||
load_point_clouds=True,
|
||||
path_manager=path_manager,
|
||||
),
|
||||
"nonsquare": ImplicitronDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
dataset_root=dataset_root,
|
||||
image_height=self.image_size,
|
||||
image_width=self.image_size // 2,
|
||||
box_crop=True,
|
||||
load_point_clouds=True,
|
||||
path_manager=path_manager,
|
||||
),
|
||||
"nocrop": ImplicitronDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
dataset_root=dataset_root,
|
||||
image_height=self.image_size,
|
||||
image_width=self.image_size // 2,
|
||||
box_crop=False,
|
||||
load_point_clouds=True,
|
||||
path_manager=path_manager,
|
||||
),
|
||||
}
|
||||
self.datasets.update(
|
||||
{
|
||||
k + "_newndc": _change_annotations_to_new_ndc(dataset)
|
||||
for k, dataset in self.datasets.items()
|
||||
}
|
||||
)
|
||||
self.visdom = Visdom()
|
||||
if not self.visdom.check_connection():
|
||||
print("Visdom server not running! Disabling visdom visualizations.")
|
||||
self.visdom = None
|
||||
|
||||
def _render_one_pointcloud(self, point_cloud, cameras, render_size):
|
||||
(_image_render, _, _) = render_point_cloud_pytorch3d(
|
||||
cameras,
|
||||
point_cloud,
|
||||
render_size=render_size,
|
||||
point_radius=1e-2,
|
||||
topk=10,
|
||||
bg_color=0.0,
|
||||
)
|
||||
return _image_render.clamp(0.0, 1.0)
|
||||
|
||||
def test_one(self):
|
||||
"""Test dataset visualization."""
|
||||
for max_frames in (16, -1):
|
||||
for load_dataset_point_cloud in (True, False):
|
||||
for dataset_key in self.datasets:
|
||||
self._gen_and_render_pointcloud(
|
||||
max_frames, load_dataset_point_cloud, dataset_key
|
||||
)
|
||||
|
||||
def _gen_and_render_pointcloud(
|
||||
self, max_frames, load_dataset_point_cloud, dataset_key
|
||||
):
|
||||
dataset = self.datasets[dataset_key]
|
||||
# load the point cloud of the first sequence
|
||||
sequence_show = list(dataset.seq_annots.keys())[0]
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
|
||||
dataset,
|
||||
sequence_name=sequence_show,
|
||||
mask_points=True,
|
||||
max_frames=max_frames,
|
||||
num_workers=10,
|
||||
load_dataset_point_cloud=load_dataset_point_cloud,
|
||||
)
|
||||
|
||||
# render on gpu
|
||||
point_cloud = point_cloud.to(device)
|
||||
cameras = sequence_frame_data.camera.to(device)
|
||||
|
||||
# render the point_cloud from the viewpoint of loaded cameras
|
||||
images_render = torch.cat(
|
||||
[
|
||||
self._render_one_pointcloud(
|
||||
point_cloud,
|
||||
cameras[frame_i],
|
||||
(
|
||||
dataset.image_height,
|
||||
dataset.image_width,
|
||||
),
|
||||
)
|
||||
for frame_i in range(len(cameras))
|
||||
]
|
||||
).cpu()
|
||||
images_gt_and_render = torch.cat(
|
||||
[sequence_frame_data.image_rgb, images_render], dim=3
|
||||
)
|
||||
|
||||
imfile = os.path.join(
|
||||
os.path.split(os.path.abspath(__file__))[0],
|
||||
"test_dataset_visualize"
|
||||
+ f"_max_frames={max_frames}"
|
||||
+ f"_load_pcl={load_dataset_point_cloud}.png",
|
||||
)
|
||||
print(f"Exporting image {imfile}.")
|
||||
torchvision.utils.save_image(images_gt_and_render, imfile, nrow=2)
|
||||
|
||||
if self.visdom is not None:
|
||||
test_name = f"{max_frames}_{load_dataset_point_cloud}_{dataset_key}"
|
||||
self.visdom.images(
|
||||
images_gt_and_render,
|
||||
env="test_dataset_visualize",
|
||||
win=f"pcl_renders_{test_name}",
|
||||
opts={"title": f"pcl_renders_{test_name}"},
|
||||
)
|
||||
plotlyplot = plot_scene(
|
||||
{
|
||||
"scene_batch": {
|
||||
"cameras": cameras,
|
||||
"point_cloud": point_cloud,
|
||||
}
|
||||
},
|
||||
camera_scale=1.0,
|
||||
pointcloud_max_points=10000,
|
||||
pointcloud_marker_size=1.0,
|
||||
)
|
||||
self.visdom.plotlyplot(
|
||||
plotlyplot,
|
||||
env="test_dataset_visualize",
|
||||
win=f"pcl_{test_name}",
|
||||
)
|
||||
|
||||
|
||||
def _change_annotations_to_new_ndc(dataset):
|
||||
dataset = copy.deepcopy(dataset)
|
||||
for frame in dataset.frame_annots:
|
||||
vp = frame["frame_annotation"].viewpoint
|
||||
vp.intrinsics_format = "ndc_isotropic"
|
||||
# this assume the focal length to be equal on x and y (ok for a test)
|
||||
max_flength = max(vp.focal_length)
|
||||
vp.principal_point = (
|
||||
vp.principal_point[0] * max_flength / vp.focal_length[0],
|
||||
vp.principal_point[1] * max_flength / vp.focal_length[1],
|
||||
)
|
||||
vp.focal_length = (
|
||||
max_flength,
|
||||
max_flength,
|
||||
)
|
||||
|
||||
return dataset
|
||||
48
tests/implicitron/test_eval_cameras.py
Normal file
48
tests/implicitron/test_eval_cameras.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.eval_video_trajectory import (
|
||||
generate_eval_video_cameras,
|
||||
)
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
|
||||
from pytorch3d.transforms import axis_angle_to_matrix
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import TestCaseMixin
|
||||
else:
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestEvalCameras(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
|
||||
def test_circular(self):
|
||||
n_train_cameras = 10
|
||||
n_test_cameras = 100
|
||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||
amplitude = 0.01
|
||||
R_jiggled = torch.bmm(
|
||||
R, axis_angle_to_matrix(torch.rand(n_train_cameras, 3) * amplitude)
|
||||
)
|
||||
cameras_train = PerspectiveCameras(R=R_jiggled, T=T)
|
||||
cameras_test = generate_eval_video_cameras(
|
||||
cameras_train, trajectory_type="circular_lsq_fit", trajectory_scale=1.0
|
||||
)
|
||||
|
||||
positions_test = cameras_test.get_camera_center()
|
||||
center = positions_test.mean(0)
|
||||
self.assertClose(center, torch.zeros(3), atol=0.1)
|
||||
self.assertClose(
|
||||
(positions_test - center).norm(dim=[1]),
|
||||
torch.ones(n_test_cameras),
|
||||
atol=0.1,
|
||||
)
|
||||
290
tests/implicitron/test_evaluation.py
Normal file
290
tests/implicitron/test_evaluation.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import math
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import lpips
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from .common_resources import get_skateboard_data, provide_lpips_vgg
|
||||
else:
|
||||
from common_resources import get_skateboard_data, provide_lpips_vgg
|
||||
|
||||
|
||||
class TestEvaluation(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# initialize evaluation dataset/dataloader
|
||||
torch.manual_seed(42)
|
||||
|
||||
stack = contextlib.ExitStack()
|
||||
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
|
||||
self.addCleanup(stack.close)
|
||||
|
||||
category = "skateboard"
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 256
|
||||
self.dataset = ImplicitronDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
dataset_root=dataset_root,
|
||||
image_height=self.image_size,
|
||||
image_width=self.image_size,
|
||||
box_crop=True,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
self.bg_color = 0.0
|
||||
|
||||
# init the lpips model for eval
|
||||
provide_lpips_vgg()
|
||||
self.lpips_model = lpips.LPIPS(net="vgg")
|
||||
|
||||
def test_eval_depth(self):
|
||||
"""
|
||||
Check that eval_depth correctly masks errors and that, for get_best_scale=True,
|
||||
the error with scaled prediction equals the error without scaling the
|
||||
predicted depth. Finally, test that the error values are as expected
|
||||
for prediction and gt differing by a constant offset.
|
||||
"""
|
||||
gt = (torch.randn(10, 1, 300, 400, device="cuda") * 5.0).clamp(0.0)
|
||||
mask = (torch.rand_like(gt) > 0.5).type_as(gt)
|
||||
|
||||
for diff in 10 ** torch.linspace(-5, 0, 6):
|
||||
for crop in (0, 5):
|
||||
|
||||
pred = gt + (torch.rand_like(gt) - 0.5) * 2 * diff
|
||||
|
||||
# scaled prediction test
|
||||
mse_depth, abs_depth = eval_depth(
|
||||
pred,
|
||||
gt,
|
||||
crop=crop,
|
||||
mask=mask,
|
||||
get_best_scale=True,
|
||||
)
|
||||
mse_depth_scale, abs_depth_scale = eval_depth(
|
||||
pred * 10.0,
|
||||
gt,
|
||||
crop=crop,
|
||||
mask=mask,
|
||||
get_best_scale=True,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
float(mse_depth.sum()), float(mse_depth_scale.sum()), delta=1e-4
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
float(abs_depth.sum()), float(abs_depth_scale.sum()), delta=1e-4
|
||||
)
|
||||
|
||||
# error masking test
|
||||
pred_masked_err = gt + (torch.rand_like(gt) + diff) * (1 - mask)
|
||||
mse_depth_masked, abs_depth_masked = eval_depth(
|
||||
pred_masked_err,
|
||||
gt,
|
||||
crop=crop,
|
||||
mask=mask,
|
||||
get_best_scale=True,
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
float(mse_depth_masked.sum()), float(0.0), delta=1e-4
|
||||
)
|
||||
self.assertAlmostEqual(
|
||||
float(abs_depth_masked.sum()), float(0.0), delta=1e-4
|
||||
)
|
||||
mse_depth_unmasked, abs_depth_unmasked = eval_depth(
|
||||
pred_masked_err,
|
||||
gt,
|
||||
crop=crop,
|
||||
mask=1 - mask,
|
||||
get_best_scale=True,
|
||||
)
|
||||
self.assertGreater(
|
||||
float(mse_depth_unmasked.sum()),
|
||||
float(diff ** 2),
|
||||
)
|
||||
self.assertGreater(
|
||||
float(abs_depth_unmasked.sum()),
|
||||
float(diff),
|
||||
)
|
||||
|
||||
# tests with constant error
|
||||
pred_fix_diff = gt + diff * mask
|
||||
for _mask_gt in (mask, None):
|
||||
mse_depth_fix_diff, abs_depth_fix_diff = eval_depth(
|
||||
pred_fix_diff,
|
||||
gt,
|
||||
crop=crop,
|
||||
mask=_mask_gt,
|
||||
get_best_scale=False,
|
||||
)
|
||||
if _mask_gt is not None:
|
||||
expected_err_abs = diff
|
||||
expected_err_mse = diff ** 2
|
||||
else:
|
||||
err_mask = (gt > 0.0).float() * mask
|
||||
if crop > 0:
|
||||
err_mask = err_mask[:, :, crop:-crop, crop:-crop]
|
||||
gt_cropped = gt[:, :, crop:-crop, crop:-crop]
|
||||
else:
|
||||
gt_cropped = gt
|
||||
gt_mass = (gt_cropped > 0.0).float().sum(dim=(1, 2, 3))
|
||||
expected_err_abs = (
|
||||
diff * err_mask.sum(dim=(1, 2, 3)) / (gt_mass)
|
||||
)
|
||||
expected_err_mse = diff * expected_err_abs
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
abs_depth_fix_diff,
|
||||
expected_err_abs * torch.ones_like(abs_depth_fix_diff),
|
||||
atol=1e-4,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
mse_depth_fix_diff,
|
||||
expected_err_mse * torch.ones_like(mse_depth_fix_diff),
|
||||
atol=1e-4,
|
||||
)
|
||||
)
|
||||
|
||||
def test_psnr(self):
|
||||
"""
|
||||
Compare against opencv and check that the psnr is above
|
||||
the minimum possible value.
|
||||
"""
|
||||
import cv2
|
||||
|
||||
im1 = torch.rand(100, 3, 256, 256).cuda()
|
||||
im1_uint8 = (im1 * 255).to(torch.uint8)
|
||||
im1_rounded = im1_uint8.float() / 255
|
||||
for max_diff in 10 ** torch.linspace(-5, 0, 6):
|
||||
im2 = im1 + (torch.rand_like(im1) - 0.5) * 2 * max_diff
|
||||
im2 = im2.clamp(0.0, 1.0)
|
||||
im2_uint8 = (im2 * 255).to(torch.uint8)
|
||||
im2_rounded = im2_uint8.float() / 255
|
||||
# check that our psnr matches the output of opencv
|
||||
psnr = calc_psnr(im1_rounded, im2_rounded)
|
||||
# some versions of cv2 can only take uint8 input
|
||||
psnr_cv2 = cv2.PSNR(
|
||||
im1_uint8.cpu().numpy(),
|
||||
im2_uint8.cpu().numpy(),
|
||||
)
|
||||
self.assertAlmostEqual(float(psnr), float(psnr_cv2), delta=1e-4)
|
||||
# check that all PSNRs are bigger than the minimum possible PSNR
|
||||
max_mse = max_diff ** 2
|
||||
min_psnr = 10 * math.log10(1.0 / max_mse)
|
||||
for _im1, _im2 in zip(im1, im2):
|
||||
_psnr = calc_psnr(_im1, _im2)
|
||||
self.assertGreaterEqual(float(_psnr) + 1e-6, min_psnr)
|
||||
|
||||
def _one_sequence_test(
|
||||
self,
|
||||
seq_dataset,
|
||||
n_batches=2,
|
||||
min_batch_size=5,
|
||||
max_batch_size=10,
|
||||
):
|
||||
# form a list of random batches
|
||||
batch_indices = []
|
||||
for _ in range(n_batches):
|
||||
batch_size = torch.randint(
|
||||
low=min_batch_size, high=max_batch_size, size=(1,)
|
||||
)
|
||||
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
seq_dataset,
|
||||
# batch_size=1,
|
||||
shuffle=False,
|
||||
batch_sampler=batch_indices,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
|
||||
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
|
||||
model.cuda()
|
||||
self.lpips_model.cuda()
|
||||
|
||||
for frame_data in loader:
|
||||
self.assertIsNone(frame_data.frame_type)
|
||||
self.assertIsNotNone(frame_data.image_rgb)
|
||||
# override the frame_type
|
||||
frame_data.frame_type = [
|
||||
"train_unseen",
|
||||
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
|
||||
]
|
||||
|
||||
# move frame_data to gpu
|
||||
frame_data = dataclass_to_cuda_(frame_data)
|
||||
preds = model(**dataclasses.asdict(frame_data))
|
||||
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
eval_result = eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
# Make a terribly bad NVS prediction and check that this is worse
|
||||
# than the DBIR prediction.
|
||||
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
|
||||
nvs_prediction_bad.depth_render += (
|
||||
torch.randn_like(nvs_prediction.depth_render) * 100.0
|
||||
)
|
||||
nvs_prediction_bad.image_render += (
|
||||
torch.randn_like(nvs_prediction.image_render) * 100.0
|
||||
)
|
||||
nvs_prediction_bad.mask_render = (
|
||||
torch.randn_like(nvs_prediction.mask_render) > 0.0
|
||||
).float()
|
||||
eval_result_bad = eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction_bad,
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
lower_better = {
|
||||
"psnr": False,
|
||||
"psnr_fg": False,
|
||||
"depth_abs_fg": True,
|
||||
"iou": False,
|
||||
"rgb_l1": True,
|
||||
"rgb_l1_fg": True,
|
||||
}
|
||||
|
||||
for metric in lower_better.keys():
|
||||
m_better = eval_result[metric]
|
||||
m_worse = eval_result_bad[metric]
|
||||
if m_better != m_better or m_worse != m_worse:
|
||||
continue # metric is missing, i.e. NaN
|
||||
_assert = (
|
||||
self.assertLessEqual
|
||||
if lower_better[metric]
|
||||
else self.assertGreaterEqual
|
||||
)
|
||||
_assert(m_better, m_worse)
|
||||
|
||||
def test_full_eval(self, n_sequences=5):
|
||||
"""Test evaluation."""
|
||||
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
|
||||
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
|
||||
self._one_sequence_test(seq_dataset)
|
||||
67
tests/implicitron/test_forward_pass.py
Normal file
67
tests/implicitron/test_forward_pass.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.base import GenericModel
|
||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
|
||||
|
||||
|
||||
class TestGenericModel(unittest.TestCase):
|
||||
def test_gm(self):
|
||||
# Simple test of a forward pass of the default GenericModel.
|
||||
device = torch.device("cuda:1")
|
||||
expand_args_fields(GenericModel)
|
||||
model = GenericModel()
|
||||
model.to(device)
|
||||
|
||||
n_train_cameras = 2
|
||||
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
|
||||
cameras = PerspectiveCameras(R=R, T=T, device=device)
|
||||
|
||||
# TODO: make these default to None?
|
||||
defaulted_args = {
|
||||
"fg_probability": None,
|
||||
"depth_map": None,
|
||||
"mask_crop": None,
|
||||
"sequence_name": None,
|
||||
}
|
||||
|
||||
with self.assertWarnsRegex(UserWarning, "No main objective found"):
|
||||
model(
|
||||
camera=cameras,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
**defaulted_args,
|
||||
image_rgb=None,
|
||||
)
|
||||
target_image_rgb = torch.rand(
|
||||
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
|
||||
device=device,
|
||||
)
|
||||
train_preds = model(
|
||||
camera=cameras,
|
||||
evaluation_mode=EvaluationMode.TRAINING,
|
||||
image_rgb=target_image_rgb,
|
||||
**defaulted_args,
|
||||
)
|
||||
self.assertGreater(train_preds["objective"].item(), 0)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
# TODO: perhaps this warning should be skipped in eval mode?
|
||||
with self.assertWarnsRegex(UserWarning, "No main objective found"):
|
||||
eval_preds = model(
|
||||
camera=cameras[0],
|
||||
**defaulted_args,
|
||||
image_rgb=None,
|
||||
)
|
||||
self.assertEqual(
|
||||
eval_preds["images_render"].shape,
|
||||
(1, 3, model.render_image_height, model.render_image_width),
|
||||
)
|
||||
63
tests/implicitron/test_ray_point_refiner.py
Normal file
63
tests/implicitron/test_ray_point_refiner.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
|
||||
from pytorch3d.renderer import RayBundle
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import TestCaseMixin
|
||||
else:
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
|
||||
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
|
||||
def test_simple(self):
|
||||
length = 15
|
||||
n_pts_per_ray = 10
|
||||
|
||||
for add_input_samples in [False, True]:
|
||||
ray_point_refiner = RayPointRefiner(
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
random_sampling=False,
|
||||
add_input_samples=add_input_samples,
|
||||
)
|
||||
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
|
||||
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
|
||||
weights = torch.ones(3, 25, length)
|
||||
refined = ray_point_refiner(bundle, weights)
|
||||
|
||||
self.assertIsNone(refined.directions)
|
||||
self.assertIsNone(refined.origins)
|
||||
self.assertIsNone(refined.xys)
|
||||
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
|
||||
expected = expected.expand(3, 25, n_pts_per_ray)
|
||||
if add_input_samples:
|
||||
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0]
|
||||
else:
|
||||
full_expected = expected
|
||||
self.assertClose(refined.lengths, full_expected)
|
||||
|
||||
ray_point_refiner_random = RayPointRefiner(
|
||||
n_pts_per_ray=n_pts_per_ray,
|
||||
random_sampling=True,
|
||||
add_input_samples=add_input_samples,
|
||||
)
|
||||
refined_random = ray_point_refiner_random(bundle, weights)
|
||||
lengths_random = refined_random.lengths
|
||||
self.assertEqual(lengths_random.shape, full_expected.shape)
|
||||
if not add_input_samples:
|
||||
self.assertGreater(lengths_random.min().item(), 0.5)
|
||||
self.assertLess(lengths_random.max().item(), length - 1.5)
|
||||
|
||||
# Check sorted
|
||||
self.assertTrue(
|
||||
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
|
||||
)
|
||||
114
tests/implicitron/test_srn.py
Normal file
114
tests/implicitron/test_srn.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
|
||||
SRNHyperNetImplicitFunction,
|
||||
SRNImplicitFunction,
|
||||
SRNPixelGenerator,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from pytorch3d.renderer import RayBundle
|
||||
|
||||
|
||||
if os.environ.get("FB_TEST", False):
|
||||
from common_testing import TestCaseMixin
|
||||
else:
|
||||
from tests.common_testing import TestCaseMixin
|
||||
|
||||
_BATCH_SIZE: int = 3
|
||||
_N_RAYS: int = 100
|
||||
_N_POINTS_ON_RAY: int = 10
|
||||
|
||||
|
||||
class TestSRN(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
torch.manual_seed(42)
|
||||
get_default_args(SRNHyperNetImplicitFunction)
|
||||
get_default_args(SRNImplicitFunction)
|
||||
|
||||
def test_pixel_generator(self):
|
||||
SRNPixelGenerator()
|
||||
|
||||
def _get_bundle(self, *, device) -> RayBundle:
|
||||
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
||||
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
|
||||
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
|
||||
bundle = RayBundle(
|
||||
lengths=lengths, origins=origins, directions=directions, xys=None
|
||||
)
|
||||
return bundle
|
||||
|
||||
def test_srn_implicit_function(self):
|
||||
implicit_function = SRNImplicitFunction()
|
||||
device = torch.device("cpu")
|
||||
bundle = self._get_bundle(device=device)
|
||||
rays_densities, rays_colors = implicit_function(bundle)
|
||||
out_features = implicit_function.raymarch_function.out_features
|
||||
self.assertEqual(
|
||||
rays_densities.shape,
|
||||
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
|
||||
)
|
||||
self.assertIsNone(rays_colors)
|
||||
|
||||
def test_srn_hypernet_implicit_function(self):
|
||||
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
|
||||
latent_dim_hypernet = 39
|
||||
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
|
||||
device = torch.device("cuda:0")
|
||||
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
|
||||
implicit_function.to(device)
|
||||
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
||||
bundle = self._get_bundle(device=device)
|
||||
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
|
||||
out_features = implicit_function.hypernet.out_features
|
||||
self.assertEqual(
|
||||
rays_densities.shape,
|
||||
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
|
||||
)
|
||||
self.assertIsNone(rays_colors)
|
||||
|
||||
def test_srn_hypernet_implicit_function_optim(self):
|
||||
# Test optimization loop, requiring that the cache is properly
|
||||
# cleared in new_args_bound
|
||||
latent_dim_hypernet = 39
|
||||
hyper_args = {"latent_dim_hypernet": latent_dim_hypernet}
|
||||
device = torch.device("cuda:0")
|
||||
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
|
||||
bundle = self._get_bundle(device=device)
|
||||
|
||||
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
|
||||
implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
|
||||
implicit_function.to(device)
|
||||
implicit_function2.to(device)
|
||||
|
||||
wrapper = ImplicitFunctionWrapper(implicit_function)
|
||||
optimizer = torch.optim.Adam(implicit_function.parameters())
|
||||
for _step in range(3):
|
||||
optimizer.zero_grad()
|
||||
wrapper.bind_args(global_code=global_code)
|
||||
rays_densities, _rays_colors = wrapper(bundle)
|
||||
wrapper.unbind_args()
|
||||
loss = rays_densities.sum()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
wrapper2 = ImplicitFunctionWrapper(implicit_function)
|
||||
optimizer2 = torch.optim.Adam(implicit_function2.parameters())
|
||||
implicit_function2.load_state_dict(implicit_function.state_dict())
|
||||
optimizer2.load_state_dict(optimizer.state_dict())
|
||||
for _step in range(3):
|
||||
optimizer2.zero_grad()
|
||||
wrapper2.bind_args(global_code=global_code)
|
||||
rays_densities, _rays_colors = wrapper2(bundle)
|
||||
wrapper2.unbind_args()
|
||||
loss = rays_densities.sum()
|
||||
loss.backward()
|
||||
optimizer2.step()
|
||||
93
tests/implicitron/test_types.py
Normal file
93
tests/implicitron/test_types.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import dataclasses
|
||||
import unittest
|
||||
from typing import Dict, List, NamedTuple, Tuple
|
||||
|
||||
from pytorch3d.implicitron.dataset import types
|
||||
from pytorch3d.implicitron.dataset.types import FrameAnnotation
|
||||
|
||||
|
||||
class _NT(NamedTuple):
|
||||
annot: FrameAnnotation
|
||||
|
||||
|
||||
class TestDatasetTypes(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.entry = FrameAnnotation(
|
||||
frame_number=23,
|
||||
sequence_name="1",
|
||||
frame_timestamp=1.2,
|
||||
image=types.ImageAnnotation(path="/tmp/1.jpg", size=(224, 224)),
|
||||
mask=types.MaskAnnotation(path="/tmp/1.png", mass=42.0),
|
||||
viewpoint=types.ViewpointAnnotation(
|
||||
R=(
|
||||
(1, 0, 0),
|
||||
(1, 0, 0),
|
||||
(1, 0, 0),
|
||||
),
|
||||
T=(0, 0, 0),
|
||||
principal_point=(100, 100),
|
||||
focal_length=(200, 200),
|
||||
),
|
||||
)
|
||||
|
||||
def test_asdict_rec(self):
|
||||
first = [dataclasses.asdict(self.entry)]
|
||||
second = types._asdict_rec([self.entry])
|
||||
self.assertEqual(first, second)
|
||||
|
||||
def test_parsing(self):
|
||||
"""Test that we handle collections enclosing dataclasses."""
|
||||
|
||||
dct = dataclasses.asdict(self.entry)
|
||||
|
||||
parsed = types._dataclass_from_dict(dct, FrameAnnotation)
|
||||
self.assertEqual(parsed, self.entry)
|
||||
|
||||
# namedtuple
|
||||
parsed = types._dataclass_from_dict(_NT(dct), _NT)
|
||||
self.assertEqual(parsed.annot, self.entry)
|
||||
|
||||
# tuple
|
||||
parsed = types._dataclass_from_dict((dct,), Tuple[FrameAnnotation])
|
||||
self.assertEqual(parsed, (self.entry,))
|
||||
|
||||
# list
|
||||
parsed = types._dataclass_from_dict(
|
||||
[
|
||||
dct,
|
||||
],
|
||||
List[FrameAnnotation],
|
||||
)
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
self.entry,
|
||||
],
|
||||
)
|
||||
|
||||
# dict
|
||||
parsed = types._dataclass_from_dict({"k": dct}, Dict[str, FrameAnnotation])
|
||||
self.assertEqual(parsed, {"k": self.entry})
|
||||
|
||||
def test_parsing_vectorized(self):
|
||||
dct = dataclasses.asdict(self.entry)
|
||||
|
||||
self._compare_with_scalar(dct, FrameAnnotation)
|
||||
self._compare_with_scalar(_NT(dct), _NT)
|
||||
self._compare_with_scalar((dct,), Tuple[FrameAnnotation])
|
||||
self._compare_with_scalar([dct], List[FrameAnnotation])
|
||||
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
|
||||
|
||||
def _compare_with_scalar(self, obj, typeannot, repeat=3):
|
||||
input = [obj] * 3
|
||||
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
|
||||
self.assertEqual(len(input), repeat)
|
||||
gt = types._dataclass_from_dict(obj, typeannot)
|
||||
self.assertTrue(all(res == gt for res in vect_output))
|
||||
270
tests/implicitron/test_viewsampling.py
Normal file
270
tests/implicitron/test_viewsampling.py
Normal file
@@ -0,0 +1,270 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import unittest
|
||||
|
||||
import pytorch3d as pt3d
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields
|
||||
|
||||
|
||||
class TestViewsampling(unittest.TestCase):
|
||||
def setUp(self):
|
||||
torch.manual_seed(42)
|
||||
expand_args_fields(ViewSampler)
|
||||
|
||||
def _init_view_sampler_problem(self, random_masks):
|
||||
"""
|
||||
Generates a view-sampling problem:
|
||||
- 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2'
|
||||
- 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively.
|
||||
- first 50 points in each batch correctly project to the source views,
|
||||
while the remaining 50 do not land in any projection plane.
|
||||
- each source view is labeled with image feature tensors of shape 7x100x50,
|
||||
where all elements of the n-th tensor are set to `n+1`.
|
||||
- the elements of the source view masks are either set to random binary number
|
||||
(if `random_masks==True`), or all set to 1 (`random_masks==False`).
|
||||
- the source view cameras are uniformly distributed on a unit circle
|
||||
in the x-z plane and look at (0,0,0).
|
||||
"""
|
||||
seq_id_camera = ["seq1", "seq1", "seq2", "seq2"]
|
||||
seq_id_pts = ["seq1", "seq2", "seq2"]
|
||||
pts_batch = 3
|
||||
n_pts = 100
|
||||
n_views = 4
|
||||
fdim = 7
|
||||
H = 100
|
||||
W = 50
|
||||
|
||||
# points that land into the projection planes of all cameras
|
||||
pts_inside = (
|
||||
torch.nn.functional.normalize(
|
||||
torch.randn(pts_batch, n_pts // 2, 3, device="cuda"),
|
||||
dim=-1,
|
||||
)
|
||||
* 0.1
|
||||
)
|
||||
|
||||
# move the outside points far above the scene
|
||||
pts_outside = pts_inside.clone()
|
||||
pts_outside[:, :, 1] += 1e8
|
||||
pts = torch.cat([pts_inside, pts_outside], dim=1)
|
||||
|
||||
R, T = pt3d.renderer.look_at_view_transform(
|
||||
dist=1.0,
|
||||
elev=0.0,
|
||||
azim=torch.linspace(0, 360, n_views + 1)[:n_views],
|
||||
degrees=True,
|
||||
device=pts.device,
|
||||
)
|
||||
focal_length = R.new_ones(n_views, 2)
|
||||
principal_point = R.new_zeros(n_views, 2)
|
||||
camera = pt3d.renderer.PerspectiveCameras(
|
||||
R=R,
|
||||
T=T,
|
||||
focal_length=focal_length,
|
||||
principal_point=principal_point,
|
||||
device=pts.device,
|
||||
)
|
||||
|
||||
feats_map = torch.arange(n_views, device=pts.device, dtype=pts.dtype) + 1
|
||||
feats = {"feats": feats_map[:, None, None, None].repeat(1, fdim, H, W)}
|
||||
|
||||
masks = (
|
||||
torch.rand(n_views, 1, H, W, device=pts.device, dtype=pts.dtype) > 0.5
|
||||
).type_as(R)
|
||||
|
||||
if not random_masks:
|
||||
masks[:] = 1.0
|
||||
|
||||
return pts, camera, feats, masks, seq_id_camera, seq_id_pts
|
||||
|
||||
def test_compare_with_naive(self):
|
||||
"""
|
||||
Compares the outputs of the efficient ViewSampler module with a
|
||||
naive implementation.
|
||||
"""
|
||||
|
||||
(
|
||||
pts,
|
||||
camera,
|
||||
feats,
|
||||
masks,
|
||||
seq_id_camera,
|
||||
seq_id_pts,
|
||||
) = self._init_view_sampler_problem(True)
|
||||
|
||||
for masked_sampling in (True, False):
|
||||
feats_sampled_n, masks_sampled_n = _view_sample_naive(
|
||||
pts,
|
||||
seq_id_pts,
|
||||
camera,
|
||||
seq_id_camera,
|
||||
feats,
|
||||
masks,
|
||||
masked_sampling,
|
||||
)
|
||||
# make sure we generate the constructor for ViewSampler
|
||||
expand_args_fields(ViewSampler)
|
||||
view_sampler = ViewSampler(masked_sampling=masked_sampling)
|
||||
feats_sampled, masks_sampled = view_sampler(
|
||||
pts=pts,
|
||||
seq_id_pts=seq_id_pts,
|
||||
camera=camera,
|
||||
seq_id_camera=seq_id_camera,
|
||||
feats=feats,
|
||||
masks=masks,
|
||||
)
|
||||
for k in feats_sampled.keys():
|
||||
self.assertTrue(torch.allclose(feats_sampled[k], feats_sampled_n[k]))
|
||||
self.assertTrue(torch.allclose(masks_sampled, masks_sampled_n))
|
||||
|
||||
def test_viewsampling(self):
|
||||
"""
|
||||
Generates a viewsampling problem with predictable outcome, and compares
|
||||
the ViewSampler's output to the expected result.
|
||||
"""
|
||||
|
||||
(
|
||||
pts,
|
||||
camera,
|
||||
feats,
|
||||
masks,
|
||||
seq_id_camera,
|
||||
seq_id_pts,
|
||||
) = self._init_view_sampler_problem(False)
|
||||
|
||||
expand_args_fields(ViewSampler)
|
||||
|
||||
for masked_sampling in (True, False):
|
||||
|
||||
view_sampler = ViewSampler(masked_sampling=masked_sampling)
|
||||
|
||||
feats_sampled, masks_sampled = view_sampler(
|
||||
pts=pts,
|
||||
seq_id_pts=seq_id_pts,
|
||||
camera=camera,
|
||||
seq_id_camera=seq_id_camera,
|
||||
feats=feats,
|
||||
masks=masks,
|
||||
)
|
||||
|
||||
n_views = camera.R.shape[0]
|
||||
n_pts = pts.shape[1]
|
||||
feat_dim = feats["feats"].shape[1]
|
||||
pts_batch = pts.shape[0]
|
||||
n_pts_away = n_pts // 2
|
||||
|
||||
for pts_i in range(pts_batch):
|
||||
for view_i in range(n_views):
|
||||
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
|
||||
# points / cameras come from different sequences
|
||||
gt_masks = pts.new_zeros(n_pts, 1)
|
||||
gt_feats = pts.new_zeros(n_pts, feat_dim)
|
||||
else:
|
||||
gt_masks = pts.new_ones(n_pts, 1)
|
||||
gt_feats = pts.new_ones(n_pts, feat_dim) * (view_i + 1)
|
||||
gt_feats[n_pts_away:] = 0.0
|
||||
if masked_sampling:
|
||||
gt_masks[n_pts_away:] = 0.0
|
||||
|
||||
for k in feats_sampled:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
feats_sampled[k][pts_i, view_i],
|
||||
gt_feats,
|
||||
)
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
masks_sampled[pts_i, view_i],
|
||||
gt_masks,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _view_sample_naive(
|
||||
pts,
|
||||
seq_id_pts,
|
||||
camera,
|
||||
seq_id_camera,
|
||||
feats,
|
||||
masks,
|
||||
masked_sampling,
|
||||
):
|
||||
"""
|
||||
A naive implementation of the forward pass of ViewSampler.
|
||||
Refer to ViewSampler's docstring for description of the arguments.
|
||||
"""
|
||||
|
||||
pts_batch = pts.shape[0]
|
||||
n_views = camera.R.shape[0]
|
||||
n_pts = pts.shape[1]
|
||||
|
||||
feats_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
|
||||
masks_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
|
||||
|
||||
for pts_i in range(pts_batch):
|
||||
for view_i in range(n_views):
|
||||
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
|
||||
# points/cameras come from different sequences
|
||||
feats_sampled_ = {
|
||||
k: f.new_zeros(n_pts, f.shape[1]) for k, f in feats.items()
|
||||
}
|
||||
masks_sampled_ = masks.new_zeros(n_pts, 1)
|
||||
else:
|
||||
# same sequence of pts and cameras -> sample
|
||||
feats_sampled_, masks_sampled_ = _sample_one_view_naive(
|
||||
camera[view_i],
|
||||
pts[pts_i],
|
||||
{k: f[view_i] for k, f in feats.items()},
|
||||
masks[view_i],
|
||||
masked_sampling,
|
||||
sampling_mode="bilinear",
|
||||
)
|
||||
feats_sampled[pts_i][view_i] = feats_sampled_
|
||||
masks_sampled[pts_i][view_i] = masks_sampled_
|
||||
|
||||
masks_sampled_cat = torch.stack([torch.stack(m) for m in masks_sampled])
|
||||
feats_sampled_cat = {}
|
||||
for k in feats_sampled[0][0].keys():
|
||||
feats_sampled_cat[k] = torch.stack(
|
||||
[torch.stack([f_[k] for f_ in f]) for f in feats_sampled]
|
||||
)
|
||||
return feats_sampled_cat, masks_sampled_cat
|
||||
|
||||
|
||||
def _sample_one_view_naive(
|
||||
camera,
|
||||
pts,
|
||||
feats,
|
||||
masks,
|
||||
masked_sampling,
|
||||
sampling_mode="bilinear",
|
||||
):
|
||||
"""
|
||||
Sample a single source view.
|
||||
"""
|
||||
proj_ndc = camera.transform_points(pts[None])[None, ..., :-1] # 1 x 1 x n_pts x 2
|
||||
feats_sampled = {
|
||||
k: pt3d.renderer.ndc_grid_sample(f[None], proj_ndc, mode=sampling_mode).permute(
|
||||
0, 3, 1, 2
|
||||
)[0, :, :, 0]
|
||||
for k, f in feats.items()
|
||||
} # n_pts x dim
|
||||
if not masked_sampling:
|
||||
n_pts = pts.shape[0]
|
||||
masks_sampled = proj_ndc.new_ones(n_pts, 1)
|
||||
else:
|
||||
masks_sampled = pt3d.renderer.ndc_grid_sample(
|
||||
masks[None],
|
||||
proj_ndc,
|
||||
mode=sampling_mode,
|
||||
align_corners=False,
|
||||
)[0, 0, 0, :][:, None]
|
||||
return feats_sampled, masks_sampled
|
||||
Reference in New Issue
Block a user