mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
		
							parent
							
								
									760305e044
								
							
						
					
					
						commit
						c3f8dad55c
					
				@ -3,6 +3,7 @@ defaults:
 | 
			
		||||
- _self_
 | 
			
		||||
exp_dir: ./data/exps/base/
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  visdom_port: 8097
 | 
			
		||||
  visualize_interval: 0
 | 
			
		||||
  max_epochs: 1000
 | 
			
		||||
data_source_ImplicitronDataSource_args:
 | 
			
		||||
@ -22,7 +23,6 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
      mask_depths: false
 | 
			
		||||
      mask_images: false
 | 
			
		||||
model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
  visdom_port: 8097
 | 
			
		||||
  model_GenericModel_args:
 | 
			
		||||
    loss_weights:
 | 
			
		||||
      loss_mask_bce: 1.0
 | 
			
		||||
 | 
			
		||||
@ -147,9 +147,6 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
 | 
			
		||||
    def run(self) -> None:
 | 
			
		||||
        # Make sure the config settings are self-consistent.
 | 
			
		||||
        self._check_config_consistent()
 | 
			
		||||
 | 
			
		||||
        # Initialize the accelerator if desired.
 | 
			
		||||
        if no_accelerate:
 | 
			
		||||
            accelerator = None
 | 
			
		||||
@ -176,9 +173,11 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
            exp_dir=self.exp_dir,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        stats = self.model_factory.load_stats(
 | 
			
		||||
            exp_dir=self.exp_dir,
 | 
			
		||||
        stats = self.training_loop.load_stats(
 | 
			
		||||
            log_vars=model.log_vars,
 | 
			
		||||
            exp_dir=self.exp_dir,
 | 
			
		||||
            resume=self.model_factory.resume,
 | 
			
		||||
            resume_epoch=self.model_factory.resume_epoch,  # pyre-ignore [16]
 | 
			
		||||
        )
 | 
			
		||||
        start_epoch = stats.epoch + 1
 | 
			
		||||
 | 
			
		||||
@ -190,6 +189,8 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
            exp_dir=self.exp_dir,
 | 
			
		||||
            last_epoch=start_epoch,
 | 
			
		||||
            model=model,
 | 
			
		||||
            resume=self.model_factory.resume,
 | 
			
		||||
            resume_epoch=self.model_factory.resume_epoch,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Wrap all modules in the distributed library
 | 
			
		||||
@ -224,26 +225,6 @@ class Experiment(Configurable):  # pyre-ignore: 13
 | 
			
		||||
            seed=self.seed,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def _check_config_consistent(self) -> None:
 | 
			
		||||
        if hasattr(self.optimizer_factory, "resume") and hasattr(
 | 
			
		||||
            self.model_factory, "resume"
 | 
			
		||||
        ):
 | 
			
		||||
            assert (
 | 
			
		||||
                # pyre-ignore [16]
 | 
			
		||||
                not self.optimizer_factory.resume
 | 
			
		||||
                # pyre-ignore [16]
 | 
			
		||||
                or self.model_factory.resume
 | 
			
		||||
            ), "Cannot resume the optimizer without resuming the model."
 | 
			
		||||
        if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
 | 
			
		||||
            self.model_factory, "resume_epoch"
 | 
			
		||||
        ):
 | 
			
		||||
            assert (
 | 
			
		||||
                # pyre-ignore [16]
 | 
			
		||||
                self.optimizer_factory.resume_epoch
 | 
			
		||||
                # pyre-ignore [16]
 | 
			
		||||
                == self.model_factory.resume_epoch
 | 
			
		||||
            ), "Optimizer and model must resume from the same epoch."
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _setup_envvars_for_cluster() -> bool:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -6,13 +6,13 @@
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import torch.optim
 | 
			
		||||
 | 
			
		||||
from accelerate import Accelerator
 | 
			
		||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
 | 
			
		||||
from pytorch3d.implicitron.tools import model_io, vis_utils
 | 
			
		||||
from pytorch3d.implicitron.tools import model_io
 | 
			
		||||
from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
    registry,
 | 
			
		||||
    ReplaceableBase,
 | 
			
		||||
