Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR

Summary:
To avoid model_zoo, we need to make GenericModel pluggable.
I also align creation APIs for convenience.

Reviewed By: bottler, davnov134

Differential Revision: D35933093

fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
This commit is contained in:
Roman Shapovalov
2022-05-09 15:23:07 -07:00
committed by Facebook GitHub Bot
parent 5c59841863
commit a6dada399d
11 changed files with 282 additions and 178 deletions

View File

@@ -9,7 +9,7 @@ import unittest
from omegaconf import OmegaConf
from pytorch3d.implicitron.models.autodecoder import Autodecoder
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
IdrFeatureField,
)

View File

@@ -6,7 +6,6 @@
import contextlib
import copy
import dataclasses
import itertools
import math
@@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
FrameData,
ImplicitronDataset,
)
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
eval_batch,
)
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
@@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase):
category = "skateboard"
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
self.image_size = 256
self.image_size = 64
self.dataset = ImplicitronDataset(
frame_annotations_file=frame_file,
sequence_annotations_file=sequence_file,
@@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase):
box_crop=True,
path_manager=path_manager,
)
self.bg_color = 0.0
self.bg_color = (0.0, 0.0, 0.0)
# init the lpips model for eval
provide_lpips_vgg()
self.lpips_model = lpips.LPIPS(net="vgg")
self.lpips_model = lpips.LPIPS(net="vgg").cuda()
def test_eval_depth(self):
"""
@@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase):
def _one_sequence_test(
self,
seq_dataset,
n_batches=2,
min_batch_size=5,
max_batch_size=10,
model,
batch_indices,
check_metrics=False,
):
# form a list of random batches
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
loader = torch.utils.data.DataLoader(
seq_dataset,
# batch_size=1,
shuffle=False,
batch_sampler=batch_indices,
collate_fn=FrameData.collate,
)
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
model.cuda()
self.lpips_model.cuda()
for frame_data in loader:
self.assertIsNone(frame_data.frame_type)
self.assertIsNotNone(frame_data.image_rgb)
@@ -233,61 +224,101 @@ class TestEvaluation(unittest.TestCase):
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
]
# move frame_data to gpu
frame_data = dataclass_to_cuda_(frame_data)
preds = model(**dataclasses.asdict(frame_data))
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
eval_result = eval_batch(
frame_data,
nvs_prediction,
preds["implicitron_render"],
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
nvs_prediction_bad.depth_render += (
torch.randn_like(nvs_prediction.depth_render) * 100.0
)
nvs_prediction_bad.image_render += (
torch.randn_like(nvs_prediction.image_render) * 100.0
)
nvs_prediction_bad.mask_render = (
torch.randn_like(nvs_prediction.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
nvs_prediction_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better.keys():
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
if check_metrics:
self._check_metrics(
frame_data, preds["implicitron_render"], eval_result
)
_assert(m_better, m_worse)
def _check_metrics(self, frame_data, implicitron_render, eval_result):
# Make a terribly bad NVS prediction and check that this is worse
# than the DBIR prediction.
implicitron_render_bad = implicitron_render.clone()
implicitron_render_bad.depth_render += (
torch.randn_like(implicitron_render_bad.depth_render) * 100.0
)
implicitron_render_bad.image_render += (
torch.randn_like(implicitron_render_bad.image_render) * 100.0
)
implicitron_render_bad.mask_render = (
torch.randn_like(implicitron_render_bad.mask_render) > 0.0
).float()
eval_result_bad = eval_batch(
frame_data,
implicitron_render_bad,
bg_color=self.bg_color,
lpips_model=self.lpips_model,
)
lower_better = {
"psnr": False,
"psnr_fg": False,
"depth_abs_fg": True,
"iou": False,
"rgb_l1": True,
"rgb_l1_fg": True,
}
for metric in lower_better:
m_better = eval_result[metric]
m_worse = eval_result_bad[metric]
if m_better != m_better or m_worse != m_worse:
continue # metric is missing, i.e. NaN
_assert = (
self.assertLessEqual
if lower_better[metric]
else self.assertGreaterEqual
)
_assert(m_better, m_worse)
def _get_random_batch_indices(
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10
):
batch_indices = []
for _ in range(n_batches):
batch_size = torch.randint(
low=min_batch_size, high=max_batch_size, size=(1,)
)
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
return batch_indices
def test_full_eval(self, n_sequences=5):
"""Test evaluation."""
# caching batch indices first to preserve RNG state
seq_datasets = {}
batch_indices = {}
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
idx = list(self.dataset.sequence_indices_in_order(seq))
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
self._one_sequence_test(seq_dataset)
seq_datasets[seq] = seq_dataset
batch_indices[seq] = self._get_random_batch_indices(seq_dataset)
for model_class_type in ["ModelDBIR", "GenericModel"]:
ModelClass = registry.get(ImplicitronModelBase, model_class_type)
expand_args_fields(ModelClass)
model = ModelClass(
render_image_width=self.image_size,
render_image_height=self.image_size,
bg_color=self.bg_color,
)
model.eval()
model.cuda()
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
self._one_sequence_test(
seq_datasets[seq],
model,
batch_indices[seq],
check_metrics=(model_class_type == "ModelDBIR"),
)

View File

@@ -7,7 +7,7 @@
import unittest
import torch
from pytorch3d.implicitron.models.base import GenericModel
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras