mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Mods and bugfixes for LLFF and Blender repros
Summary: LLFF (and most/all non-synth datasets) will have no background/foreground distinction. Add support for data with no fg mask. Also, we had a bug in stats loading, like this: * Load stats * One of the stats has a history of length 0 * That's fine, e.g. maybe it's fg_error but the dataset has no notion of fg/bg. So leave it as len 0 * Check whether all the stats have the same history length as an arbitrarily chosen "reference-stat" * Ooops the reference-stat happened to be the stat with length 0 * assert (legit_stat_len == reference_stat_len (=0)) ---> failed assert Also some minor fixes (from Jeremy's other diff) to support LLFF Reviewed By: davnov134 Differential Revision: D38475272 fbshipit-source-id: 5b35ac86d1d5239759f537621f41a3aa4eb3bd68
This commit is contained in:
		
							parent
							
								
									624bc5a274
								
							
						
					
					
						commit
						c83ec3555d
					
				@ -1,7 +1,7 @@
 | 
			
		||||
defaults:
 | 
			
		||||
- repro_singleseq_base
 | 
			
		||||
- _self_
 | 
			
		||||
exp_dir: "./data/nerf_blender_publ/${oc.env:BLENDER_SINGLESEQ_CLASS}"
 | 
			
		||||
exp_dir: "./data/nerf_blender_repro/${oc.env:BLENDER_SINGLESEQ_CLASS}"
 | 
			
		||||
data_source_ImplicitronDataSource_args:
 | 
			
		||||
  data_loader_map_provider_SequenceDataLoaderMapProvider_args:
 | 
			
		||||
    dataset_length_train: 100
 | 
			
		||||
@ -16,17 +16,18 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
 | 
			
		||||
model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
  model_GenericModel_args:
 | 
			
		||||
    raysampler_AdaptiveRaySampler_args:
 | 
			
		||||
    mask_images: false
 | 
			
		||||
    raysampler_class_type: NearFarRaySampler
 | 
			
		||||
    raysampler_NearFarRaySampler_args:
 | 
			
		||||
      n_rays_per_image_sampled_from_mask: 4096
 | 
			
		||||
      scene_extent: 2.0
 | 
			
		||||
      min_depth: 2
 | 
			
		||||
      max_depth: 6
 | 
			
		||||
    renderer_MultiPassEmissionAbsorptionRenderer_args:
 | 
			
		||||
      density_noise_std_train: 1.0
 | 
			
		||||
      n_pts_per_ray_fine_training: 128
 | 
			
		||||
      n_pts_per_ray_fine_evaluation: 128
 | 
			
		||||
      raymarcher_EmissionAbsorptionRaymarcher_args:
 | 
			
		||||
        blend_output: true
 | 
			
		||||
        bg_color:
 | 
			
		||||
        - 1.0
 | 
			
		||||
        blend_output: false
 | 
			
		||||
    loss_weights:
 | 
			
		||||
      loss_rgb_mse: 1.0
 | 
			
		||||
      loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
@ -35,11 +36,11 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      loss_autodecoder_norm: 0.00
 | 
			
		||||
 | 
			
		||||
optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  exponential_lr_step_size: 3001
 | 
			
		||||
  exponential_lr_step_size: 2500
 | 
			
		||||
  lr_policy: Exponential
 | 
			
		||||
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  max_epochs: 3001
 | 
			
		||||
  max_epochs: 2000
 | 
			
		||||
  metric_print_interval: 100
 | 
			
		||||
  store_checkpoints_purge: 3
 | 
			
		||||
  test_when_finished: true
 | 
			
		||||
 | 
			
		||||
@ -249,6 +249,7 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
        stats = Stats(
 | 
			
		||||
            # log_vars should be a list, but OmegaConf might load them as ListConfig
 | 
			
		||||
            list(log_vars),
 | 
			
		||||
            plot_file=os.path.join(exp_dir, "train_stats.pdf"),
 | 
			
		||||
            visdom_env=visdom_env_charts,
 | 
			
		||||
            verbose=False,
 | 
			
		||||
            visdom_server=self.visdom_server,
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,7 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
    n_known_frames_for_test: null
 | 
			
		||||
    path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
      silence_logs: true
 | 
			
		||||
    downscale_factor: 4
 | 
			
		||||
  dataset_map_provider_RenderedMeshDatasetMapProvider_args:
 | 
			
		||||
    num_views: 40
 | 
			
		||||
    data_file: null
 | 
			
		||||
 | 
			
		||||
@ -162,7 +162,7 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestNerfRepro(unittest.TestCase):
 | 
			
		||||
    @unittest.skip("This runs full NeRF training on Blender data.")
 | 
			
		||||
    @unittest.skip("This test runs full blender training.")
 | 
			
		||||
    def test_nerf_blender(self):
 | 
			
		||||
        # Train vanilla NERF.
 | 
			
		||||
        # Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
 | 
			
		||||
@ -174,6 +174,22 @@ class TestNerfRepro(unittest.TestCase):
 | 
			
		||||
            experiment.dump_cfg(cfg)
 | 
			
		||||
            experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("This test runs full llff training.")
 | 
			
		||||
    def test_nerf_llff(self):
 | 
			
		||||
        # Train vanilla NERF.
 | 
			
		||||
        # Set env vars LLFF_DATASET_ROOT and LLFF_SINGLESEQ_CLASS first!
 | 
			
		||||
        LLFF_SINGLESEQ_CLASS = os.environ["LLFF_SINGLESEQ_CLASS"]
 | 
			
		||||
        if not interactive_testing_requested():
 | 
			
		||||
            return
 | 
			
		||||
        with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
 | 
			
		||||
            cfg = compose(
 | 
			
		||||
                config_name=f"repro_singleseq_nerf_llff_{LLFF_SINGLESEQ_CLASS}",
 | 
			
		||||
                overrides=[],
 | 
			
		||||
            )
 | 
			
		||||
            experiment_runner = experiment.Experiment(**cfg)
 | 
			
		||||
            experiment.dump_cfg(cfg)
 | 
			
		||||
            experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("This test checks resuming of the NeRF training.")
 | 
			
		||||
    def test_nerf_blender_resume(self):
 | 
			
		||||
        # Train one train batch of NeRF, then resume for one more batch.
 | 
			
		||||
 | 
			
		||||
@ -32,17 +32,21 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
 | 
			
		||||
            and test datasets, and this many random training frames are added to
 | 
			
		||||
            each test batch. If not set, test batches each contain just a single
 | 
			
		||||
            testing frame.
 | 
			
		||||
        downscale_factor: determines image sizes.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    downscale_factor: int = 4
 | 
			
		||||
 | 
			
		||||
    def _load_data(self) -> None:
 | 
			
		||||
        path_manager = self.path_manager_factory.get()
 | 
			
		||||
        images, poses, _ = load_llff_data(
 | 
			
		||||
            self.base_dir, factor=8, path_manager=path_manager
 | 
			
		||||
            self.base_dir, factor=self.downscale_factor, path_manager=path_manager
 | 
			
		||||
        )
 | 
			
		||||
        hwf = poses[0, :3, -1]
 | 
			
		||||
        poses = poses[:, :3, :4]
 | 
			
		||||
 | 
			
		||||
        i_test = np.arange(images.shape[0])[::8]
 | 
			
		||||
        llffhold = 8
 | 
			
		||||
        i_test = np.arange(images.shape[0])[::llffhold]
 | 
			
		||||
        i_test_index = set(i_test.tolist())
 | 
			
		||||
        i_train = np.array(
 | 
			
		||||
            [i for i in np.arange(images.shape[0]) if i not in i_test_index]
 | 
			
		||||
 | 
			
		||||
@ -27,6 +27,7 @@ from .utils import DATASET_TYPE_KNOWN, DATASET_TYPE_UNKNOWN
 | 
			
		||||
_SINGLE_SEQUENCE_NAME: str = "one_sequence"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@expand_args_fields
 | 
			
		||||
class SingleSceneDataset(DatasetBase, Configurable):
 | 
			
		||||
    """
 | 
			
		||||
    A dataset from images from a single scene.
 | 
			
		||||
@ -110,7 +111,6 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
 | 
			
		||||
    def _get_dataset(
 | 
			
		||||
        self, split_idx: int, frame_type: str, set_eval_batches: bool = False
 | 
			
		||||
    ) -> SingleSceneDataset:
 | 
			
		||||
        expand_args_fields(SingleSceneDataset)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        split = self.i_split[split_idx]
 | 
			
		||||
        frame_types = [frame_type] * len(split)
 | 
			
		||||
 | 
			
		||||
@ -245,13 +245,20 @@ def eval_batch(
 | 
			
		||||
    if frame_data.mask_crop is None:
 | 
			
		||||
        warnings.warn("mask_crop is None, assuming the whole image is valid.")
 | 
			
		||||
 | 
			
		||||
    if frame_data.fg_probability is None:
 | 
			
		||||
        warnings.warn("fg_probability is None, assuming the whole image is fg.")
 | 
			
		||||
 | 
			
		||||
    # threshold the masks to make ground truth binary masks
 | 
			
		||||
    # pyre-ignore [58]
 | 
			
		||||
    mask_fg = frame_data.fg_probability >= mask_thr
 | 
			
		||||
    mask_fg = (
 | 
			
		||||
        frame_data.fg_probability >= mask_thr
 | 
			
		||||
        if frame_data.fg_probability is not None
 | 
			
		||||
        # pyre-ignore [16]
 | 
			
		||||
        else torch.ones_like(frame_data.image_rgb[:, :1, ...]).bool()
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    mask_crop = (
 | 
			
		||||
        frame_data.mask_crop
 | 
			
		||||
        if frame_data.mask_crop is not None
 | 
			
		||||
        # pyre-ignore [6]
 | 
			
		||||
        else torch.ones_like(mask_fg)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -259,7 +266,6 @@ def eval_batch(
 | 
			
		||||
        # pyre-fixme[6]: Expected `Tensor` for 1st param but got
 | 
			
		||||
        #  `Optional[torch.Tensor]`.
 | 
			
		||||
        frame_data.image_rgb,
 | 
			
		||||
        # pyre-ignore [6]
 | 
			
		||||
        mask_fg,
 | 
			
		||||
        bg_color=bg_color,
 | 
			
		||||
    )
 | 
			
		||||
@ -275,7 +281,6 @@ def eval_batch(
 | 
			
		||||
            # pyre-fixme[6]: Expected `Tensor` for 4th param but got
 | 
			
		||||
            #  `Optional[torch.Tensor]`.
 | 
			
		||||
            depth_map=frame_data.depth_map,
 | 
			
		||||
            # pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
 | 
			
		||||
            depth_mask=frame_data.depth_mask[:1],
 | 
			
		||||
            visdom_env=visualize_visdom_env,
 | 
			
		||||
        )
 | 
			
		||||
@ -284,7 +289,7 @@ def eval_batch(
 | 
			
		||||
 | 
			
		||||
    results["iou"] = iou(
 | 
			
		||||
        cloned_render["mask_render"],
 | 
			
		||||
        mask_fg,  # pyre-ignore [6]
 | 
			
		||||
        mask_fg,
 | 
			
		||||
        mask=mask_crop,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -13,8 +13,8 @@ from typing import Any, Dict, List, Optional, Tuple
 | 
			
		||||
 | 
			
		||||
import lpips
 | 
			
		||||
import torch
 | 
			
		||||
import tqdm
 | 
			
		||||
 | 
			
		||||
import tqdm
 | 
			
		||||
from pytorch3d.implicitron.dataset import utils as ds_utils
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
 | 
			
		||||
 | 
			
		||||
@ -198,7 +198,6 @@ class Stats(object):
 | 
			
		||||
                if verbose:
 | 
			
		||||
                    print(f"Adding {add_log_var}")
 | 
			
		||||
                self.log_vars.append(add_log_var)
 | 
			
		||||
        # self.synchronize_logged_vars(self.log_vars, verbose=verbose)
 | 
			
		||||
 | 
			
		||||
    def update(self, preds, time_start=None, freeze_iter=False, stat_set="train"):
 | 
			
		||||
 | 
			
		||||
@ -230,7 +229,6 @@ class Stats(object):
 | 
			
		||||
                    elapsed = time.time() - time_start
 | 
			
		||||
                time_per_it = float(elapsed) / float(it + 1)
 | 
			
		||||
                val = time_per_it
 | 
			
		||||
                # self.stats[stat_set]['sec/it'].update(time_per_it,epoch=epoch,n=1)
 | 
			
		||||
            else:
 | 
			
		||||
                if stat in preds:
 | 
			
		||||
                    try:
 | 
			
		||||
@ -441,7 +439,6 @@ class Stats(object):
 | 
			
		||||
        self.log_vars = log_vars  # !!!
 | 
			
		||||
 | 
			
		||||
        for stat_set in stat_sets:
 | 
			
		||||
            reference_stat = list(self.stats[stat_set].keys())[0]
 | 
			
		||||
            for stat in log_vars:
 | 
			
		||||
                if stat not in self.stats[stat_set]:
 | 
			
		||||
                    if verbose:
 | 
			
		||||
@ -468,12 +465,11 @@ class Stats(object):
 | 
			
		||||
                lastep = self.epoch + 1
 | 
			
		||||
                for ep in range(lastep):
 | 
			
		||||
                    self.stats[stat_set][stat].update(default_val, n=1, epoch=ep)
 | 
			
		||||
                epoch_self = self.stats[stat_set][reference_stat].get_epoch()
 | 
			
		||||
                epoch_generated = self.stats[stat_set][stat].get_epoch()
 | 
			
		||||
                assert (
 | 
			
		||||
                    epoch_self == epoch_generated
 | 
			
		||||
                    epoch_generated == self.epoch + 1
 | 
			
		||||
                ), "bad epoch of synchronized log_var! %d vs %d" % (
 | 
			
		||||
                    epoch_self,
 | 
			
		||||
                    self.epoch + 1,
 | 
			
		||||
                    epoch_generated,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -83,6 +83,7 @@ dataset_map_provider_LlffDatasetMapProvider_args:
 | 
			
		||||
  n_known_frames_for_test: null
 | 
			
		||||
  path_manager_factory_PathManagerFactory_args:
 | 
			
		||||
    silence_logs: true
 | 
			
		||||
  downscale_factor: 4
 | 
			
		||||
dataset_map_provider_RenderedMeshDatasetMapProvider_args:
 | 
			
		||||
  num_views: 40
 | 
			
		||||
  data_file: null
 | 
			
		||||
 | 
			
		||||
@ -69,6 +69,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        provider = LlffDatasetMapProvider(
 | 
			
		||||
            base_dir="manifold://co3d/tree/nerf_data/nerf_llff_data/fern",
 | 
			
		||||
            object_name="fern",
 | 
			
		||||
            downscale_factor=8,
 | 
			
		||||
        )
 | 
			
		||||
        dataset_map = provider.get_dataset_map()
 | 
			
		||||
        known_matrix = torch.zeros(1, 4, 4)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user