Mods and bugfixes for LLFF and Blender repros

Summary:
LLFF (and most/all non-synth datasets) will have no background/foreground distinction. Add support for data with no fg mask.

Also, we had a bug in stats loading, like this:
  * Load stats
  * One of the stats has a history of length 0
  * That's fine, e.g. maybe it's fg_error but the dataset has no notion of fg/bg. So leave it as len 0
  * Check whether all the stats have the same history length as an arbitrarily chosen "reference-stat"
  * Ooops the reference-stat happened to be the stat with length 0
  * assert (legit_stat_len == reference_stat_len (=0)) ---> failed assert

Also some minor fixes (from Jeremy's other diff) to support LLFF

Reviewed By: davnov134

Differential Revision: D38475272

fbshipit-source-id: 5b35ac86d1d5239759f537621f41a3aa4eb3bd68
This commit is contained in:
Krzysztof Chalupka
2022-08-09 15:04:44 -07:00
committed by Facebook GitHub Bot
parent 624bc5a274
commit c83ec3555d
11 changed files with 51 additions and 25 deletions

View File

@@ -32,17 +32,21 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single
testing frame.
downscale_factor: determines image sizes.
"""
downscale_factor: int = 4
def _load_data(self) -> None:
path_manager = self.path_manager_factory.get()
images, poses, _ = load_llff_data(
self.base_dir, factor=8, path_manager=path_manager
self.base_dir, factor=self.downscale_factor, path_manager=path_manager
)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
i_test = np.arange(images.shape[0])[::8]
llffhold = 8
i_test = np.arange(images.shape[0])[::llffhold]
i_test_index = set(i_test.tolist())
i_train = np.array(
[i for i in np.arange(images.shape[0]) if i not in i_test_index]

View File

@@ -27,6 +27,7 @@ from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
@expand_args_fields
class SingleSceneDataset(DatasetBase, Configurable):
"""
A dataset from images from a single scene.
@@ -110,7 +111,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False
) -> SingleSceneDataset:
expand_args_fields(SingleSceneDataset)
# pyre-ignore[16]
split = self.i_split[split_idx]
frame_types = [frame_type] * len(split)

View File

@@ -245,13 +245,20 @@ def eval_batch(
if frame_data.mask_crop is None:
warnings.warn("mask_crop is None, assuming the whole image is valid.")
if frame_data.fg_probability is None:
warnings.warn("fg_probability is None, assuming the whole image is fg.")
# threshold the masks to make ground truth binary masks
# pyre-ignore [58]
mask_fg = frame_data.fg_probability >= mask_thr
mask_fg = (
frame_data.fg_probability >= mask_thr
if frame_data.fg_probability is not None
# pyre-ignore [16]
else torch.ones_like(frame_data.image_rgb[:, :1, ...]).bool()
)
mask_crop = (
frame_data.mask_crop
if frame_data.mask_crop is not None
# pyre-ignore [6]
else torch.ones_like(mask_fg)
)
@@ -259,7 +266,6 @@ def eval_batch(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`.
frame_data.image_rgb,
# pyre-ignore [6]
mask_fg,
bg_color=bg_color,
)
@@ -275,7 +281,6 @@ def eval_batch(
# pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map,
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
depth_mask=frame_data.depth_mask[:1],
visdom_env=visualize_visdom_env,
)
@@ -284,7 +289,7 @@ def eval_batch(
results["iou"] = iou(
cloned_render["mask_render"],
mask_fg, # pyre-ignore [6]
mask_fg,
mask=mask_crop,
)

View File

@@ -13,8 +13,8 @@ from typing import Any, Dict, List, Optional, Tuple
import lpips
import torch
import tqdm
import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate

View File

@@ -198,7 +198,6 @@ class Stats(object):
if verbose:
print(f"Adding {add_log_var}")
self.log_vars.append(add_log_var)
# self.synchronize_logged_vars(self.log_vars, verbose=verbose)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
@@ -230,7 +229,6 @@ class Stats(object):
elapsed = time.time() - time_start
time_per_it = float(elapsed) / float(it + 1)
val = time_per_it
# self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
else:
if stat in preds:
try:
@@ -441,7 +439,6 @@ class Stats(object):
self.log_vars = log_vars # !!!
for stat_set in stat_sets:
reference_stat = list(self.stats[stat_set].keys())[0]
for stat in log_vars:
if stat not in self.stats[stat_set]:
if verbose:
@@ -468,12 +465,11 @@ class Stats(object):
lastep = self.epoch + 1
for ep in range(lastep):
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
epoch_self = self.stats[stat_set][reference_stat].get_epoch()
epoch_generated = self.stats[stat_set][stat].get_epoch()
assert (
epoch_self == epoch_generated
epoch_generated == self.epoch + 1
), "bad epoch of synchronized log_var! %d vs %d" % (
epoch_self,
self.epoch + 1,
epoch_generated,
)