Mods and bugfixes for LLFF and Blender repros

Summary:
LLFF (and most/all non-synth datasets) will have no background/foreground distinction. Add support for data with no fg mask.

Also, we had a bug in stats loading, like this:
  * Load stats
  * One of the stats has a history of length 0
  * That's fine, e.g. maybe it's fg_error but the dataset has no notion of fg/bg. So leave it as len 0
  * Check whether all the stats have the same history length as an arbitrarily chosen "reference-stat"
  * Ooops the reference-stat happened to be the stat with length 0
  * assert (legit_stat_len == reference_stat_len (=0)) ---> failed assert

Also some minor fixes (from Jeremy's other diff) to support LLFF

Reviewed By: davnov134

Differential Revision: D38475272

fbshipit-source-id: 5b35ac86d1d5239759f537621f41a3aa4eb3bd68
This commit is contained in:
Krzysztof Chalupka 2022-08-09 15:04:44 -07:00 committed by Facebook GitHub Bot
parent 624bc5a274
commit c83ec3555d
11 changed files with 51 additions and 25 deletions

View File

@ -1,7 +1,7 @@
defaults: defaults:
- repro_singleseq_base - repro_singleseq_base
- _self_ - _self_
exp_dir: "./data/nerf_blender_publ/${oc.env:BLENDER_SINGLESEQ_CLASS}" exp_dir: "./data/nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}"
data_source_ImplicitronDataSource_args: data_source_ImplicitronDataSource_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args: data_loader_map_provider_SequenceDataLoaderMapProvider_args:
dataset_length_train: 100 dataset_length_train: 100
@ -16,17 +16,18 @@ data_source_ImplicitronDataSource_args:
model_factory_ImplicitronModelFactory_args: model_factory_ImplicitronModelFactory_args:
model_GenericModel_args: model_GenericModel_args:
raysampler_AdaptiveRaySampler_args: mask_images: false
raysampler_class_type: NearFarRaySampler
raysampler_NearFarRaySampler_args:
n_rays_per_image_sampled_from_mask: 4096 n_rays_per_image_sampled_from_mask: 4096
scene_extent: 2.0 min_depth: 2
max_depth: 6
renderer_MultiPassEmissionAbsorptionRenderer_args: renderer_MultiPassEmissionAbsorptionRenderer_args:
density_noise_std_train: 1.0 density_noise_std_train: 1.0
n_pts_per_ray_fine_training: 128 n_pts_per_ray_fine_training: 128
n_pts_per_ray_fine_evaluation: 128 n_pts_per_ray_fine_evaluation: 128
raymarcher_EmissionAbsorptionRaymarcher_args: raymarcher_EmissionAbsorptionRaymarcher_args:
blend_output: true blend_output: false
bg_color:
- 1.0
loss_weights: loss_weights:
loss_rgb_mse: 1.0 loss_rgb_mse: 1.0
loss_prev_stage_rgb_mse: 1.0 loss_prev_stage_rgb_mse: 1.0
@ -35,11 +36,11 @@ model_factory_ImplicitronModelFactory_args:
loss_autodecoder_norm: 0.00 loss_autodecoder_norm: 0.00
optimizer_factory_ImplicitronOptimizerFactory_args: optimizer_factory_ImplicitronOptimizerFactory_args:
exponential_lr_step_size: 3001 exponential_lr_step_size: 2500
lr_policy: Exponential lr_policy: Exponential
training_loop_ImplicitronTrainingLoop_args: training_loop_ImplicitronTrainingLoop_args:
max_epochs: 3001 max_epochs: 2000
metric_print_interval: 100 metric_print_interval: 100
store_checkpoints_purge: 3 store_checkpoints_purge: 3
test_when_finished: true test_when_finished: true

View File

@ -249,6 +249,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
stats = Stats( stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig # log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars), list(log_vars),
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
visdom_env=visdom_env_charts, visdom_env=visdom_env_charts,
verbose=False, verbose=False,
visdom_server=self.visdom_server, visdom_server=self.visdom_server,

View File

@ -95,6 +95,7 @@ data_source_ImplicitronDataSource_args:
n_known_frames_for_test: null n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args: path_manager_factory_PathManagerFactory_args:
silence_logs: true silence_logs: true
downscale_factor: 4
dataset_map_provider_RenderedMeshDatasetMapProvider_args: dataset_map_provider_RenderedMeshDatasetMapProvider_args:
num_views: 40 num_views: 40
data_file: null data_file: null

View File

@ -162,7 +162,7 @@ class TestExperiment(unittest.TestCase):
class TestNerfRepro(unittest.TestCase): class TestNerfRepro(unittest.TestCase):
@unittest.skip("This runs full NeRF training on Blender data.") @unittest.skip("This test runs full blender training.")
def test_nerf_blender(self): def test_nerf_blender(self):
# Train vanilla NERF. # Train vanilla NERF.
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first! # Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
@ -174,6 +174,22 @@ class TestNerfRepro(unittest.TestCase):
experiment.dump_cfg(cfg) experiment.dump_cfg(cfg)
experiment_runner.run() experiment_runner.run()
@unittest.skip("This test runs full llff training.")
def test_nerf_llff(self):
# Train vanilla NERF.
# Set env vars LLFF_DATASET_ROOT and LLFF_SINGLESEQ_CLASS first!
LLFF_SINGLESEQ_CLASS = os.environ["LLFF_SINGLESEQ_CLASS"]
if not interactive_testing_requested():
return
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
cfg = compose(
config_name=f"repro_singleseq_nerf_llff_{LLFF_SINGLESEQ_CLASS}",
overrides=[],
)
experiment_runner = experiment.Experiment(**cfg)
experiment.dump_cfg(cfg)
experiment_runner.run()
@unittest.skip("This test checks resuming of the NeRF training.") @unittest.skip("This test checks resuming of the NeRF training.")
def test_nerf_blender_resume(self): def test_nerf_blender_resume(self):
# Train one train batch of NeRF, then resume for one more batch. # Train one train batch of NeRF, then resume for one more batch.

View File

@ -32,17 +32,21 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
and test datasets, and this many random training frames are added to and test datasets, and this many random training frames are added to
each test batch. If not set, test batches each contain just a single each test batch. If not set, test batches each contain just a single
testing frame. testing frame.
downscale_factor: determines image sizes.
""" """
downscale_factor: int = 4
def _load_data(self) -> None: def _load_data(self) -> None:
path_manager = self.path_manager_factory.get() path_manager = self.path_manager_factory.get()
images, poses, _ = load_llff_data( images, poses, _ = load_llff_data(
self.base_dir, factor=8, path_manager=path_manager self.base_dir, factor=self.downscale_factor, path_manager=path_manager
) )
hwf = poses[0, :3, -1] hwf = poses[0, :3, -1]
poses = poses[:, :3, :4] poses = poses[:, :3, :4]
i_test = np.arange(images.shape[0])[::8] llffhold = 8
i_test = np.arange(images.shape[0])[::llffhold]
i_test_index = set(i_test.tolist()) i_test_index = set(i_test.tolist())
i_train = np.array( i_train = np.array(
[i for i in np.arange(images.shape[0]) if i not in i_test_index] [i for i in np.arange(images.shape[0]) if i not in i_test_index]

View File

@ -27,6 +27,7 @@ from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
_SINGLE_SEQUENCE_NAME: str = "one_sequence" _SINGLE_SEQUENCE_NAME: str = "one_sequence"
@expand_args_fields
class SingleSceneDataset(DatasetBase, Configurable): class SingleSceneDataset(DatasetBase, Configurable):
""" """
A dataset from images from a single scene. A dataset from images from a single scene.
@ -110,7 +111,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def _get_dataset( def _get_dataset(
self, split_idx: int, frame_type: str, set_eval_batches: bool = False self, split_idx: int, frame_type: str, set_eval_batches: bool = False
) -> SingleSceneDataset: ) -> SingleSceneDataset:
expand_args_fields(SingleSceneDataset)
# pyre-ignore[16] # pyre-ignore[16]
split = self.i_split[split_idx] split = self.i_split[split_idx]
frame_types = [frame_type] * len(split) frame_types = [frame_type] * len(split)

View File

@ -245,13 +245,20 @@ def eval_batch(
if frame_data.mask_crop is None: if frame_data.mask_crop is None:
warnings.warn("mask_crop is None, assuming the whole image is valid.") warnings.warn("mask_crop is None, assuming the whole image is valid.")
if frame_data.fg_probability is None:
warnings.warn("fg_probability is None, assuming the whole image is fg.")
# threshold the masks to make ground truth binary masks # threshold the masks to make ground truth binary masks
# pyre-ignore [58] mask_fg = (
mask_fg = frame_data.fg_probability >= mask_thr frame_data.fg_probability >= mask_thr
if frame_data.fg_probability is not None
# pyre-ignore [16]
else torch.ones_like(frame_data.image_rgb[:, :1, ...]).bool()
)
mask_crop = ( mask_crop = (
frame_data.mask_crop frame_data.mask_crop
if frame_data.mask_crop is not None if frame_data.mask_crop is not None
# pyre-ignore [6]
else torch.ones_like(mask_fg) else torch.ones_like(mask_fg)
) )
@ -259,7 +266,6 @@ def eval_batch(
# pyre-fixme[6]: Expected `Tensor` for 1st param but got # pyre-fixme[6]: Expected `Tensor` for 1st param but got
# `Optional[torch.Tensor]`. # `Optional[torch.Tensor]`.
frame_data.image_rgb, frame_data.image_rgb,
# pyre-ignore [6]
mask_fg, mask_fg,
bg_color=bg_color, bg_color=bg_color,
) )
@ -275,7 +281,6 @@ def eval_batch(
# pyre-fixme[6]: Expected `Tensor` for 4th param but got # pyre-fixme[6]: Expected `Tensor` for 4th param but got
# `Optional[torch.Tensor]`. # `Optional[torch.Tensor]`.
depth_map=frame_data.depth_map, depth_map=frame_data.depth_map,
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
depth_mask=frame_data.depth_mask[:1], depth_mask=frame_data.depth_mask[:1],
visdom_env=visualize_visdom_env, visdom_env=visualize_visdom_env,
) )
@ -284,7 +289,7 @@ def eval_batch(
results["iou"] = iou( results["iou"] = iou(
cloned_render["mask_render"], cloned_render["mask_render"],
mask_fg, # pyre-ignore [6] mask_fg,
mask=mask_crop, mask=mask_crop,
) )

View File

@ -13,8 +13,8 @@ from typing import Any, Dict, List, Optional, Tuple
import lpips import lpips
import torch import torch
import tqdm
import tqdm
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate

View File

@ -198,7 +198,6 @@ class Stats(object):
if verbose: if verbose:
print(f"Adding {add_log_var}") print(f"Adding {add_log_var}")
self.log_vars.append(add_log_var) self.log_vars.append(add_log_var)
# self.synchronize_logged_vars(self.log_vars, verbose=verbose)
def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"): def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
@ -230,7 +229,6 @@ class Stats(object):
elapsed = time.time() - time_start elapsed = time.time() - time_start
time_per_it = float(elapsed) / float(it + 1) time_per_it = float(elapsed) / float(it + 1)
val = time_per_it val = time_per_it
# self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
else: else:
if stat in preds: if stat in preds:
try: try:
@ -441,7 +439,6 @@ class Stats(object):
self.log_vars = log_vars # !!! self.log_vars = log_vars # !!!
for stat_set in stat_sets: for stat_set in stat_sets:
reference_stat = list(self.stats[stat_set].keys())[0]
for stat in log_vars: for stat in log_vars:
if stat not in self.stats[stat_set]: if stat not in self.stats[stat_set]:
if verbose: if verbose:
@ -468,12 +465,11 @@ class Stats(object):
lastep = self.epoch + 1 lastep = self.epoch + 1
for ep in range(lastep): for ep in range(lastep):
self.stats[stat_set][stat].update(default_val, n=1, epoch=ep) self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
epoch_self = self.stats[stat_set][reference_stat].get_epoch()
epoch_generated = self.stats[stat_set][stat].get_epoch() epoch_generated = self.stats[stat_set][stat].get_epoch()
assert ( assert (
epoch_self == epoch_generated epoch_generated == self.epoch + 1
), "bad epoch of synchronized log_var! %d vs %d" % ( ), "bad epoch of synchronized log_var! %d vs %d" % (
epoch_self, self.epoch + 1,
epoch_generated, epoch_generated,
) )

View File

@ -83,6 +83,7 @@ dataset_map_provider_LlffDatasetMapProvider_args:
n_known_frames_for_test: null n_known_frames_for_test: null
path_manager_factory_PathManagerFactory_args: path_manager_factory_PathManagerFactory_args:
silence_logs: true silence_logs: true
downscale_factor: 4
dataset_map_provider_RenderedMeshDatasetMapProvider_args: dataset_map_provider_RenderedMeshDatasetMapProvider_args:
num_views: 40 num_views: 40
data_file: null data_file: null

View File

@ -69,6 +69,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
provider = LlffDatasetMapProvider( provider = LlffDatasetMapProvider(
base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern", base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
object_name="fern", object_name="fern",
downscale_factor=8,
) )
dataset_map = provider.get_dataset_map() dataset_map = provider.get_dataset_map()
known_matrix = torch.zeros(1, 4, 4) known_matrix = torch.zeros(1, 4, 4)