implicitron v0 (#1133)

Co-authored-by: Jeremy Francis Reizenstein <bottler@users.noreply.github.com>
This commit is contained in:
Jeremy Reizenstein
2022-03-21 20:20:10 +00:00
committed by GitHub
parent 0e377c6850
commit cdd2142dd5
90 changed files with 17075 additions and 0 deletions

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import logging
import os
import tempfile
import unittest
from pathlib import Path
from typing import Generator, Tuple
from zipfile import ZipFile
from iopath.common.file_io import PathManager
@contextlib.contextmanager
def get_skateboard_data(
avoid_manifold: bool = False, silence_logs: bool = False
) -> Generator[Tuple[str, PathManager], None, None]:
"""
Context manager for accessing Co3D dataset by tests, at least for
the first 5 skateboards. Internally, we want this to exercise the
normal way to access the data directly manifold, but on an RE
worker this is impossible so we use a workaround.
Args:
avoid_manifold: Use the method used by RE workers even locally.
silence_logs: Whether to reduce log output from iopath library.
Yields:
dataset_root: (str) path to dataset root.
path_manager: path_manager to access it with.
"""
path_manager = PathManager()
if silence_logs:
logging.getLogger("iopath.fb.manifold").setLevel(logging.CRITICAL)
logging.getLogger("iopath.common.file_io").setLevel(logging.CRITICAL)
if not os.environ.get("FB_TEST", False):
if os.getenv("FAIR_ENV_CLUSTER", "") == "":
raise unittest.SkipTest("Unknown environment. Data not available.")
yield "/checkpoint/dnovotny/datasets/co3d/download_aws_22_02_18", path_manager
elif avoid_manifold or os.environ.get("INSIDE_RE_WORKER", False):
from libfb.py.parutil import get_file_path
par_path = "skateboard_first_5"
source = get_file_path(par_path)
assert Path(source).is_file()
with tempfile.TemporaryDirectory() as dest:
with ZipFile(source) as f:
f.extractall(dest)
yield os.path.join(dest, "extracted"), path_manager
else:
from iopath.fb.manifold import ManifoldPathHandler
path_manager.register_handler(ManifoldPathHandler())
yield "manifold://co3d/tree/extracted", path_manager
def provide_lpips_vgg():
"""
Ensure the weights files are available for lpips.LPIPS(net="vgg")
to be called. Specifically, torchvision's vgg16
"""
# In OSS, torchvision looks for vgg16 weights in
# https://download.pytorch.org/models/vgg16-397923af.pth
# Inside fbcode, this is replaced by asking iopath for
# manifold://torchvision/tree/models/vgg16-397923af.pth
# (the code for this replacement is in
# fbcode/pytorch/vision/fb/_internally_replaced_utils.py )
#
# iopath does this by looking for the file at the cache location
# and if it is not there getting it from manifold.
# (the code for this is in
# fbcode/fair_infra/data/iopath/iopath/fb/manifold.py )
#
# On the remote execution worker, manifold is inaccessible.
# We solve this by making the cached file available before iopath
# looks.
#
# By default the cache location is
# ~/.torch/iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# But we can't write to the home directory on the RE worker.
# We define FVCORE_CACHE to change the cache location to
# iopath_cache/manifold_cache/tree/models/vgg16-397923af.pth
# (Without it, manifold caches in unstable temporary locations on RE.)
#
# The file we want has been copied from
# tree/models/vgg16-397923af.pth in the torchvision bucket
# to
# tree/testing/vgg16-397923af.pth in the co3d bucket
# and the TARGETS file copies it somewhere in the PAR which we
# recover with get_file_path.
# (It can't copy straight to a nested location, see
# https://fb.workplace.com/groups/askbuck/posts/2644615728920359/)
# Here we symlink it to the new cache location.
if os.environ.get("INSIDE_RE_WORKER") is not None:
from libfb.py.parutil import get_file_path
os.environ["FVCORE_CACHE"] = "iopath_cache"
par_path = "vgg_weights_for_lpips"
source = Path(get_file_path(par_path))
assert source.is_file()
dest = Path("iopath_cache/manifold_cache/tree/models")
if not dest.exists():
dest.mkdir(parents=True)
(dest / "vgg16-397923af.pth").symlink_to(source)

View File

@@ -0,0 +1,122 @@
mask_images: true
mask_depths: true
render_image_width: 400
render_image_height: 400
mask_threshold: 0.5
output_rasterized_mc: false
bg_color:
- 0.0
- 0.0
- 0.0
view_pool: false
num_passes: 1
chunk_size_grid: 4096
render_features_dimensions: 3
tqdm_trigger_threshold: 16
n_train_target_views: 1
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
renderer_class_type: LSTMRenderer
feature_aggregator_class_type: AngleWeightedIdentityFeatureAggregator
implicit_function_class_type: IdrFeatureField
loss_weights:
loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
sequence_autodecoder_args:
encoding_dim: 0
n_instances: 0
init_scale: 1.0
ignore_input: false
raysampler_args:
image_width: 400
image_height: 400
scene_center:
- 0.0
- 0.0
- 0.0
scene_extent: 0.0
sampling_mode_training: mask_sample
sampling_mode_evaluation: full_grid
n_pts_per_ray_training: 64
n_pts_per_ray_evaluation: 64
n_rays_per_image_sampled_from_mask: 1024
min_depth: 0.1
max_depth: 8.0
stratified_point_sampling_training: true
stratified_point_sampling_evaluation: false
renderer_LSTMRenderer_args:
num_raymarch_steps: 10
init_depth: 17.0
init_depth_noise_std: 0.0005
hidden_size: 16
n_feature_channels: 256
verbose: false
image_feature_extractor_args:
name: resnet34
pretrained: true
stages:
- 1
- 2
- 3
- 4
normalize_image: true
image_rescale: 0.16
first_max_pool: true
proj_dim: 32
l2_norm: true
add_masks: true
add_images: true
global_average_pool: false
feature_rescale: 1.0
view_sampler_args:
masked_sampling: false
sampling_mode: bilinear
feature_aggregator_AngleWeightedIdentityFeatureAggregator_args:
exclude_target_view: true
exclude_target_view_mask_features: true
concatenate_output: true
weight_by_ray_angle_gamma: 1.0
min_ray_angle_weight: 0.1
implicit_function_IdrFeatureField_args:
feature_vector_size: 3
d_in: 3
d_out: 1
dims:
- 512
- 512
- 512
- 512
- 512
- 512
- 512
- 512
geometric_init: true
bias: 1.0
skip_in: []
weight_norm: true
n_harmonic_functions_xyz: 0
pooled_feature_dim: 0
encoding_dim: 0

View File

@@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
from collections import defaultdict
from dataclasses import dataclass
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
@dataclass
class MockFrameAnnotation:
frame_number: int
frame_timestamp: float = 0.0
class MockDataset:
def __init__(self, num_seq, max_frame_gap=1):
"""
Makes a gap of max_frame_gap frame numbers in the middle of each sequence
"""
self.seq_annots = {f"seq_{i}": None for i in range(num_seq)}
self.seq_to_idx = {
f"seq_{i}": list(range(i * 10, i * 10 + 10)) for i in range(num_seq)
}
# frame numbers within sequence: [0, ..., 4, n, ..., n+4]
# where n - 4 == max_frame_gap
frame_nos = list(range(5)) + list(range(4 + max_frame_gap, 9 + max_frame_gap))
self.frame_annots = [
{"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
]
def get_frame_numbers_and_timestamps(self, idxs):
out = []
for idx in idxs:
frame_annotation = self.frame_annots[idx]["frame_annotation"]
out.append(
(frame_annotation.frame_number, frame_annotation.frame_timestamp)
)
return out
class TestSceneBatchSampler(unittest.TestCase):
def setUp(self):
self.dataset_overfit = MockDataset(1)
def test_overfit(self):
num_batches = 3
batch_size = 10
sampler = SceneBatchSampler(
self.dataset_overfit,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=[10], # will try to sample batch_size anyway
)
self.assertEqual(len(sampler), num_batches)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch)
self.assertEqual(len(batch), batch_size) # true for our examples
self.assertTrue(all(idx // 10 == 0 for idx in batch))
with self.assertRaises(StopIteration):
batch = next(it)
def test_multiseq(self):
for ips_options in [[10], [2], [3], [2, 3, 4]]:
for sample_consecutive_frames in [True, False]:
for consecutive_frames_max_gap in [0, 1, 3]:
self._test_multiseq_flavour(
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
)
def test_multiseq_gaps(self):
num_batches = 16
batch_size = 10
dataset_multiseq = MockDataset(5, max_frame_gap=3)
for ips_options in [[10], [2], [3], [2, 3, 4]]:
debug_info = f" Images per sequence: {ips_options}."
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=True,
consecutive_frames_max_gap=1,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
if max(ips_options) > 5:
# true for our examples
self.assertEqual(len(batch), 5, msg=debug_info)
else:
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
self._check_frames_are_consecutive(
batch, dataset_multiseq.frame_annots, debug_info
)
def _test_multiseq_flavour(
self,
ips_options,
sample_consecutive_frames,
consecutive_frames_max_gap,
num_batches=16,
batch_size=10,
):
debug_info = (
f" Images per sequence: {ips_options}, "
f"sample_consecutive_frames: {sample_consecutive_frames}, "
f"consecutive_frames_max_gap: {consecutive_frames_max_gap}, "
)
# in this test, either consecutive_frames_max_gap == max_frame_gap,
# or consecutive_frames_max_gap == 0, so segments consist of full sequences
frame_gap = consecutive_frames_max_gap if consecutive_frames_max_gap > 0 else 3
dataset_multiseq = MockDataset(5, max_frame_gap=frame_gap)
sampler = SceneBatchSampler(
dataset_multiseq,
batch_size=batch_size,
num_batches=num_batches,
images_per_seq_options=ips_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
)
self.assertEqual(len(sampler), num_batches, msg=debug_info)
it = iter(sampler)
typical_counts = set()
for _ in range(num_batches):
batch = next(it)
self.assertIsNotNone(batch, "batch is None in" + debug_info)
# true for our examples
self.assertEqual(len(batch), batch_size, msg=debug_info)
# find distribution over sequences
counts = _count_by_quotient(batch, 10)
freqs = _count_by_quotient(counts.values(), 1)
self.assertLessEqual(
len(freqs),
2,
msg="We should have maximum of 2 different "
"frequences of sequences in the batch." + debug_info,
)
if len(freqs) == 2:
most_seq_count = max(*freqs.keys())
last_seq = min(*freqs.keys())
self.assertEqual(
freqs[last_seq],
1,
msg="Only one odd sequence allowed." + debug_info,
)
else:
self.assertEqual(len(freqs), 1)
most_seq_count = next(iter(freqs))
self.assertIn(most_seq_count, ips_options)
typical_counts.add(most_seq_count)
if sample_consecutive_frames:
self._check_frames_are_consecutive(
batch,
dataset_multiseq.frame_annots,
debug_info,
max_gap=consecutive_frames_max_gap,
)
self.assertTrue(
all(i in typical_counts for i in ips_options),
"Some of the frequency options did not occur among "
f"the {num_batches} batches (could be just bad luck)." + debug_info,
)
with self.assertRaises(StopIteration):
batch = next(it)
def _check_frames_are_consecutive(self, batch, annots, debug_info, max_gap=1):
# make sure that sampled frames are consecutive
for i in range(len(batch) - 1):
curr_idx, next_idx = batch[i : i + 2]
if curr_idx // 10 == next_idx // 10: # same sequence
if max_gap > 0:
curr_idx, next_idx = [
annots[idx]["frame_annotation"].frame_number
for idx in (curr_idx, next_idx)
]
gap = max_gap
else:
gap = 1 # we'll check that raw dataset indices are consecutive
self.assertLessEqual(next_idx - curr_idx, gap, msg=debug_info)
def _count_by_quotient(indices, divisor):
counter = defaultdict(int)
for i in indices:
counter[i // divisor] += 1
return counter

View File

@@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from math import pi
import torch
from pytorch3d.implicitron.tools.circle_fitting import (
_signed_area,
fit_circle_in_2d,
fit_circle_in_3d,
)
from pytorch3d.transforms import random_rotation
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestCircleFitting(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def _assertParallel(self, a, b, **kwargs):
"""
Given a and b of shape (..., 3) each containing 3D vectors,
assert that correspnding vectors are parallel. Changed sign is ok.
"""
self.assertClose(torch.cross(a, b, dim=-1), torch.zeros_like(a), **kwargs)
def test_simple_3d(self):
device = torch.device("cuda:0")
for _ in range(7):
radius = 10 * torch.rand(1, device=device)[0]
center = 10 * torch.rand(3, device=device)
rot = random_rotation(device=device)
offset = torch.rand(3, device=device)
up = torch.rand(3, device=device)
self._simple_3d_test(radius, center, rot, offset, up)
def _simple_3d_test(self, radius, center, rot, offset, up):
# angles are increasing so the points move in a well defined direction.
angles = torch.cumsum(torch.rand(17, device=rot.device), dim=0)
many = torch.stack(
[torch.cos(angles), torch.sin(angles), torch.zeros_like(angles)], dim=1
)
source_points = (many * radius) @ rot + center[None]
# case with no generation
result = fit_circle_in_3d(source_points)
self.assertClose(result.radius, radius)
self.assertClose(result.center, center)
self._assertParallel(result.normal, rot[2], atol=1e-5)
self.assertEqual(result.generated_points.shape, (0, 3))
# Generate 5 points around the circle
n_new_points = 5
result2 = fit_circle_in_3d(source_points, n_points=n_new_points)
self.assertClose(result2.radius, radius)
self.assertClose(result2.center, center)
self.assertClose(result2.normal, result.normal)
self.assertEqual(result2.generated_points.shape, (5, 3))
observed_points = result2.generated_points
self.assertClose(observed_points[0], observed_points[4], atol=1e-4)
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
observed_normal = torch.cross(
observed_points[0] - observed_points[2],
observed_points[1] - observed_points[3],
dim=-1,
)
self._assertParallel(observed_normal, result.normal, atol=1e-4)
diameters = observed_points[:2] - observed_points[2:4]
self.assertClose(
torch.norm(diameters, dim=1), diameters.new_full((2,), 2 * radius)
)
# Regenerate the input points
result3 = fit_circle_in_3d(source_points, angles=angles - angles[0])
self.assertClose(result3.radius, radius)
self.assertClose(result3.center, center)
self.assertClose(result3.normal, result.normal)
self.assertClose(result3.generated_points, source_points, atol=1e-5)
# Test with offset
result4 = fit_circle_in_3d(
source_points, angles=angles - angles[0], offset=offset, up=up
)
self.assertClose(result4.radius, radius)
self.assertClose(result4.center, center)
self.assertClose(result4.normal, result.normal)
observed_offsets = result4.generated_points - source_points
# observed_offset is constant
self.assertClose(
observed_offsets.min(0).values, observed_offsets.max(0).values, atol=1e-5
)
# observed_offset has the right length
self.assertClose(observed_offsets[0].norm(), offset.norm())
self.assertClose(result.normal.norm(), torch.ones(()))
# component of observed_offset along normal
component = torch.dot(observed_offsets[0], result.normal)
self.assertClose(component.abs(), offset[2].abs(), atol=1e-5)
agree_normal = torch.dot(result.normal, up) > 0
agree_signs = component * offset[2] > 0
self.assertEqual(agree_normal, agree_signs)
def test_simple_2d(self):
radius = 7.0
center = torch.tensor([9, 2.5])
angles = torch.cumsum(torch.rand(17), dim=0)
many = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
source_points = (many * radius) + center[None]
result = fit_circle_in_2d(source_points)
self.assertClose(result.radius, torch.tensor(radius))
self.assertClose(result.center, center)
self.assertEqual(result.generated_points.shape, (0, 2))
# Generate 5 points around the circle
n_new_points = 5
result2 = fit_circle_in_2d(source_points, n_points=n_new_points)
self.assertClose(result2.radius, torch.tensor(radius))
self.assertClose(result2.center, center)
self.assertEqual(result2.generated_points.shape, (5, 2))
observed_points = result2.generated_points
self.assertClose(observed_points[0], observed_points[4])
self.assertClose(observed_points[0], source_points[0], atol=1e-5)
diameters = observed_points[:2] - observed_points[2:4]
self.assertClose(torch.norm(diameters, dim=1), torch.full((2,), 2 * radius))
# Regenerate the input points
result3 = fit_circle_in_2d(source_points, angles=angles - angles[0])
self.assertClose(result3.radius, torch.tensor(radius))
self.assertClose(result3.center, center)
self.assertClose(result3.generated_points, source_points, atol=1e-5)
def test_minimum_inputs(self):
fit_circle_in_3d(torch.rand(3, 3), n_points=10)
with self.assertRaisesRegex(
ValueError, "2 points are not enough to determine a circle"
):
fit_circle_in_3d(torch.rand(2, 3))
def test_signed_area(self):
n_points = 1001
angles = torch.linspace(0, 2 * pi, n_points)
radius = 0.85
center = torch.rand(2)
circle = center + radius * torch.stack(
[torch.cos(angles), torch.sin(angles)], dim=1
)
circle_area = torch.tensor(pi * radius * radius)
self.assertClose(_signed_area(circle), circle_area)
# clockwise is negative
self.assertClose(_signed_area(circle.flip(0)), -circle_area)
# Semicircles
self.assertClose(_signed_area(circle[: (n_points + 1) // 2]), circle_area / 2)
self.assertClose(_signed_area(circle[n_points // 2 :]), circle_area / 2)
# A straight line bounds no area
self.assertClose(_signed_area(torch.rand(2, 2)), torch.tensor(0.0))
# Letter 'L' written anticlockwise.
L_shape = [[0, 1], [0, 0], [1, 0]]
# Triangle area is 0.5 * b * h.
self.assertClose(_signed_area(torch.tensor(L_shape)), torch.tensor(0.5))

View File

@@ -0,0 +1,610 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import textwrap
import unittest
from dataclasses import dataclass, field, is_dataclass
from enum import Enum
from typing import List, Optional, Tuple
from omegaconf import DictConfig, ListConfig, OmegaConf, ValidationError
from pytorch3d.implicitron.tools.config import (
Configurable,
ReplaceableBase,
_is_actually_dataclass,
_Registry,
expand_args_fields,
get_default_args,
get_default_args_field,
registry,
remove_unused_components,
run_auto_creation,
)
@dataclass
class Animal(ReplaceableBase):
pass
class Fruit(ReplaceableBase):
pass
@registry.register
class Banana(Fruit):
pips: int
spots: int
bananame: str
@registry.register
class Pear(Fruit):
n_pips: int = 13
class Pineapple(Fruit):
pass
@registry.register
class Orange(Fruit):
pass
@registry.register
class Kiwi(Fruit):
pass
@registry.register
class LargePear(Pear):
pass
class MainTest(Configurable):
the_fruit: Fruit
n_ids: int
n_reps: int = 8
the_second_fruit: Fruit
def create_the_second_fruit(self):
expand_args_fields(Pineapple)
self.the_second_fruit = Pineapple()
def __post_init__(self):
run_auto_creation(self)
class TestConfig(unittest.TestCase):
def test_is_actually_dataclass(self):
@dataclass
class A:
pass
self.assertTrue(_is_actually_dataclass(A))
self.assertTrue(is_dataclass(A))
class B(A):
a: int
self.assertFalse(_is_actually_dataclass(B))
self.assertTrue(is_dataclass(B))
def test_simple_replacement(self):
struct = get_default_args(MainTest)
struct.n_ids = 9780
struct.the_fruit_Pear_args.n_pips = 3
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Pear"
main = MainTest(**struct)
self.assertIsInstance(main.the_fruit, Pear)
self.assertEqual(main.n_reps, 8)
self.assertEqual(main.n_ids, 9780)
self.assertEqual(main.the_fruit.n_pips, 3)
self.assertIsInstance(main.the_second_fruit, Pineapple)
struct2 = get_default_args(MainTest)
self.assertEqual(struct2.the_fruit_Pear_args.n_pips, 13)
self.assertEqual(
MainTest._creation_functions,
("create_the_fruit", "create_the_second_fruit"),
)
def test_detect_bases(self):
# testing the _base_class_from_class function
self.assertIsNone(_Registry._base_class_from_class(ReplaceableBase))
self.assertIsNone(_Registry._base_class_from_class(MainTest))
self.assertIs(_Registry._base_class_from_class(Fruit), Fruit)
self.assertIs(_Registry._base_class_from_class(Pear), Fruit)
class PricklyPear(Pear):
pass
self.assertIs(_Registry._base_class_from_class(PricklyPear), Fruit)
def test_registry_entries(self):
self.assertIs(registry.get(Fruit, "Banana"), Banana)
with self.assertRaisesRegex(ValueError, "Banana has not been registered."):
registry.get(Animal, "Banana")
with self.assertRaisesRegex(ValueError, "PricklyPear has not been registered."):
registry.get(Fruit, "PricklyPear")
self.assertIs(registry.get(Pear, "Pear"), Pear)
self.assertIs(registry.get(Pear, "LargePear"), LargePear)
with self.assertRaisesRegex(ValueError, "Banana resolves to"):
registry.get(Pear, "Banana")
all_fruit = set(registry.get_all(Fruit))
self.assertIn(Banana, all_fruit)
self.assertIn(Pear, all_fruit)
self.assertIn(LargePear, all_fruit)
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
@registry.register
class Apple(Fruit):
pass
@registry.register
class CrabApple(Apple):
pass
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
with self.assertRaisesRegex(ValueError, "Cannot tell what it is."):
@registry.register
class NotAFruit:
pass
def test_recursion(self):
class Shape(ReplaceableBase):
pass
@registry.register
class Triangle(Shape):
a: float = 5.0
@registry.register
class Square(Shape):
a: float = 3.0
@registry.register
class LargeShape(Shape):
inner: Shape
def __post_init__(self):
run_auto_creation(self)
class ShapeContainer(Configurable):
shape: Shape
container = ShapeContainer(**get_default_args(ShapeContainer))
# This is because ShapeContainer is missing __post_init__
with self.assertRaises(AttributeError):
container.shape
class ShapeContainer2(Configurable):
x: Shape
x_class_type: str = "LargeShape"
def __post_init__(self):
self.x_LargeShape_args.inner_class_type = "Triangle"
run_auto_creation(self)
container2_args = get_default_args(ShapeContainer2)
container2_args.x_LargeShape_args.inner_Triangle_args.a += 10
self.assertIn("inner_Square_args", container2_args.x_LargeShape_args)
# We do not perform expansion that would result in an infinite recursion,
# so this member is not present.
self.assertNotIn("inner_LargeShape_args", container2_args.x_LargeShape_args)
container2_args.x_LargeShape_args.inner_Square_args.a += 100
container2 = ShapeContainer2(**container2_args)
self.assertIsInstance(container2.x, LargeShape)
self.assertIsInstance(container2.x.inner, Triangle)
self.assertEqual(container2.x.inner.a, 15.0)
def test_simpleclass_member(self):
# Members which are not dataclasses are
# tolerated. But it would be nice to be able to
# configure them.
class Foo:
def __init__(self, a=1, b=2):
self.a, self.b = a, b
@dataclass()
class Bar:
aa: int = 9
bb: int = 9
class Container(Configurable):
bar: Bar = Bar()
# TODO make this work?
# foo: Foo = Foo()
fruit: Fruit
fruit_class_type: str = "Orange"
def __post_init__(self):
run_auto_creation(self)
self.assertEqual(get_default_args(Foo), {"a": 1, "b": 2})
container_args = get_default_args(Container)
container = Container(**container_args)
self.assertIsInstance(container.fruit, Orange)
# self.assertIsInstance(container.bar, Bar)
container_defaulted = Container()
container_defaulted.fruit_Pear_args.n_pips += 4
container_args2 = get_default_args(Container)
container = Container(**container_args2)
self.assertEqual(container.fruit_Pear_args.n_pips, 13)
def test_inheritance(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Orange"
def __post_init__(self):
raise ValueError("This doesn't get called")
class LargeFruitBowl(FruitBowl):
extra_fruit: Fruit
extra_fruit_class_type: str = "Kiwi"
def __post_init__(self):
run_auto_creation(self)
large_args = get_default_args(LargeFruitBowl)
self.assertNotIn("extra_fruit", large_args)
self.assertNotIn("main_fruit", large_args)
large = LargeFruitBowl(**large_args)
self.assertIsInstance(large.main_fruit, Orange)
self.assertIsInstance(large.extra_fruit, Kiwi)
def test_inheritance2(self):
# This is a case where a class could contain an instance
# of a subclass, which is ignored.
class Parent(ReplaceableBase):
pass
class Main(Configurable):
parent: Parent
# Note - no __post__init__
@registry.register
class Derived(Parent, Main):
pass
args = get_default_args(Main)
# Derived has been ignored in processing Main.
self.assertCountEqual(args.keys(), ["parent_class_type"])
main = Main(**args)
with self.assertRaisesRegex(ValueError, "UNDEFAULTED has not been registered."):
run_auto_creation(main)
main.parent_class_type = "Derived"
# Illustrates that a dict works fine instead of a DictConfig.
main.parent_Derived_args = {}
with self.assertRaises(AttributeError):
main.parent
run_auto_creation(main)
self.assertIsInstance(main.parent, Derived)
def test_redefine(self):
class FruitBowl(ReplaceableBase):
main_fruit: Fruit
main_fruit_class_type: str = "Grape"
def __post_init__(self):
run_auto_creation(self)
@registry.register
@dataclass
class Grape(Fruit):
large: bool = False
def get_color(self):
return "red"
def __post_init__(self):
raise ValueError("This doesn't get called")
bowl_args = get_default_args(FruitBowl)
@registry.register
@dataclass
class Grape(Fruit): # noqa: F811
large: bool = True
def get_color(self):
return "green"
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
self.assertIsInstance(bowl.main_fruit, Grape)
# Redefining the same class won't help with defaults because encoded in args
self.assertEqual(bowl.main_fruit.large, False)
# But the override worked.
self.assertEqual(bowl.main_fruit.get_color(), "green")
# 2. Try redefining without the dataclass modifier
# This relies on the fact that default creation processes the class.
# (otherwise incomprehensible messages)
@registry.register
class Grape(Fruit): # noqa: F811
large: bool = True
with self.assertWarnsRegex(
UserWarning, "New implementation of Grape is being chosen."
):
bowl = FruitBowl(**bowl_args)
# 3. Adding a new class doesn't get picked up, because the first
# get_default_args call has frozen FruitBowl. This is intrinsic to
# the way dataclass and expand_args_fields work in-place but
# expand_args_fields is not pure - it depends on the registry.
@registry.register
class Fig(Fruit):
pass
bowl_args2 = get_default_args(FruitBowl)
self.assertIn("main_fruit_Grape_args", bowl_args2)
self.assertNotIn("main_fruit_Fig_args", bowl_args2)
# TODO Is it possible to make this work?
# bowl_args2["main_fruit_Fig_args"] = get_default_args(Fig)
# bowl_args2.main_fruit_class_type = "Fig"
# bowl2 = FruitBowl(**bowl_args2) <= unexpected argument
# Note that it is possible to use Fig if you can set
# bowl2.main_fruit_Fig_args explicitly (not in bowl_args2)
# before run_auto_creation happens. See test_inheritance2
# for an example.
def test_no_replacement(self):
# Test of Configurables without ReplaceableBase
class A(Configurable):
n: int = 9
class B(Configurable):
a: A
def __post_init__(self):
run_auto_creation(self)
class C(Configurable):
b: B
def __post_init__(self):
run_auto_creation(self)
c_args = get_default_args(C)
c = C(**c_args)
self.assertIsInstance(c.b.a, A)
self.assertEqual(c.b.a.n, 9)
def test_doc(self):
# The case in the docstring.
class A(ReplaceableBase):
k: int = 1
@registry.register
class A1(A):
m: int = 3
@registry.register
class A2(A):
n: str = "2"
class B(Configurable):
a: A
a_class_type: str = "A2"
def __post_init__(self):
run_auto_creation(self)
b_args = get_default_args(B)
self.assertNotIn("a", b_args)
b = B(**b_args)
self.assertEqual(b.a.n, "2")
def test_raw_types(self):
@dataclass
class MyDataclass:
int_field: int = 0
none_field: Optional[int] = None
float_field: float = 9.3
bool_field: bool = True
tuple_field: tuple = (3, True, "j")
class SimpleClass:
def __init__(self, tuple_member_=(3, 4)):
self.tuple_member = tuple_member_
def get_tuple(self):
return self.tuple_member
def f(*, a: int = 3, b: str = "kj"):
self.assertEqual(a, 3)
self.assertEqual(b, "kj")
class C(Configurable):
simple: DictConfig = get_default_args_field(SimpleClass)
# simple2: SimpleClass2 = SimpleClass2()
mydata: DictConfig = get_default_args_field(MyDataclass)
a_tuple: Tuple[float] = (4.0, 3.0)
f_args: DictConfig = get_default_args_field(f)
args = get_default_args(C)
c = C(**args)
self.assertCountEqual(args.keys(), ["simple", "mydata", "a_tuple", "f_args"])
mydata = MyDataclass(**c.mydata)
simple = SimpleClass(**c.simple)
# OmegaConf converts tuples to ListConfigs (which act like lists).
self.assertEqual(simple.get_tuple(), [3, 4])
self.assertTrue(isinstance(simple.get_tuple(), ListConfig))
self.assertEqual(c.a_tuple, [4.0, 3.0])
self.assertTrue(isinstance(c.a_tuple, ListConfig))
self.assertEqual(mydata.tuple_field, (3, True, "j"))
self.assertTrue(isinstance(mydata.tuple_field, ListConfig))
f(**c.f_args)
def test_irrelevant_bases(self):
class NotADataclass:
# Like torch.nn.Module, this class contains annotations
# but is not designed to be dataclass'd.
# This test ensures that such classes, when inherited fron,
# are not accidentally expand_args_fields.
a: int = 9
b: int
class LeftConfigured(Configurable, NotADataclass):
left: int = 1
class RightConfigured(NotADataclass, Configurable):
right: int = 2
class Outer(Configurable):
left: LeftConfigured
right: RightConfigured
def __post_init__(self):
run_auto_creation(self)
outer = Outer(**get_default_args(Outer))
self.assertEqual(outer.left.left, 1)
self.assertEqual(outer.right.right, 2)
with self.assertRaisesRegex(TypeError, "non-default argument"):
dataclass(NotADataclass)
def test_unprocessed(self):
# behavior of Configurable classes which need processing in __new__,
class Unprocessed(Configurable):
a: int = 9
class UnprocessedReplaceable(ReplaceableBase):
a: int = 1
with self.assertWarnsRegex(UserWarning, "must be processed"):
Unprocessed()
with self.assertWarnsRegex(UserWarning, "must be processed"):
UnprocessedReplaceable()
def test_enum(self):
# Test that enum values are kept, i.e. that OmegaConf's runtime checks
# are in use.
class A(Enum):
B1 = "b1"
B2 = "b2"
class C(Configurable):
a: A = A.B1
base = get_default_args(C)
replaced = OmegaConf.merge(base, {"a": "B2"})
self.assertEqual(replaced.a, A.B2)
with self.assertRaises(ValidationError):
# You can't use a value which is not one of the
# choices, even if it is the str representation
# of one of the choices.
OmegaConf.merge(base, {"a": "b2"})
remerged = OmegaConf.merge(base, OmegaConf.create(OmegaConf.to_yaml(base)))
self.assertEqual(remerged.a, A.B1)
def test_remove_unused_components(self):
struct = get_default_args(MainTest)
struct.n_ids = 32
struct.the_fruit_class_type = "Pear"
struct.the_second_fruit_class_type = "Banana"
remove_unused_components(struct)
expected_keys = [
"n_ids",
"n_reps",
"the_fruit_Pear_args",
"the_fruit_class_type",
"the_second_fruit_Banana_args",
"the_second_fruit_class_type",
]
expected_yaml = textwrap.dedent(
"""\
n_ids: 32
n_reps: 8
the_fruit_class_type: Pear
the_fruit_Pear_args:
n_pips: 13
the_second_fruit_class_type: Banana
the_second_fruit_Banana_args:
pips: ???
spots: ???
bananame: ???
"""
)
self.assertEqual(sorted(struct.keys()), expected_keys)
# Check that struct is what we expect
expected = OmegaConf.create(expected_yaml)
self.assertEqual(struct, expected)
# Check that we get what we expect when writing to yaml.
self.assertEqual(OmegaConf.to_yaml(struct, sort_keys=False), expected_yaml)
main = MainTest(**struct)
instance_data = OmegaConf.structured(main)
remove_unused_components(instance_data)
self.assertEqual(sorted(instance_data.keys()), expected_keys)
self.assertEqual(instance_data, expected)
@dataclass(eq=False)
class MockDataclass:
field_no_default: int
field_primitive_type: int = 42
field_reference_type: List[int] = field(default_factory=lambda: [])
class MockClassWithInit: # noqa: B903
def __init__(
self,
field_no_default: int,
field_primitive_type: int = 42,
field_reference_type: List[int] = [], # noqa: B006
):
self.field_no_default = field_no_default
self.field_primitive_type = field_primitive_type
self.field_reference_type = field_reference_type
class TestRawClasses(unittest.TestCase):
def test_get_default_args(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
inst = cls(field_no_default=0)
dataclass_defaults.field_no_default = 0
for name, val in dataclass_defaults.items():
self.assertTrue(hasattr(inst, name))
self.assertEqual(val, getattr(inst, name))
def test_get_default_args_readonly(self):
for cls in [MockDataclass, MockClassWithInit]:
dataclass_defaults = get_default_args(cls)
dataclass_defaults["field_reference_type"].append(13)
inst = cls(field_no_default=0)
self.assertEqual(inst.field_reference_type, [])

View File

@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)
from pytorch3d.implicitron.models.implicit_function.neural_radiance_field import (
NeuralRadianceFieldImplicitFunction,
)
from pytorch3d.implicitron.models.renderer.lstm_renderer import LSTMRenderer
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.view_pooling.feature_aggregation import (
AngleWeightedIdentityFeatureAggregator,
AngleWeightedReductionFeatureAggregator,
)
from pytorch3d.implicitron.tools.config import (
get_default_args,
remove_unused_components,
)
if os.environ.get("FB_TEST", False):
from common_testing import get_tests_dir
else:
from tests.common_testing import get_tests_dir
DATA_DIR = get_tests_dir() / "implicitron/data"
DEBUG: bool = False
# Tests the use of the config system in implicitron
class TestGenericModel(unittest.TestCase):
def setUp(self):
self.maxDiff = None
def test_create_gm(self):
args = get_default_args(GenericModel)
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, MultiPassEmissionAbsorptionRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedReductionFeatureAggregator
)
self.assertIsInstance(
gm._implicit_functions[0]._fn, NeuralRadianceFieldImplicitFunction
)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
self.assertFalse(hasattr(gm, "image_feature_extractor"))
def test_create_gm_overrides(self):
args = get_default_args(GenericModel)
args.feature_aggregator_class_type = "AngleWeightedIdentityFeatureAggregator"
args.implicit_function_class_type = "IdrFeatureField"
args.renderer_class_type = "LSTMRenderer"
gm = GenericModel(**args)
self.assertIsInstance(gm.renderer, LSTMRenderer)
self.assertIsInstance(
gm.feature_aggregator, AngleWeightedIdentityFeatureAggregator
)
self.assertIsInstance(gm._implicit_functions[0]._fn, IdrFeatureField)
self.assertIsInstance(gm.sequence_autodecoder, Autodecoder)
self.assertFalse(hasattr(gm, "implicit_function"))
instance_args = OmegaConf.structured(gm)
remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG:
(DATA_DIR / "overrides.yaml_").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())

View File

@@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import os
import unittest
import torch
import torchvision
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
from pytorch3d.implicitron.dataset.visualize import get_implicitron_sequence_pointcloud
from pytorch3d.implicitron.tools.point_cloud_utils import render_point_cloud_pytorch3d
from pytorch3d.vis.plotly_vis import plot_scene
from visdom import Visdom
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data
else:
from common_resources import get_skateboard_data
class TestDatasetVisualize(unittest.TestCase):
def setUp(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
raise unittest.SkipTest("Visdom not available")
category = "skateboard"
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.datasets = {
"simple": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nonsquare": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=True,
load_point_clouds=True,
path_manager=path_manager,
),
"nocrop": ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size // 2,
box_crop=False,
load_point_clouds=True,
path_manager=path_manager,
),
}
self.datasets.update(
{
k + "_newndc": _change_annotations_to_new_ndc(dataset)
for k, dataset in self.datasets.items()
}
)
self.visdom = Visdom()
if not self.visdom.check_connection():
print("Visdom server not running! Disabling visdom visualizations.")
self.visdom = None
def _render_one_pointcloud(self, point_cloud, cameras, render_size):
(_image_render, _, _) = render_point_cloud_pytorch3d(
cameras,
point_cloud,
render_size=render_size,
point_radius=1e-2,
topk=10,
bg_color=0.0,
)
return _image_render.clamp(0.0, 1.0)
def test_one(self):
"""Test dataset visualization."""
for max_frames in (16, -1):
for load_dataset_point_cloud in (True, False):
for dataset_key in self.datasets:
self._gen_and_render_pointcloud(
max_frames, load_dataset_point_cloud, dataset_key
)
def _gen_and_render_pointcloud(
self, max_frames, load_dataset_point_cloud, dataset_key
):
dataset = self.datasets[dataset_key]
# load the point cloud of the first sequence
sequence_show = list(dataset.seq_annots.keys())[0]
device = torch.device("cuda:0")
point_cloud, sequence_frame_data = get_implicitron_sequence_pointcloud(
dataset,
sequence_name=sequence_show,
mask_points=True,
max_frames=max_frames,
num_workers=10,
load_dataset_point_cloud=load_dataset_point_cloud,
)
# render on gpu
point_cloud = point_cloud.to(device)
cameras = sequence_frame_data.camera.to(device)
# render the point_cloud from the viewpoint of loaded cameras
images_render = torch.cat(
[
self._render_one_pointcloud(
point_cloud,
cameras[frame_i],
(
dataset.image_height,
dataset.image_width,
),
)
for frame_i in range(len(cameras))
]
).cpu()
images_gt_and_render = torch.cat(
[sequence_frame_data.image_rgb, images_render], dim=3
)
imfile = os.path.join(
os.path.split(os.path.abspath(__file__))[0],
"test_dataset_visualize"
+ f"_max_frames={max_frames}"
+ f"_load_pcl={load_dataset_point_cloud}.png",
)
print(f"Exporting image {imfile}.")
torchvision.utils.save_image(images_gt_and_render, imfile, nrow=2)
if self.visdom is not None:
test_name = f"{max_frames}_{load_dataset_point_cloud}_{dataset_key}"
self.visdom.images(
images_gt_and_render,
env="test_dataset_visualize",
win=f"pcl_renders_{test_name}",
opts={"title": f"pcl_renders_{test_name}"},
)
plotlyplot = plot_scene(
{
"scene_batch": {
"cameras": cameras,
"point_cloud": point_cloud,
}
},
camera_scale=1.0,
pointcloud_max_points=10000,
pointcloud_marker_size=1.0,
)
self.visdom.plotlyplot(
plotlyplot,
env="test_dataset_visualize",
win=f"pcl_{test_name}",
)
def _change_annotations_to_new_ndc(dataset):
dataset = copy.deepcopy(dataset)
for frame in dataset.frame_annots:
vp = frame["frame_annotation"].viewpoint
vp.intrinsics_format = "ndc_isotropic"
# this assume the focal length to be equal on x and y (ok for a test)
max_flength = max(vp.focal_length)
vp.principal_point = (
vp.principal_point[0] * max_flength / vp.focal_length[0],
vp.principal_point[1] * max_flength / vp.focal_length[1],
)
vp.focal_length = (
max_flength,
max_flength,
)
return dataset

View File

@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.tools.eval_video_trajectory import (
generate_eval_video_cameras,
)
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
from pytorch3d.transforms import axis_angle_to_matrix
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestEvalCameras(TestCaseMixin, unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
def test_circular(self):
n_train_cameras = 10
n_test_cameras = 100
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
amplitude = 0.01
R_jiggled = torch.bmm(
R, axis_angle_to_matrix(torch.rand(n_train_cameras, 3) * amplitude)
)
cameras_train = PerspectiveCameras(R=R_jiggled, T=T)
cameras_test = generate_eval_video_cameras(
cameras_train, trajectory_type="circular_lsq_fit", trajectory_scale=1.0
)
positions_test = cameras_test.get_camera_center()
center = positions_test.mean(0)
self.assertClose(center, torch.zeros(3), atol=0.1)
self.assertClose(
(positions_test - center).norm(dim=[1]),
torch.ones(n_test_cameras),
atol=0.1,
)

View File

@@ -0,0 +1,290 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import contextlib
import copy
import dataclasses
import math
import os
import unittest
import lpips
import torch
from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
if os.environ.get("FB_TEST", False):
from .common_resources import get_skateboard_data, provide_lpips_vgg
else:
from common_resources import get_skateboard_data, provide_lpips_vgg
class TestEvaluation(unittest.TestCase):
def setUp(self):
# initialize evaluation dataset/dataloader
torch.manual_seed(42)
stack = contextlib.ExitStack()
dataset_root, path_manager = stack.enter_context(get_skateboard_data())
self.addCleanup(stack.close)
category = "skateboard"
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.dataset = ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
dataset_root=dataset_root,
image_height=self.image_size,
image_width=self.image_size,
box_crop=True,
path_manager=path_manager,
)
self.bg_color = 0.0
# init the lpips model for eval
provide_lpips_vgg()
self.lpips_model = lpips.LPIPS(net="vgg")
def test_eval_depth(self):
"""
Check that eval_depth correctly masks errors and that, for get_best_scale=True,
the error with scaled prediction equals the error without scaling the
predicted depth. Finally, test that the error values are as expected
for prediction and gt differing by a constant offset.
"""
gt = (torch.randn(10, 1, 300, 400, device="cuda") * 5.0).clamp(0.0)
mask = (torch.rand_like(gt) > 0.5).type_as(gt)
for diff in 10 ** torch.linspace(-5, 0, 6):
for crop in (0, 5):
pred = gt + (torch.rand_like(gt) - 0.5) * 2 * diff
# scaled prediction test
mse_depth, abs_depth = eval_depth(
pred,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
mse_depth_scale, abs_depth_scale = eval_depth(
pred * 10.0,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
self.assertAlmostEqual(
float(mse_depth.sum()), float(mse_depth_scale.sum()), delta=1e-4
)
self.assertAlmostEqual(
float(abs_depth.sum()), float(abs_depth_scale.sum()), delta=1e-4
)
# error masking test
pred_masked_err = gt + (torch.rand_like(gt) + diff) * (1 - mask)
mse_depth_masked, abs_depth_masked = eval_depth(
pred_masked_err,
gt,
crop=crop,
mask=mask,
get_best_scale=True,
)
self.assertAlmostEqual(
float(mse_depth_masked.sum()), float(0.0), delta=1e-4
)
self.assertAlmostEqual(
float(abs_depth_masked.sum()), float(0.0), delta=1e-4
)
mse_depth_unmasked, abs_depth_unmasked = eval_depth(
pred_masked_err,
gt,
crop=crop,
mask=1 - mask,
get_best_scale=True,
)
self.assertGreater(
float(mse_depth_unmasked.sum()),
float(diff ** 2),
)
self.assertGreater(
float(abs_depth_unmasked.sum()),
float(diff),
)
# tests with constant error
pred_fix_diff = gt + diff * mask
for _mask_gt in (mask, None):
mse_depth_fix_diff, abs_depth_fix_diff = eval_depth(
pred_fix_diff,
gt,
crop=crop,
mask=_mask_gt,
get_best_scale=False,
)
if _mask_gt is not None:
expected_err_abs = diff
expected_err_mse = diff ** 2
else:
err_mask = (gt > 0.0).float() * mask
if crop > 0:
err_mask = err_mask[:, :, crop:-crop, crop:-crop]
gt_cropped = gt[:, :, crop:-crop, crop:-crop]
else:
gt_cropped = gt
gt_mass = (gt_cropped > 0.0).float().sum(dim=(1, 2, 3))
expected_err_abs = (
diff * err_mask.sum(dim=(1, 2, 3)) / (gt_mass)
)
expected_err_mse = diff * expected_err_abs
self.assertTrue(
torch.allclose(
abs_depth_fix_diff,
expected_err_abs * torch.ones_like(abs_depth_fix_diff),
atol=1e-4,
)
)
self.assertTrue(
torch.allclose(
mse_depth_fix_diff,
expected_err_mse * torch.ones_like(mse_depth_fix_diff),
atol=1e-4,
)
)
def test_psnr(self):
"""
Compare against opencv and check that the psnr is above
the minimum possible value.
"""
import cv2
im1 = torch.rand(100, 3, 256, 256).cuda()
im1_uint8 = (im1 * 255).to(torch.uint8)
im1_rounded = im1_uint8.float() / 255
for max_diff in 10 ** torch.linspace(-5, 0, 6):
im2 = im1 + (torch.rand_like(im1) - 0.5) * 2 * max_diff
im2 = im2.clamp(0.0, 1.0)
im2_uint8 = (im2 * 255).to(torch.uint8)
im2_rounded = im2_uint8.float() / 255
# check that our psnr matches the output of opencv
psnr = calc_psnr(im1_rounded, im2_rounded)
# some versions of cv2 can only take uint8 input
psnr_cv2 = cv2.PSNR(
im1_uint8.cpu().numpy(),
im2_uint8.cpu().numpy(),
)
self.assertAlmostEqual(float(psnr), float(psnr_cv2), delta=1e-4)
# check that all PSNRs are bigger than the minimum possible PSNR
max_mse = max_diff ** 2
min_psnr = 10 * math.log10(1.0 / max_mse)
for _im1, _im2 in zip(im1, im2):
_psnr = calc_psnr(_im1, _im2)
self.assertGreaterEqual(float(_psnr) + 1e-6, min_psnr)
def _one_sequence_test(
self,
seq_dataset,
n_batches=2,
min_batch_size=5,
max_batch_size=10,
):
# form a list of random batches
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
loader = torch.utils.data.DataLoader(
seq_dataset,
# batch_size=1,
shuffle=False,
batch_sampler=batch_indices,
collate_fn=FrameData.collate,
)
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
model.cuda()
self.lpips_model.cuda()
for frame_data in loader:
self.assertIsNone(frame_data.frame_type)
self.assertIsNotNone(frame_data.image_rgb)
# override the frame_type
frame_data.frame_type = [
"train_unseen",
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
]
# move frame_data to gpu
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
eval_result = eval_batch(
frame_data,
nvs_prediction,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
nvs_prediction_bad.depth_render += (
torch.randn_like(nvs_prediction.depth_render) * 100.0
)
nvs_prediction_bad.image_render += (
torch.randn_like(nvs_prediction.image_render) * 100.0
)
nvs_prediction_bad.mask_render = (
torch.randn_like(nvs_prediction.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
nvs_prediction_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better.keys():
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
)
_assert(m_better, m_worse)
def test_full_eval(self, n_sequences=5):
"""Test evaluation."""
for _, idx in list(self.dataset.seq_to_idx.items())[:n_sequences]:
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset)

View File

@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields
from pytorch3d.renderer.cameras import PerspectiveCameras, look_at_view_transform
class TestGenericModel(unittest.TestCase):
def test_gm(self):
# Simple test of a forward pass of the default GenericModel.
device = torch.device("cuda:1")
expand_args_fields(GenericModel)
model = GenericModel()
model.to(device)
n_train_cameras = 2
R, T = look_at_view_transform(azim=torch.rand(n_train_cameras) * 360)
cameras = PerspectiveCameras(R=R, T=T, device=device)
# TODO: make these default to None?
defaulted_args = {
"fg_probability": None,
"depth_map": None,
"mask_crop": None,
"sequence_name": None,
}
with self.assertWarnsRegex(UserWarning, "No main objective found"):
model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
**defaulted_args,
image_rgb=None,
)
target_image_rgb = torch.rand(
(n_train_cameras, 3, model.render_image_height, model.render_image_width),
device=device,
)
train_preds = model(
camera=cameras,
evaluation_mode=EvaluationMode.TRAINING,
image_rgb=target_image_rgb,
**defaulted_args,
)
self.assertGreater(train_preds["objective"].item(), 0)
model.eval()
with torch.no_grad():
# TODO: perhaps this warning should be skipped in eval mode?
with self.assertWarnsRegex(UserWarning, "No main objective found"):
eval_preds = model(
camera=cameras[0],
**defaulted_args,
image_rgb=None,
)
self.assertEqual(
eval_preds["images_render"].shape,
(1, 3, model.render_image_height, model.render_image_width),
)

View File

@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.models.renderer.ray_point_refiner import RayPointRefiner
from pytorch3d.renderer import RayBundle
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
class TestRayPointRefiner(TestCaseMixin, unittest.TestCase):
def test_simple(self):
length = 15
n_pts_per_ray = 10
for add_input_samples in [False, True]:
ray_point_refiner = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=False,
add_input_samples=add_input_samples,
)
lengths = torch.arange(length, dtype=torch.float32).expand(3, 25, length)
bundle = RayBundle(lengths=lengths, origins=None, directions=None, xys=None)
weights = torch.ones(3, 25, length)
refined = ray_point_refiner(bundle, weights)
self.assertIsNone(refined.directions)
self.assertIsNone(refined.origins)
self.assertIsNone(refined.xys)
expected = torch.linspace(0.5, length - 1.5, n_pts_per_ray)
expected = expected.expand(3, 25, n_pts_per_ray)
if add_input_samples:
full_expected = torch.cat((lengths, expected), dim=-1).sort()[0]
else:
full_expected = expected
self.assertClose(refined.lengths, full_expected)
ray_point_refiner_random = RayPointRefiner(
n_pts_per_ray=n_pts_per_ray,
random_sampling=True,
add_input_samples=add_input_samples,
)
refined_random = ray_point_refiner_random(bundle, weights)
lengths_random = refined_random.lengths
self.assertEqual(lengths_random.shape, full_expected.shape)
if not add_input_samples:
self.assertGreater(lengths_random.min().item(), 0.5)
self.assertLess(lengths_random.max().item(), length - 1.5)
# Check sorted
self.assertTrue(
(lengths_random[..., 1:] - lengths_random[..., :-1] > 0).all()
)

View File

@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import os
import unittest
import torch
from pytorch3d.implicitron.models.implicit_function.scene_representation_networks import (
SRNHyperNetImplicitFunction,
SRNImplicitFunction,
SRNPixelGenerator,
)
from pytorch3d.implicitron.models.renderer.base import ImplicitFunctionWrapper
from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer import RayBundle
if os.environ.get("FB_TEST", False):
from common_testing import TestCaseMixin
else:
from tests.common_testing import TestCaseMixin
_BATCH_SIZE: int = 3
_N_RAYS: int = 100
_N_POINTS_ON_RAY: int = 10
class TestSRN(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
torch.manual_seed(42)
get_default_args(SRNHyperNetImplicitFunction)
get_default_args(SRNImplicitFunction)
def test_pixel_generator(self):
SRNPixelGenerator()
def _get_bundle(self, *, device) -> RayBundle:
origins = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
directions = torch.rand(_BATCH_SIZE, _N_RAYS, 3, device=device)
lengths = torch.rand(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, device=device)
bundle = RayBundle(
lengths=lengths, origins=origins, directions=directions, xys=None
)
return bundle
def test_srn_implicit_function(self):
implicit_function = SRNImplicitFunction()
device = torch.device("cpu")
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle)
out_features = implicit_function.raymarch_function.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function(self):
# TODO investigate: If latent_dim_hypernet=0, why does this crash and dump core?
latent_dim_hypernet = 39
hypernet_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hypernet_args)
implicit_function.to(device)
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
rays_densities, rays_colors = implicit_function(bundle, global_code=global_code)
out_features = implicit_function.hypernet.out_features
self.assertEqual(
rays_densities.shape,
(_BATCH_SIZE, _N_RAYS, _N_POINTS_ON_RAY, out_features),
)
self.assertIsNone(rays_colors)
def test_srn_hypernet_implicit_function_optim(self):
# Test optimization loop, requiring that the cache is properly
# cleared in new_args_bound
latent_dim_hypernet = 39
hyper_args = {"latent_dim_hypernet": latent_dim_hypernet}
device = torch.device("cuda:0")
global_code = torch.rand(_BATCH_SIZE, latent_dim_hypernet, device=device)
bundle = self._get_bundle(device=device)
implicit_function = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
implicit_function2 = SRNHyperNetImplicitFunction(hypernet_args=hyper_args)
implicit_function.to(device)
implicit_function2.to(device)
wrapper = ImplicitFunctionWrapper(implicit_function)
optimizer = torch.optim.Adam(implicit_function.parameters())
for _step in range(3):
optimizer.zero_grad()
wrapper.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper(bundle)
wrapper.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer.step()
wrapper2 = ImplicitFunctionWrapper(implicit_function)
optimizer2 = torch.optim.Adam(implicit_function2.parameters())
implicit_function2.load_state_dict(implicit_function.state_dict())
optimizer2.load_state_dict(optimizer.state_dict())
for _step in range(3):
optimizer2.zero_grad()
wrapper2.bind_args(global_code=global_code)
rays_densities, _rays_colors = wrapper2(bundle)
wrapper2.unbind_args()
loss = rays_densities.sum()
loss.backward()
optimizer2.step()

View File

@@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import unittest
from typing import Dict, List, NamedTuple, Tuple
from pytorch3d.implicitron.dataset import types
from pytorch3d.implicitron.dataset.types import FrameAnnotation
class _NT(NamedTuple):
annot: FrameAnnotation
class TestDatasetTypes(unittest.TestCase):
def setUp(self):
self.entry = FrameAnnotation(
frame_number=23,
sequence_name="1",
frame_timestamp=1.2,
image=types.ImageAnnotation(path="/tmp/1.jpg", size=(224, 224)),
mask=types.MaskAnnotation(path="/tmp/1.png", mass=42.0),
viewpoint=types.ViewpointAnnotation(
R=(
(1, 0, 0),
(1, 0, 0),
(1, 0, 0),
),
T=(0, 0, 0),
principal_point=(100, 100),
focal_length=(200, 200),
),
)
def test_asdict_rec(self):
first = [dataclasses.asdict(self.entry)]
second = types._asdict_rec([self.entry])
self.assertEqual(first, second)
def test_parsing(self):
"""Test that we handle collections enclosing dataclasses."""
dct = dataclasses.asdict(self.entry)
parsed = types._dataclass_from_dict(dct, FrameAnnotation)
self.assertEqual(parsed, self.entry)
# namedtuple
parsed = types._dataclass_from_dict(_NT(dct), _NT)
self.assertEqual(parsed.annot, self.entry)
# tuple
parsed = types._dataclass_from_dict((dct,), Tuple[FrameAnnotation])
self.assertEqual(parsed, (self.entry,))
# list
parsed = types._dataclass_from_dict(
[
dct,
],
List[FrameAnnotation],
)
self.assertEqual(
parsed,
[
self.entry,
],
)
# dict
parsed = types._dataclass_from_dict({"k": dct}, Dict[str, FrameAnnotation])
self.assertEqual(parsed, {"k": self.entry})
def test_parsing_vectorized(self):
dct = dataclasses.asdict(self.entry)
self._compare_with_scalar(dct, FrameAnnotation)
self._compare_with_scalar(_NT(dct), _NT)
self._compare_with_scalar((dct,), Tuple[FrameAnnotation])
self._compare_with_scalar([dct], List[FrameAnnotation])
self._compare_with_scalar({"k": dct}, Dict[str, FrameAnnotation])
def _compare_with_scalar(self, obj, typeannot, repeat=3):
input = [obj] * 3
vect_output = types._dataclass_list_from_dict_list(input, typeannot)
self.assertEqual(len(input), repeat)
gt = types._dataclass_from_dict(obj, typeannot)
self.assertTrue(all(res == gt for res in vect_output))

View File

@@ -0,0 +1,270 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import pytorch3d as pt3d
import torch
from pytorch3d.implicitron.models.view_pooling.view_sampling import ViewSampler
from pytorch3d.implicitron.tools.config import expand_args_fields
class TestViewsampling(unittest.TestCase):
def setUp(self):
torch.manual_seed(42)
expand_args_fields(ViewSampler)
def _init_view_sampler_problem(self, random_masks):
"""
Generates a view-sampling problem:
- 4 source views, 1st/2nd from the first sequence 'seq1', the rest from 'seq2'
- 3 sets of 3D points from sequences 'seq1', 'seq2', 'seq2' respectively.
- first 50 points in each batch correctly project to the source views,
while the remaining 50 do not land in any projection plane.
- each source view is labeled with image feature tensors of shape 7x100x50,
where all elements of the n-th tensor are set to `n+1`.
- the elements of the source view masks are either set to random binary number
(if `random_masks==True`), or all set to 1 (`random_masks==False`).
- the source view cameras are uniformly distributed on a unit circle
in the x-z plane and look at (0,0,0).
"""
seq_id_camera = ["seq1", "seq1", "seq2", "seq2"]
seq_id_pts = ["seq1", "seq2", "seq2"]
pts_batch = 3
n_pts = 100
n_views = 4
fdim = 7
H = 100
W = 50
# points that land into the projection planes of all cameras
pts_inside = (
torch.nn.functional.normalize(
torch.randn(pts_batch, n_pts // 2, 3, device="cuda"),
dim=-1,
)
* 0.1
)
# move the outside points far above the scene
pts_outside = pts_inside.clone()
pts_outside[:, :, 1] += 1e8
pts = torch.cat([pts_inside, pts_outside], dim=1)
R, T = pt3d.renderer.look_at_view_transform(
dist=1.0,
elev=0.0,
azim=torch.linspace(0, 360, n_views + 1)[:n_views],
degrees=True,
device=pts.device,
)
focal_length = R.new_ones(n_views, 2)
principal_point = R.new_zeros(n_views, 2)
camera = pt3d.renderer.PerspectiveCameras(
R=R,
T=T,
focal_length=focal_length,
principal_point=principal_point,
device=pts.device,
)
feats_map = torch.arange(n_views, device=pts.device, dtype=pts.dtype) + 1
feats = {"feats": feats_map[:, None, None, None].repeat(1, fdim, H, W)}
masks = (
torch.rand(n_views, 1, H, W, device=pts.device, dtype=pts.dtype) > 0.5
).type_as(R)
if not random_masks:
masks[:] = 1.0
return pts, camera, feats, masks, seq_id_camera, seq_id_pts
def test_compare_with_naive(self):
"""
Compares the outputs of the efficient ViewSampler module with a
naive implementation.
"""
(
pts,
camera,
feats,
masks,
seq_id_camera,
seq_id_pts,
) = self._init_view_sampler_problem(True)
for masked_sampling in (True, False):
feats_sampled_n, masks_sampled_n = _view_sample_naive(
pts,
seq_id_pts,
camera,
seq_id_camera,
feats,
masks,
masked_sampling,
)
# make sure we generate the constructor for ViewSampler
expand_args_fields(ViewSampler)
view_sampler = ViewSampler(masked_sampling=masked_sampling)
feats_sampled, masks_sampled = view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
for k in feats_sampled.keys():
self.assertTrue(torch.allclose(feats_sampled[k], feats_sampled_n[k]))
self.assertTrue(torch.allclose(masks_sampled, masks_sampled_n))
def test_viewsampling(self):
"""
Generates a viewsampling problem with predictable outcome, and compares
the ViewSampler's output to the expected result.
"""
(
pts,
camera,
feats,
masks,
seq_id_camera,
seq_id_pts,
) = self._init_view_sampler_problem(False)
expand_args_fields(ViewSampler)
for masked_sampling in (True, False):
view_sampler = ViewSampler(masked_sampling=masked_sampling)
feats_sampled, masks_sampled = view_sampler(
pts=pts,
seq_id_pts=seq_id_pts,
camera=camera,
seq_id_camera=seq_id_camera,
feats=feats,
masks=masks,
)
n_views = camera.R.shape[0]
n_pts = pts.shape[1]
feat_dim = feats["feats"].shape[1]
pts_batch = pts.shape[0]
n_pts_away = n_pts // 2
for pts_i in range(pts_batch):
for view_i in range(n_views):
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
# points / cameras come from different sequences
gt_masks = pts.new_zeros(n_pts, 1)
gt_feats = pts.new_zeros(n_pts, feat_dim)
else:
gt_masks = pts.new_ones(n_pts, 1)
gt_feats = pts.new_ones(n_pts, feat_dim) * (view_i + 1)
gt_feats[n_pts_away:] = 0.0
if masked_sampling:
gt_masks[n_pts_away:] = 0.0
for k in feats_sampled:
self.assertTrue(
torch.allclose(
feats_sampled[k][pts_i, view_i],
gt_feats,
)
)
self.assertTrue(
torch.allclose(
masks_sampled[pts_i, view_i],
gt_masks,
)
)
def _view_sample_naive(
pts,
seq_id_pts,
camera,
seq_id_camera,
feats,
masks,
masked_sampling,
):
"""
A naive implementation of the forward pass of ViewSampler.
Refer to ViewSampler's docstring for description of the arguments.
"""
pts_batch = pts.shape[0]
n_views = camera.R.shape[0]
n_pts = pts.shape[1]
feats_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
masks_sampled = [[[] for _ in range(n_views)] for _ in range(pts_batch)]
for pts_i in range(pts_batch):
for view_i in range(n_views):
if seq_id_pts[pts_i] != seq_id_camera[view_i]:
# points/cameras come from different sequences
feats_sampled_ = {
k: f.new_zeros(n_pts, f.shape[1]) for k, f in feats.items()
}
masks_sampled_ = masks.new_zeros(n_pts, 1)
else:
# same sequence of pts and cameras -> sample
feats_sampled_, masks_sampled_ = _sample_one_view_naive(
camera[view_i],
pts[pts_i],
{k: f[view_i] for k, f in feats.items()},
masks[view_i],
masked_sampling,
sampling_mode="bilinear",
)
feats_sampled[pts_i][view_i] = feats_sampled_
masks_sampled[pts_i][view_i] = masks_sampled_
masks_sampled_cat = torch.stack([torch.stack(m) for m in masks_sampled])
feats_sampled_cat = {}
for k in feats_sampled[0][0].keys():
feats_sampled_cat[k] = torch.stack(
[torch.stack([f_[k] for f_ in f]) for f in feats_sampled]
)
return feats_sampled_cat, masks_sampled_cat
def _sample_one_view_naive(
camera,
pts,
feats,
masks,
masked_sampling,
sampling_mode="bilinear",
):
"""
Sample a single source view.
"""
proj_ndc = camera.transform_points(pts[None])[None, ..., :-1] # 1 x 1 x n_pts x 2
feats_sampled = {
k: pt3d.renderer.ndc_grid_sample(f[None], proj_ndc, mode=sampling_mode).permute(
0, 3, 1, 2
)[0, :, :, 0]
for k, f in feats.items()
} # n_pts x dim
if not masked_sampling:
n_pts = pts.shape[0]
masks_sampled = proj_ndc.new_ones(n_pts, 1)
else:
masks_sampled = pt3d.renderer.ndc_grid_sample(
masks[None],
proj_ndc,
mode=sampling_mode,
align_corners=False,
)[0, 0, 0, :][:, None]
return feats_sampled, masks_sampled