@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ModelFactoryBase(ReplaceableBase):
 | 
			
		||||
 | 
			
		||||
    resume: bool = True  # resume from the last checkpoint
 | 
			
		||||
 | 
			
		||||
    def __call__(self, **kwargs) -> ImplicitronModelBase:
 | 
			
		||||
        """
 | 
			
		||||
        Initialize the model (possibly from a previously saved state).
 | 
			
		||||
@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase):  # pyre-ignore [13]
 | 
			
		||||
    A factory class that initializes an implicit rendering model.
 | 
			
		||||
 | 
			
		||||
    Members:
 | 
			
		||||
        force_load: If True, throw a FileNotFoundError if `resume` is True but
 | 
			
		||||
            a model checkpoint cannot be found.
 | 
			
		||||
        model: An ImplicitronModelBase object.
 | 
			
		||||
        resume: If True, attempt to load the last checkpoint from `exp_dir`
 | 
			
		||||
            passed to __call__. Failure to do so will return a model with ini-
 | 
			
		||||
            tial weights unless `force_load` is True.
 | 
			
		||||
            tial weights unless `force_resume` is True.
 | 
			
		||||
        resume_epoch: If `resume` is True: Resume a model at this epoch, or if
 | 
			
		||||
            `resume_epoch` <= 0, then resume from the latest checkpoint.
 | 
			
		||||
        visdom_env: The name of the Visdom environment to use for plotting.
 | 
			
		||||
        visdom_port: The Visdom port.
 | 
			
		||||
        visdom_server: Address of the Visdom server.
 | 
			
		||||
        force_resume: If True, throw a FileNotFoundError if `resume` is True but
 | 
			
		||||
            a model checkpoint cannot be found.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    force_load: bool = False
 | 
			
		||||
    model: ImplicitronModelBase
 | 
			
		||||
    model_class_type: str = "GenericModel"
 | 
			
		||||
    resume: bool = False
 | 
			
		||||
    resume: bool = True
 | 
			
		||||
    resume_epoch: int = -1
 | 
			
		||||
    visdom_env: str = ""
 | 
			
		||||
    visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
 | 
			
		||||
    visdom_server: str = "http://127.0.0.1"
 | 
			
		||||
    force_resume: bool = False
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase):  # pyre-ignore [13]
 | 
			
		||||
            model: The model with optionally loaded weights from checkpoint
 | 
			
		||||
 | 
			
		||||
        Raise:
 | 
			
		||||
            FileNotFoundError if `force_load` is True but checkpoint not found.
 | 
			
		||||
            FileNotFoundError if `force_resume` is True but checkpoint not found.
 | 
			
		||||
        """
 | 
			
		||||
        # Determine the network outputs that should be logged
 | 
			
		||||
        if hasattr(self.model, "log_vars"):
 | 
			
		||||
            log_vars = list(self.model.log_vars)  # pyre-ignore [6]
 | 
			
		||||
            log_vars = list(self.model.log_vars)
 | 
			
		||||
        else:
 | 
			
		||||
            log_vars = ["objective"]
 | 
			
		||||
 | 
			
		||||
        # Retrieve the last checkpoint
 | 
			
		||||
        if self.resume_epoch > 0:
 | 
			
		||||
            # Resume from a certain epoch
 | 
			
		||||
            model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
 | 
			
		||||
            if not os.path.isfile(model_path):
 | 
			
		||||
                raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
 | 
			
		||||
        else:
 | 
			
		||||
            # Retrieve the last checkpoint
 | 
			
		||||
            model_path = model_io.find_last_checkpoint(exp_dir)
 | 
			
		||||
 | 
			
		||||
        if model_path is not None:
 | 
			
		||||
            logger.info("found previous model %s" % model_path)
 | 
			
		||||
            if self.force_load or self.resume:
 | 
			
		||||
                logger.info("   -> resuming")
 | 
			
		||||
            logger.info(f"Found previous model {model_path}")
 | 
			
		||||
            if self.force_resume or self.resume:
 | 
			
		||||
                logger.info("Resuming.")
 | 
			
		||||
 | 
			
		||||
                map_location = None
 | 
			
		||||
                if accelerator is not None and not accelerator.is_local_main_process:
 | 
			
		||||
