mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 05:40:34 +08:00
Extracted ImplicitronModelBase and unified API for GenericModel and ModelDBIR
Summary: To avoid model_zoo, we need to make GenericModel pluggable. I also align creation APIs for convenience. Reviewed By: bottler, davnov134 Differential Revision: D35933093 fbshipit-source-id: 8228926528eb41a795fbfbe32304b8019197e2b1
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5c59841863
commit
a6dada399d
@@ -9,7 +9,7 @@ import unittest
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.models.autodecoder import Autodecoder
|
||||
from pytorch3d.implicitron.models.base import GenericModel
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.implicit_function.idr_feature_field import (
|
||||
IdrFeatureField,
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import dataclasses
|
||||
import itertools
|
||||
import math
|
||||
@@ -19,8 +18,13 @@ from pytorch3d.implicitron.dataset.implicitron_dataset import (
|
||||
FrameData,
|
||||
ImplicitronDataset,
|
||||
)
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import eval_batch
|
||||
from pytorch3d.implicitron.models.model_dbir import ModelDBIR
|
||||
from pytorch3d.implicitron.evaluation.evaluate_new_view_synthesis import (
|
||||
eval_batch,
|
||||
)
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel # noqa
|
||||
from pytorch3d.implicitron.models.model_dbir import ModelDBIR # noqa
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, registry
|
||||
from pytorch3d.implicitron.tools.metric_utils import calc_psnr, eval_depth
|
||||
from pytorch3d.implicitron.tools.utils import dataclass_to_cuda_
|
||||
|
||||
@@ -43,7 +47,7 @@ class TestEvaluation(unittest.TestCase):
|
||||
category = "skateboard"
|
||||
frame_file = os.path.join(dataset_root, category, "frame_annotations.jgz")
|
||||
sequence_file = os.path.join(dataset_root, category, "sequence_annotations.jgz")
|
||||
self.image_size = 256
|
||||
self.image_size = 64
|
||||
self.dataset = ImplicitronDataset(
|
||||
frame_annotations_file=frame_file,
|
||||
sequence_annotations_file=sequence_file,
|
||||
@@ -53,11 +57,11 @@ class TestEvaluation(unittest.TestCase):
|
||||
box_crop=True,
|
||||
path_manager=path_manager,
|
||||
)
|
||||
self.bg_color = 0.0
|
||||
self.bg_color = (0.0, 0.0, 0.0)
|
||||
|
||||
# init the lpips model for eval
|
||||
provide_lpips_vgg()
|
||||
self.lpips_model = lpips.LPIPS(net="vgg")
|
||||
self.lpips_model = lpips.LPIPS(net="vgg").cuda()
|
||||
|
||||
def test_eval_depth(self):
|
||||
"""
|
||||
@@ -200,30 +204,17 @@ class TestEvaluation(unittest.TestCase):
|
||||
def _one_sequence_test(
|
||||
self,
|
||||
seq_dataset,
|
||||
n_batches=2,
|
||||
min_batch_size=5,
|
||||
max_batch_size=10,
|
||||
model,
|
||||
batch_indices,
|
||||
check_metrics=False,
|
||||
):
|
||||
# form a list of random batches
|
||||
batch_indices = []
|
||||
for _ in range(n_batches):
|
||||
batch_size = torch.randint(
|
||||
low=min_batch_size, high=max_batch_size, size=(1,)
|
||||
)
|
||||
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
seq_dataset,
|
||||
# batch_size=1,
|
||||
shuffle=False,
|
||||
batch_sampler=batch_indices,
|
||||
collate_fn=FrameData.collate,
|
||||
)
|
||||
|
||||
model = ModelDBIR(image_size=self.image_size, bg_color=self.bg_color)
|
||||
model.cuda()
|
||||
self.lpips_model.cuda()
|
||||
|
||||
for frame_data in loader:
|
||||
self.assertIsNone(frame_data.frame_type)
|
||||
self.assertIsNotNone(frame_data.image_rgb)
|
||||
@@ -233,61 +224,101 @@ class TestEvaluation(unittest.TestCase):
|
||||
*(["train_known"] * (len(frame_data.image_rgb) - 1)),
|
||||
]
|
||||
|
||||
# move frame_data to gpu
|
||||
frame_data = dataclass_to_cuda_(frame_data)
|
||||
preds = model(**dataclasses.asdict(frame_data))
|
||||
|
||||
nvs_prediction = copy.deepcopy(preds["nvs_prediction"])
|
||||
eval_result = eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction,
|
||||
preds["implicitron_render"],
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
# Make a terribly bad NVS prediction and check that this is worse
|
||||
# than the DBIR prediction.
|
||||
nvs_prediction_bad = copy.deepcopy(preds["nvs_prediction"])
|
||||
nvs_prediction_bad.depth_render += (
|
||||
torch.randn_like(nvs_prediction.depth_render) * 100.0
|
||||
)
|
||||
nvs_prediction_bad.image_render += (
|
||||
torch.randn_like(nvs_prediction.image_render) * 100.0
|
||||
)
|
||||
nvs_prediction_bad.mask_render = (
|
||||
torch.randn_like(nvs_prediction.mask_render) > 0.0
|
||||
).float()
|
||||
eval_result_bad = eval_batch(
|
||||
frame_data,
|
||||
nvs_prediction_bad,
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
lower_better = {
|
||||
"psnr": False,
|
||||
"psnr_fg": False,
|
||||
"depth_abs_fg": True,
|
||||
"iou": False,
|
||||
"rgb_l1": True,
|
||||
"rgb_l1_fg": True,
|
||||
}
|
||||
|
||||
for metric in lower_better.keys():
|
||||
m_better = eval_result[metric]
|
||||
m_worse = eval_result_bad[metric]
|
||||
if m_better != m_better or m_worse != m_worse:
|
||||
continue # metric is missing, i.e. NaN
|
||||
_assert = (
|
||||
self.assertLessEqual
|
||||
if lower_better[metric]
|
||||
else self.assertGreaterEqual
|
||||
if check_metrics:
|
||||
self._check_metrics(
|
||||
frame_data, preds["implicitron_render"], eval_result
|
||||
)
|
||||
_assert(m_better, m_worse)
|
||||
|
||||
def _check_metrics(self, frame_data, implicitron_render, eval_result):
|
||||
# Make a terribly bad NVS prediction and check that this is worse
|
||||
# than the DBIR prediction.
|
||||
implicitron_render_bad = implicitron_render.clone()
|
||||
implicitron_render_bad.depth_render += (
|
||||
torch.randn_like(implicitron_render_bad.depth_render) * 100.0
|
||||
)
|
||||
implicitron_render_bad.image_render += (
|
||||
torch.randn_like(implicitron_render_bad.image_render) * 100.0
|
||||
)
|
||||
implicitron_render_bad.mask_render = (
|
||||
torch.randn_like(implicitron_render_bad.mask_render) > 0.0
|
||||
).float()
|
||||
eval_result_bad = eval_batch(
|
||||
frame_data,
|
||||
implicitron_render_bad,
|
||||
bg_color=self.bg_color,
|
||||
lpips_model=self.lpips_model,
|
||||
)
|
||||
|
||||
lower_better = {
|
||||
"psnr": False,
|
||||
"psnr_fg": False,
|
||||
"depth_abs_fg": True,
|
||||
"iou": False,
|
||||
"rgb_l1": True,
|
||||
"rgb_l1_fg": True,
|
||||
}
|
||||
|
||||
for metric in lower_better:
|
||||
m_better = eval_result[metric]
|
||||
m_worse = eval_result_bad[metric]
|
||||
if m_better != m_better or m_worse != m_worse:
|
||||
continue # metric is missing, i.e. NaN
|
||||
_assert = (
|
||||
self.assertLessEqual
|
||||
if lower_better[metric]
|
||||
else self.assertGreaterEqual
|
||||
)
|
||||
_assert(m_better, m_worse)
|
||||
|
||||
def _get_random_batch_indices(
|
||||
self, seq_dataset, n_batches=2, min_batch_size=5, max_batch_size=10
|
||||
):
|
||||
batch_indices = []
|
||||
for _ in range(n_batches):
|
||||
batch_size = torch.randint(
|
||||
low=min_batch_size, high=max_batch_size, size=(1,)
|
||||
)
|
||||
batch_indices.append(torch.randperm(len(seq_dataset))[:batch_size])
|
||||
|
||||
return batch_indices
|
||||
|
||||
def test_full_eval(self, n_sequences=5):
|
||||
"""Test evaluation."""
|
||||
|
||||
# caching batch indices first to preserve RNG state
|
||||
seq_datasets = {}
|
||||
batch_indices = {}
|
||||
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||
idx = list(self.dataset.sequence_indices_in_order(seq))
|
||||
seq_dataset = torch.utils.data.Subset(self.dataset, idx)
|
||||
self._one_sequence_test(seq_dataset)
|
||||
seq_datasets[seq] = seq_dataset
|
||||
batch_indices[seq] = self._get_random_batch_indices(seq_dataset)
|
||||
|
||||
for model_class_type in ["ModelDBIR", "GenericModel"]:
|
||||
ModelClass = registry.get(ImplicitronModelBase, model_class_type)
|
||||
expand_args_fields(ModelClass)
|
||||
model = ModelClass(
|
||||
render_image_width=self.image_size,
|
||||
render_image_height=self.image_size,
|
||||
bg_color=self.bg_color,
|
||||
)
|
||||
model.eval()
|
||||
model.cuda()
|
||||
|
||||
for seq in itertools.islice(self.dataset.sequence_names(), n_sequences):
|
||||
self._one_sequence_test(
|
||||
seq_datasets[seq],
|
||||
model,
|
||||
batch_indices[seq],
|
||||
check_metrics=(model_class_type == "ModelDBIR"),
|
||||
)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.base import GenericModel
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
|
||||
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
|
||||
from pytorch3d.renderer.cameras import look_at_view_transform, PerspectiveCameras
|
||||
|
||||
Reference in New Issue
Block a user