@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase):  # pyre-ignore [13]
 | 
			
		||||
                except RuntimeError as e:
 | 
			
		||||
                    logger.error(e)
 | 
			
		||||
                    logger.info(
 | 
			
		||||
                        "Cant load state dict in strict mode! -> trying non-strict"
 | 
			
		||||
                        "Cannot load state dict in strict mode! -> trying non-strict"
 | 
			
		||||
                    )
 | 
			
		||||
                    self.model.load_state_dict(model_state_dict, strict=False)
 | 
			
		||||
                self.model.log_vars = log_vars  # pyre-ignore [16]
 | 
			
		||||
                self.model.log_vars = log_vars
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info("   -> but not resuming -> starting from scratch")
 | 
			
		||||
        elif self.force_load:
 | 
			
		||||
                logger.info("Not resuming -> starting from scratch.")
 | 
			
		||||
        elif self.force_resume:
 | 
			
		||||
            raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
 | 
			
		||||
 | 
			
		||||
        return self.model
 | 
			
		||||
 | 
			
		||||
    def load_stats(
 | 
			
		||||
        self,
 | 
			
		||||
        log_vars: List[str],
 | 
			
		||||
        exp_dir: str,
 | 
			
		||||
        clear_stats: bool = False,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Stats:
 | 
			
		||||
        """
 | 
			
		||||
        Load Stats that correspond to the model's log_vars.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            log_vars: A list of variable names to log. Should be a subset of the
 | 
			
		||||
                `preds` returned by the forward function of the corresponding
 | 
			
		||||
                ImplicitronModelBase instance.
 | 
			
		||||
            exp_dir: Root experiment directory.
 | 
			
		||||
            clear_stats: If True, do not load stats from the checkpoint speci-
 | 
			
		||||
                fied by self.resume and self.resume_epoch; instead, create a
 | 
			
		||||
                fresh stats object.
 | 
			
		||||
 | 
			
		||||
        stats: The stats structure (optionally loaded from checkpoint)
 | 
			
		||||
        """
 | 
			
		||||
        # Init the stats struct
 | 
			
		||||
        visdom_env_charts = (
 | 
			
		||||
            vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
 | 
			
		||||
        )
 | 
			
		||||
        stats = Stats(
 | 
			
		||||
            # log_vars should be a list, but OmegaConf might load them as ListConfig
 | 
			
		||||
            list(log_vars),
 | 
			
		||||
            plot_file=os.path.join(exp_dir, "train_stats.pdf"),
 | 
			
		||||
            visdom_env=visdom_env_charts,
 | 
			
		||||
            verbose=False,
 | 
			
		||||
            visdom_server=self.visdom_server,
 | 
			
		||||
            visdom_port=self.visdom_port,
 | 
			
		||||
        )
 | 
			
		||||
        if self.resume_epoch > 0:
 | 
			
		||||
            model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
 | 
			
		||||
        else:
 | 
			
		||||
            model_path = model_io.find_last_checkpoint(exp_dir)
 | 
			
		||||
 | 
			
		||||
        if model_path is not None:
 | 
			
		||||
            stats_path = model_io.get_stats_path(model_path)
 | 
			
		||||
            stats_load = model_io.load_stats(stats_path)
 | 
			
		||||
 | 
			
		||||
            # Determine if stats should be reset
 | 
			
		||||
            if not clear_stats:
 | 
			
		||||
                if stats_load is None:
 | 
			
		||||
                    logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
 | 
			
		||||
                    last_epoch = model_io.parse_epoch_from_model_path(model_path)
 | 
			
		||||
                    logger.info(f"Estimated resume epoch = {last_epoch}")
 | 
			
		||||
 | 
			
		||||
                    # Reset the stats struct
 | 
			
		||||
                    for _ in range(last_epoch + 1):
 | 
			
		||||
                        stats.new_epoch()
 | 
			
		||||
                    assert last_epoch == stats.epoch
 | 
			
		||||
                else:
 | 
			
		||||
                    stats = stats_load
 | 
			
		||||
 | 
			
		||||
                # Update stats properties incase it was reset on load
 | 
			
		||||
                stats.visdom_env = visdom_env_charts
 | 
			
		||||
                stats.visdom_server = self.visdom_server
 | 
			
		||||
                stats.visdom_port = self.visdom_port
 | 
			
		||||
                stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
 | 
			
		||||
                stats.synchronize_logged_vars(log_vars)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info("   -> clearing stats")
 | 
			
		||||
 | 
			
		||||
        return stats
 | 
			
		||||
 | 
			
		||||
@ -60,11 +60,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
        multistep_lr_milestones: With MultiStepLR policy only: list of
 | 
			
		||||
            increasing epoch indices at which the learning rate is modified.
 | 
			
		||||
        momentum: Momentum factor for SGD optimizer.
 | 
			
		||||
        resume: If True, attempt to load the last checkpoint from `exp_dir`
 | 
			
		||||
            passed to __call__. Failure to do so will return a newly initialized
 | 
			
		||||
            optimizer.
 | 
			
		||||
        resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
 | 
			
		||||
            `resume_epoch` <= 0, then resume from the latest checkpoint.
 | 
			
		||||
        weight_decay: The optimizer weight_decay (L2 penalty on model weights).
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -76,8 +71,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
    lr_policy: str = "MultiStepLR"
 | 
			
		||||
    momentum: float = 0.9
 | 
			
		||||
    multistep_lr_milestones: tuple = ()
 | 
			
		||||
    resume: bool = False
 | 
			
		||||
    resume_epoch: int = -1
 | 
			
		||||
    weight_decay: float = 0.0
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
@ -89,6 +82,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
        model: ImplicitronModelBase,
 | 
			
		||||
        accelerator: Optional[Accelerator] = None,
 | 
			
		||||
        exp_dir: Optional[str] = None,
 | 
			
		||||
        resume: bool = True,
 | 
			
		||||
        resume_epoch: int = -1,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Tuple[torch.optim.Optimizer, Any]:
 | 
			
		||||
        """
 | 
			
		||||
@ -100,7 +95,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
            model: The model with optionally loaded weights.
 | 
			
		||||
            accelerator: An optional Accelerator instance.
 | 
			
		||||
            exp_dir: Root experiment directory.
 | 
			
		||||
 | 
			
		||||
            resume: If True, attempt to load optimizer checkpoint from exp_dir.
 | 
			
		||||
                Failure to do so will return a newly initialized optimizer.
 | 
			
		||||
            resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
 | 
			
		||||
                `resume_epoch` <= 0, then resume from the latest checkpoint.
 | 
			
		||||
        Returns:
 | 
			
		||||
            An optimizer module (optionally loaded from a checkpoint) and
 | 
			
		||||
            a learning rate scheduler module (should be a subclass of torch.optim's
 | 
			
		||||
@ -131,13 +129,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
                p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("no such solver type %s" % self.breed)
 | 
			
		||||
        logger.info("  -> solver type = %s" % self.breed)
 | 
			
		||||
            raise ValueError(f"No such solver type {self.breed}")
 | 
			
		||||
        logger.info(f"Solver type = {self.breed}")
 | 
			
		||||
 | 
			
		||||
        # Load state from checkpoint
 | 
			
		||||
        optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
 | 
			
		||||
        optimizer_state = self._get_optimizer_state(
 | 
			
		||||
            exp_dir,
 | 
			
		||||
            accelerator,
 | 
			
		||||
            resume_epoch=resume_epoch,
 | 
			
		||||
            resume=resume,
 | 
			
		||||
        )
 | 
			
		||||
        if optimizer_state is not None:
 | 
			
		||||
            logger.info("  -> setting loaded optimizer state")
 | 
			
		||||
            logger.info("Setting loaded optimizer state.")
 | 
			
		||||
            optimizer.load_state_dict(optimizer_state)
 | 
			
		||||
 | 
			
		||||
        # Initialize the learning rate scheduler
 | 
			
		||||
@ -169,20 +172,31 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
        self,
 | 
			
		||||
        exp_dir: Optional[str],
 | 
			
		||||
        accelerator: Optional[Accelerator] = None,
 | 
			
		||||
        resume: bool = True,
 | 
			
		||||
        resume_epoch: int = -1,
 | 
			
		||||
    ) -> Optional[Dict[str, Any]]:
 | 
			
		||||
        """
 | 
			
		||||
        Load an optimizer state from a checkpoint.
 | 
			
		||||
 | 
			
		||||
        resume: If True, attempt to load the last checkpoint from `exp_dir`
 | 
			
		||||
            passed to __call__. Failure to do so will return a newly initialized
 | 
			
		||||
            optimizer.
 | 
			
		||||
        resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
 | 
			
		||||
            `resume_epoch` <= 0, then resume from the latest checkpoint.
 | 
			
		||||
        """
 | 
			
		||||
        if exp_dir is None or not self.resume:
 | 
			
		||||
        if exp_dir is None or not resume:
 | 
			
		||||
            return None
 | 
			
		||||
        if self.resume_epoch > 0:
 | 
			
		||||
            save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
 | 
			
		||||
        if resume_epoch > 0:
 | 
			
		||||
            save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
 | 
			
		||||
            if not os.path.isfile(save_path):
 | 
			
		||||
                raise FileNotFoundError(
 | 
			
		||||
                    f"Cannot find optimizer from epoch {resume_epoch}."
 | 
			
		||||
                )
 | 
			
		||||
        else:
 | 
			
		||||
            save_path = model_io.find_last_checkpoint(exp_dir)
 | 
			
		||||
        optimizer_state = None
 | 
			
		||||
        if save_path is not None:
 | 
			
		||||
            logger.info(f"Found previous optimizer state {save_path}.")
 | 
			
		||||
            logger.info("   -> resuming")
 | 
			
		||||
            logger.info(f"Found previous optimizer state {save_path} -> resuming.")
 | 
			
		||||
            opt_path = model_io.get_optimizer_path(save_path)
 | 
			
		||||
 | 
			
		||||
            if os.path.isfile(opt_path):
 | 
			
		||||
@ -193,5 +207,5 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
                    }
 | 
			
		||||
                optimizer_state = torch.load(opt_path, map_location)
 | 
			
		||||
            else:
 | 
			
		||||
                optimizer_state = None
 | 
			
		||||
                raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
 | 
			
		||||
        return optimizer_state
 | 
			
		||||
 | 
			
		||||
@ -5,8 +5,9 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
from typing import Any, Optional
 | 
			
		||||
from typing import Any, List, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from accelerate import Accelerator
 | 
			
		||||
@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def load_stats(
 | 
			
		||||
        self,
 | 
			
		||||
        log_vars: List[str],
 | 
			
		||||
        exp_dir: str,
 | 
			
		||||
        resume: bool = True,
 | 
			
		||||
        resume_epoch: int = -1,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Stats:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
            logged.
 | 
			
		||||
        visualize_interval: The batch interval at which the visualizations
 | 
			
		||||
            should be plotted
 | 
			
		||||
        visdom_env: The name of the Visdom environment to use for plotting.
 | 
			
		||||
        visdom_port: The Visdom port.
 | 
			
		||||
        visdom_server: Address of the Visdom server.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # Parameters of the outer training loop.
 | 
			
		||||
@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
    test_when_finished: bool = False
 | 
			
		||||
    validation_interval: int = 1
 | 
			
		||||
 | 
			
		||||
    # Parameters of a single training-validation step.
 | 
			
		||||
    # Gradient clipping.
 | 
			
		||||
    clip_grad: float = 0.0
 | 
			
		||||
 | 
			
		||||
    # Visualization/logging parameters.
 | 
			
		||||
    metric_print_interval: int = 5
 | 
			
		||||
    visualize_interval: int = 1000
 | 
			
		||||
    visdom_env: str = ""
 | 
			
		||||
    visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
 | 
			
		||||
    visdom_server: str = "http://127.0.0.1"
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase):  # pyre-ignore [13]
 | 
			
		||||
                    "Cannot evaluate and dump results to json, no test data provided."
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    def load_stats(
 | 
			
		||||
        self,
 | 
			
		||||
        log_vars: List[str],
 | 
			
		||||
        exp_dir: str,
 | 
			
		||||
        resume: bool = True,
 | 
			
		||||
        resume_epoch: int = -1,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ) -> Stats:
 | 
			
		||||
        """
 | 
			
		||||
        Load Stats that correspond to the model's log_vars and resume_epoch.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            log_vars: A list of variable names to log. Should be a subset of the
 | 
			
		||||
                `preds` returned by the forward function of the corresponding
 | 
			
		||||
                ImplicitronModelBase instance.
 | 
			
		||||
            exp_dir: Root experiment directory.
 | 
			
		||||
            resume: If False, do not load stats from the checkpoint speci-
 | 
			
		||||
                fied by resume and resume_epoch; instead, create a fresh stats object.
 | 
			
		||||
 | 
			
		||||
        stats: The stats structure (optionally loaded from checkpoint)
 | 
			
		||||
        """
 | 
			
		||||
        # Init the stats struct
 | 
			
		||||
        visdom_env_charts = (
 | 
			
		||||
            vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
 | 
			
		||||
        )
 | 
			
		||||
        stats = Stats(
 | 
			
		||||
            # log_vars should be a list, but OmegaConf might load them as ListConfig
 | 
			
		||||
            list(log_vars),
 | 
			
		||||
            visdom_env=visdom_env_charts,
 | 
			
		||||
            verbose=False,
 | 
			
		||||
            visdom_server=self.visdom_server,
 | 
			
		||||
            visdom_port=self.visdom_port,
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        model_path = None
 | 
			
		||||
        if resume:
 | 
			
		||||
            if resume_epoch > 0:
 | 
			
		||||
                model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
 | 
			
		||||
                if not os.path.isfile(model_path):
 | 
			
		||||
                    raise FileNotFoundError(
 | 
			
		||||
                        f"Cannot find stats from epoch {resume_epoch}."
 | 
			
		||||
                    )
 | 
			
		||||
            else:
 | 
			
		||||
                model_path = model_io.find_last_checkpoint(exp_dir)
 | 
			
		||||
 | 
			
		||||
        if model_path is not None:
 | 
			
		||||
            stats_path = model_io.get_stats_path(model_path)
 | 
			
		||||
            stats_load = model_io.load_stats(stats_path)
 | 
			
		||||
 | 
			
		||||
            # Determine if stats should be reset
 | 
			
		||||
            if resume:
 | 
			
		||||
                if stats_load is None:
 | 
			
		||||
                    logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
 | 
			
		||||
                    last_epoch = model_io.parse_epoch_from_model_path(model_path)
 | 
			
		||||
                    logger.info(f"Estimated resume epoch = {last_epoch}")
 | 
			
		||||
 | 
			
		||||
                    # Reset the stats struct
 | 
			
		||||
                    for _ in range(last_epoch + 1):
 | 
			
		||||
                        stats.new_epoch()
 | 
			
		||||
                    assert last_epoch == stats.epoch
 | 
			
		||||
                else:
 | 
			
		||||
                    logger.info(f"Found previous stats in {stats_path} -> resuming.")
 | 
			
		||||
                    stats = stats_load
 | 
			
		||||
 | 
			
		||||
                # Update stats properties incase it was reset on load
 | 
			
		||||
                stats.visdom_env = visdom_env_charts
 | 
			
		||||
                stats.visdom_server = self.visdom_server
 | 
			
		||||
                stats.visdom_port = self.visdom_port
 | 
			
		||||
                stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
 | 
			
		||||
                stats.synchronize_logged_vars(log_vars)
 | 
			
		||||
            else:
 | 
			
		||||
                logger.info("Clearing stats")
 | 
			
		||||
 | 
			
		||||
        return stats
 | 
			
		||||
 | 
			
		||||
    def _training_or_validation_epoch(
 | 
			
		||||
        self,
 | 
			
		||||
        epoch: int,
 | 
			
		||||
 | 
			
		||||
@ -124,14 +124,32 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
    dataset_length_val: 0
 | 
			
		||||
    dataset_length_test: 0
 | 
			
		||||
model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
  force_load: false
 | 
			
		||||
  resume: true
 | 
			
		||||
  model_class_type: GenericModel
 | 
			
		||||
  resume: false
 | 
			
		||||
  resume_epoch: -1
 | 
			
		||||
  visdom_env: ''
 | 
			
		||||
  visdom_port: 8097
 | 
			
		||||
  visdom_server: http://127.0.0.1
 | 
			
		||||
  force_resume: false
 | 
			
		||||
  model_GenericModel_args:
 | 
			
		||||
    log_vars:
 | 
			
		||||
    - loss_rgb_psnr_fg
 | 
			
		||||
    - loss_rgb_psnr
 | 
			
		||||
    - loss_rgb_mse
 | 
			
		||||
    - loss_rgb_huber
 | 
			
		||||
    - loss_depth_abs
 | 
			
		||||
    - loss_depth_abs_fg
 | 
			
		||||
    - loss_mask_neg_iou
 | 
			
		||||
    - loss_mask_bce
 | 
			
		||||
    - loss_mask_beta_prior
 | 
			
		||||
    - loss_eikonal
 | 
			
		||||
    - loss_density_tv
 | 
			
		||||
    - loss_depth_neg_penalty
 | 
			
		||||
    - loss_autodecoder_norm
 | 
			
		||||
    - loss_prev_stage_rgb_mse
 | 
			
		||||
    - loss_prev_stage_rgb_psnr_fg
 | 
			
		||||
    - loss_prev_stage_rgb_psnr
 | 
			
		||||
    - loss_prev_stage_mask_bce
 | 
			
		||||
    - objective
 | 
			
		||||
    - epoch
 | 
			
		||||
    - sec/it
 | 
			
		||||
    mask_images: true
 | 
			
		||||
    mask_depths: true
 | 
			
		||||
    render_image_width: 400
 | 
			
		||||
@ -162,27 +180,6 @@ model_factory_ImplicitronModelFactory_args:
 | 
			
		||||
      loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
      loss_mask_bce: 0.0
 | 
			
		||||
      loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
    log_vars:
 | 
			
		||||
    - loss_rgb_psnr_fg
 | 
			
		||||
    - loss_rgb_psnr
 | 
			
		||||
    - loss_rgb_mse
 | 
			
		||||
    - loss_rgb_huber
 | 
			
		||||
    - loss_depth_abs
 | 
			
		||||
    - loss_depth_abs_fg
 | 
			
		||||
    - loss_mask_neg_iou
 | 
			
		||||
    - loss_mask_bce
 | 
			
		||||
    - loss_mask_beta_prior
 | 
			
		||||
    - loss_eikonal
 | 
			
		||||
    - loss_density_tv
 | 
			
		||||
    - loss_depth_neg_penalty
 | 
			
		||||
    - loss_autodecoder_norm
 | 
			
		||||
    - loss_prev_stage_rgb_mse
 | 
			
		||||
    - loss_prev_stage_rgb_psnr_fg
 | 
			
		||||
    - loss_prev_stage_rgb_psnr
 | 
			
		||||
    - loss_prev_stage_mask_bce
 | 
			
		||||
    - objective
 | 
			
		||||
    - epoch
 | 
			
		||||
    - sec/it
 | 
			
		||||
    global_encoder_HarmonicTimeEncoder_args:
 | 
			
		||||
      n_harmonic_functions: 10
 | 
			
		||||
      append_input: true
 | 
			
		||||
@ -422,8 +419,6 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  lr_policy: MultiStepLR
 | 
			
		||||
  momentum: 0.9
 | 
			
		||||
  multistep_lr_milestones: []
 | 
			
		||||
  resume: false
 | 
			
		||||
  resume_epoch: -1
 | 
			
		||||
  weight_decay: 0.0
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  eval_only: false
 | 
			
		||||
@ -437,6 +432,9 @@ training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  clip_grad: 0.0
 | 
			
		||||
  metric_print_interval: 5
 | 
			
		||||
  visualize_interval: 1000
 | 
			
		||||
  visdom_env: ''
 | 
			
		||||
  visdom_port: 8097
 | 
			
		||||
  visdom_server: http://127.0.0.1
 | 
			
		||||
  evaluator_ImplicitronEvaluator_args:
 | 
			
		||||
    camera_difficulty_bin_breaks:
 | 
			
		||||
    - 0.97
 | 
			
		||||
 | 
			
		||||
@ -5,6 +5,7 @@
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
@ -172,3 +173,48 @@ class TestNerfRepro(unittest.TestCase):
 | 
			
		||||
            experiment_runner = experiment.Experiment(**cfg)
 | 
			
		||||
            experiment.dump_cfg(cfg)
 | 
			
		||||
            experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
    @unittest.skip("This test checks resuming of the NeRF training.")
 | 
			
		||||
    def test_nerf_blender_resume(self):
 | 
			
		||||
        # Train one train batch of NeRF, then resume for one more batch.
 | 
			
		||||
        # Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
 | 
			
		||||
        if not interactive_testing_requested():
 | 
			
		||||
            return
 | 
			
		||||
        with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
 | 
			
		||||
            with tempfile.TemporaryDirectory() as exp_dir:
 | 
			
		||||
                cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
 | 
			
		||||
                cfg.exp_dir = exp_dir
 | 
			
		||||
 | 
			
		||||
                # set dataset len to 1
 | 
			
		||||
 | 
			
		||||
                # fmt: off
 | 
			
		||||
                (
 | 
			
		||||
                    cfg
 | 
			
		||||
                    .data_source_ImplicitronDataSource_args
 | 
			
		||||
                    .data_loader_map_provider_SequenceDataLoaderMapProvider_args
 | 
			
		||||
                    .dataset_length_train
 | 
			
		||||
                ) = 1
 | 
			
		||||
                # fmt: on
 | 
			
		||||
 | 
			
		||||
                # run for one epoch
 | 
			
		||||
                cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 1
 | 
			
		||||
                experiment_runner = experiment.Experiment(**cfg)
 | 
			
		||||
                experiment.dump_cfg(cfg)
 | 
			
		||||
                experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
                # update num epochs + 2, let the optimizer resume
 | 
			
		||||
                cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 3
 | 
			
		||||
                experiment_runner = experiment.Experiment(**cfg)
 | 
			
		||||
                experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
                # start from scratch
 | 
			
		||||
                cfg.model_factory_ImplicitronModelFactory_args.resume = False
 | 
			
		||||
                experiment_runner = experiment.Experiment(**cfg)
 | 
			
		||||
                experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
                # force resume from epoch 1
 | 
			
		||||
                cfg.model_factory_ImplicitronModelFactory_args.resume = True
 | 
			
		||||
                cfg.model_factory_ImplicitronModelFactory_args.force_resume = True
 | 
			
		||||
                cfg.model_factory_ImplicitronModelFactory_args.resume_epoch = 1
 | 
			
		||||
                experiment_runner = experiment.Experiment(**cfg)
 | 
			
		||||
                experiment_runner.run()
 | 
			
		||||
 | 
			
		||||
@ -344,7 +344,7 @@ def export_scenes(
 | 
			
		||||
 | 
			
		||||
    # Load the previously trained model
 | 
			
		||||
    experiment = Experiment(config)
 | 
			
		||||
    model = experiment.model_factory(force_load=True, load_model_only=True)
 | 
			
		||||
    model = experiment.model_factory(force_resume=True)
 | 
			
		||||
    model.cuda()
 | 
			
		||||
    model.eval()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
@ -45,6 +45,10 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
 | 
			
		||||
    optimization.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    # The keys from `preds` (output of ImplicitronModelBase.forward) to be logged in
 | 
			
		||||
    # the training loop.
 | 
			
		||||
    log_vars: List[str] = field(default_factory=lambda: ["objective"])
 | 
			
		||||
 | 
			
		||||
    def __init__(self) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,24 @@
 | 
			
		||||
log_vars:
 | 
			
		||||
- loss_rgb_psnr_fg
 | 
			
		||||
- loss_rgb_psnr
 | 
			
		||||
- loss_rgb_mse
 | 
			
		||||
- loss_rgb_huber
 | 
			
		||||
- loss_depth_abs
 | 
			
		||||
- loss_depth_abs_fg
 | 
			
		||||
- loss_mask_neg_iou
 | 
			
		||||
- loss_mask_bce
 | 
			
		||||
- loss_mask_beta_prior
 | 
			
		||||
- loss_eikonal
 | 
			
		||||
- loss_density_tv
 | 
			
		||||
- loss_depth_neg_penalty
 | 
			
		||||
- loss_autodecoder_norm
 | 
			
		||||
- loss_prev_stage_rgb_mse
 | 
			
		||||
- loss_prev_stage_rgb_psnr_fg
 | 
			
		||||
- loss_prev_stage_rgb_psnr
 | 
			
		||||
- loss_prev_stage_mask_bce
 | 
			
		||||
- objective
 | 
			
		||||
- epoch
 | 
			
		||||
- sec/it
 | 
			
		||||
mask_images: true
 | 
			
		||||
mask_depths: true
 | 
			
		||||
render_image_width: 400
 | 
			
		||||
@ -28,27 +49,6 @@ loss_weights:
 | 
			
		||||
  loss_prev_stage_rgb_mse: 1.0
 | 
			
		||||
  loss_mask_bce: 0.0
 | 
			
		||||
  loss_prev_stage_mask_bce: 0.0
 | 
			
		||||
log_vars:
 | 
			
		||||
- loss_rgb_psnr_fg
 | 
			
		||||
- loss_rgb_psnr
 | 
			
		||||
- loss_rgb_mse
 | 
			
		||||
- loss_rgb_huber
 | 
			
		||||
- loss_depth_abs
 | 
			
		||||
- loss_depth_abs_fg
 | 
			
		||||
- loss_mask_neg_iou
 | 
			
		||||
- loss_mask_bce
 | 
			
		||||
- loss_mask_beta_prior
 | 
			
		||||
- loss_eikonal
 | 
			
		||||
- loss_density_tv
 | 
			
		||||
- loss_depth_neg_penalty
 | 
			
		||||
- loss_autodecoder_norm
 | 
			
		||||
- loss_prev_stage_rgb_mse
 | 
			
		||||
- loss_prev_stage_rgb_psnr_fg
 | 
			
		||||
- loss_prev_stage_rgb_psnr
 | 
			
		||||
- loss_prev_stage_mask_bce
 | 
			
		||||
- objective
 | 
			
		||||
- epoch
 | 
			
		||||
- sec/it
 | 
			
		||||
global_encoder_SequenceAutodecoder_args:
 | 
			
		||||
  autodecoder_args:
 | 
			
		||||
    encoding_dim: 0
 | 
			
		||||
 | 
			
		||||
@ -90,5 +90,5 @@ class TestGenericModel(unittest.TestCase):
 | 
			
		||||
        remove_unused_components(instance_args)
 | 
			
		||||
        yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            (DATA_DIR / "overrides.yaml_").write_text(yaml)
 | 
			
		||||
            (DATA_DIR / "overrides_.yaml").write_text(yaml)
 | 
			
		||||
        self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user