Move load_stats to TrainingLoop

Summary:
Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop.

Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc.

Reviewed By: kjchalup

Differential Revision: D38313475

fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
David Novotny 2022-08-02 15:40:53 -07:00 committed by Facebook GitHub Bot
parent 760305e044
commit c3f8dad55c
11 changed files with 259 additions and 189 deletions

View File

@ -3,6 +3,7 @@ defaults:
- _self_ - _self_
exp_dir: ./data/exps/base/ exp_dir: ./data/exps/base/
training_loop_ImplicitronTrainingLoop_args: training_loop_ImplicitronTrainingLoop_args:
visdom_port: 8097
visualize_interval: 0 visualize_interval: 0
max_epochs: 1000 max_epochs: 1000
data_source_ImplicitronDataSource_args: data_source_ImplicitronDataSource_args:
@ -22,7 +23,6 @@ data_source_ImplicitronDataSource_args:
mask_depths: false mask_depths: false
mask_images: false mask_images: false
model_factory_ImplicitronModelFactory_args: model_factory_ImplicitronModelFactory_args:
visdom_port: 8097
model_GenericModel_args: model_GenericModel_args:
loss_weights: loss_weights:
loss_mask_bce: 1.0 loss_mask_bce: 1.0

View File

@ -147,9 +147,6 @@ class Experiment(Configurable): # pyre-ignore: 13
run_auto_creation(self) run_auto_creation(self)
def run(self) -> None: def run(self) -> None:
# Make sure the config settings are self-consistent.
self._check_config_consistent()
# Initialize the accelerator if desired. # Initialize the accelerator if desired.
if no_accelerate: if no_accelerate:
accelerator = None accelerator = None
@ -176,9 +173,11 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir, exp_dir=self.exp_dir,
) )
stats = self.model_factory.load_stats( stats = self.training_loop.load_stats(
exp_dir=self.exp_dir,
log_vars=model.log_vars, log_vars=model.log_vars,
exp_dir=self.exp_dir,
resume=self.model_factory.resume,
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
) )
start_epoch = stats.epoch + 1 start_epoch = stats.epoch + 1
@ -190,6 +189,8 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir, exp_dir=self.exp_dir,
last_epoch=start_epoch, last_epoch=start_epoch,
model=model, model=model,
resume=self.model_factory.resume,
resume_epoch=self.model_factory.resume_epoch,
) )
# Wrap all modules in the distributed library # Wrap all modules in the distributed library
@ -224,26 +225,6 @@ class Experiment(Configurable): # pyre-ignore: 13
seed=self.seed, seed=self.seed,
) )
def _check_config_consistent(self) -> None:
if hasattr(self.optimizer_factory, "resume") and hasattr(
self.model_factory, "resume"
):
assert (
# pyre-ignore [16]
not self.optimizer_factory.resume
# pyre-ignore [16]
or self.model_factory.resume
), "Cannot resume the optimizer without resuming the model."
if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
self.model_factory, "resume_epoch"
):
assert (
# pyre-ignore [16]
self.optimizer_factory.resume_epoch
# pyre-ignore [16]
== self.model_factory.resume_epoch
), "Optimizer and model must resume from the same epoch."
def _setup_envvars_for_cluster() -> bool: def _setup_envvars_for_cluster() -> bool:
""" """

View File

@ -6,13 +6,13 @@
import logging import logging
import os import os
from typing import List, Optional from typing import Optional
import torch.optim import torch.optim
from accelerate import Accelerator from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools import model_io
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
registry, registry,
ReplaceableBase, ReplaceableBase,
@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase): class ModelFactoryBase(ReplaceableBase):
resume: bool = True # resume from the last checkpoint
def __call__(self, **kwargs) -> ImplicitronModelBase: def __call__(self, **kwargs) -> ImplicitronModelBase:
""" """
Initialize the model (possibly from a previously saved state). Initialize the model (possibly from a previously saved state).
@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
A factory class that initializes an implicit rendering model. A factory class that initializes an implicit rendering model.
Members: Members:
force_load: If True, throw a FileNotFoundError if `resume` is True but
a model checkpoint cannot be found.
model: An ImplicitronModelBase object. model: An ImplicitronModelBase object.
resume: If True, attempt to load the last checkpoint from `exp_dir` resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a model with ini- passed to __call__. Failure to do so will return a model with ini-
tial weights unless `force_load` is True. tial weights unless `force_resume` is True.
resume_epoch: If `resume` is True: Resume a model at this epoch, or if resume_epoch: If `resume` is True: Resume a model at this epoch, or if
`resume_epoch` <= 0, then resume from the latest checkpoint. `resume_epoch` <= 0, then resume from the latest checkpoint.
visdom_env: The name of the Visdom environment to use for plotting. force_resume: If True, throw a FileNotFoundError if `resume` is True but
visdom_port: The Visdom port. a model checkpoint cannot be found.
visdom_server: Address of the Visdom server.
""" """
force_load: bool = False
model: ImplicitronModelBase model: ImplicitronModelBase
model_class_type: str = "GenericModel" model_class_type: str = "GenericModel"
resume: bool = False resume: bool = True
resume_epoch: int = -1 resume_epoch: int = -1
visdom_env: str = "" force_resume: bool = False
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
model: The model with optionally loaded weights from checkpoint model: The model with optionally loaded weights from checkpoint
Raise: Raise:
FileNotFoundError if `force_load` is True but checkpoint not found. FileNotFoundError if `force_resume` is True but checkpoint not found.
""" """
# Determine the network outputs that should be logged # Determine the network outputs that should be logged
if hasattr(self.model, "log_vars"): if hasattr(self.model, "log_vars"):
log_vars = list(self.model.log_vars) # pyre-ignore [6] log_vars = list(self.model.log_vars)
else: else:
log_vars = ["objective"] log_vars = ["objective"]
# Retrieve the last checkpoint
if self.resume_epoch > 0: if self.resume_epoch > 0:
# Resume from a certain epoch
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
if not os.path.isfile(model_path):
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
else: else:
# Retrieve the last checkpoint
model_path = model_io.find_last_checkpoint(exp_dir) model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None: if model_path is not None:
logger.info("found previous model %s" % model_path) logger.info(f"Found previous model {model_path}")
if self.force_load or self.resume: if self.force_resume or self.resume:
logger.info(" -> resuming") logger.info("Resuming.")
map_location = None map_location = None
if accelerator is not None and not accelerator.is_local_main_process: if accelerator is not None and not accelerator.is_local_main_process:
@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
except RuntimeError as e: except RuntimeError as e:
logger.error(e) logger.error(e)
logger.info( logger.info(
"Cant load state dict in strict mode! -> trying non-strict" "Cannot load state dict in strict mode! -> trying non-strict"
) )
self.model.load_state_dict(model_state_dict, strict=False) self.model.load_state_dict(model_state_dict, strict=False)
self.model.log_vars = log_vars # pyre-ignore [16] self.model.log_vars = log_vars
else: else:
logger.info(" -> but not resuming -> starting from scratch") logger.info("Not resuming -> starting from scratch.")
elif self.force_load: elif self.force_resume:
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!") raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
return self.model return self.model
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
clear_stats: bool = False,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
clear_stats: If True, do not load stats from the checkpoint speci-
fied by self.resume and self.resume_epoch; instead, create a
fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
if self.resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
return stats

View File

@ -60,11 +60,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
multistep_lr_milestones: With MultiStepLR policy only: list of multistep_lr_milestones: With MultiStepLR policy only: list of
increasing epoch indices at which the learning rate is modified. increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer. momentum: Momentum factor for SGD optimizer.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a newly initialized
optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
weight_decay: The optimizer weight_decay (L2 penalty on model weights). weight_decay: The optimizer weight_decay (L2 penalty on model weights).
""" """
@ -76,8 +71,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
lr_policy: str = "MultiStepLR" lr_policy: str = "MultiStepLR"
momentum: float = 0.9 momentum: float = 0.9
multistep_lr_milestones: tuple = () multistep_lr_milestones: tuple = ()
resume: bool = False
resume_epoch: int = -1
weight_decay: float = 0.0 weight_decay: float = 0.0
def __post_init__(self): def __post_init__(self):
@ -89,6 +82,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
model: ImplicitronModelBase, model: ImplicitronModelBase,
accelerator: Optional[Accelerator] = None, accelerator: Optional[Accelerator] = None,
exp_dir: Optional[str] = None, exp_dir: Optional[str] = None,
resume: bool = True,
resume_epoch: int = -1,
**kwargs, **kwargs,
) -> Tuple[torch.optim.Optimizer, Any]: ) -> Tuple[torch.optim.Optimizer, Any]:
""" """
@ -100,7 +95,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
model: The model with optionally loaded weights. model: The model with optionally loaded weights.
accelerator: An optional Accelerator instance. accelerator: An optional Accelerator instance.
exp_dir: Root experiment directory. exp_dir: Root experiment directory.
resume: If True, attempt to load optimizer checkpoint from exp_dir.
Failure to do so will return a newly initialized optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
Returns: Returns:
An optimizer module (optionally loaded from a checkpoint) and An optimizer module (optionally loaded from a checkpoint) and
a learning rate scheduler module (should be a subclass of torch.optim's a learning rate scheduler module (should be a subclass of torch.optim's
@ -131,13 +129,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
) )
else: else:
raise ValueError("no such solver type %s" % self.breed) raise ValueError(f"No such solver type {self.breed}")
logger.info(" -> solver type = %s" % self.breed) logger.info(f"Solver type = {self.breed}")
# Load state from checkpoint # Load state from checkpoint
optimizer_state = self._get_optimizer_state(exp_dir, accelerator) optimizer_state = self._get_optimizer_state(
exp_dir,
accelerator,
resume_epoch=resume_epoch,
resume=resume,
)
if optimizer_state is not None: if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state") logger.info("Setting loaded optimizer state.")
optimizer.load_state_dict(optimizer_state) optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler # Initialize the learning rate scheduler
@ -169,20 +172,31 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
self, self,
exp_dir: Optional[str], exp_dir: Optional[str],
accelerator: Optional[Accelerator] = None, accelerator: Optional[Accelerator] = None,
resume: bool = True,
resume_epoch: int = -1,
) -> Optional[Dict[str, Any]]: ) -> Optional[Dict[str, Any]]:
""" """
Load an optimizer state from a checkpoint. Load an optimizer state from a checkpoint.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a newly initialized
optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
""" """
if exp_dir is None or not self.resume: if exp_dir is None or not resume:
return None return None
if self.resume_epoch > 0: if resume_epoch > 0:
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch) save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
if not os.path.isfile(save_path):
raise FileNotFoundError(
f"Cannot find optimizer from epoch {resume_epoch}."
)
else: else:
save_path = model_io.find_last_checkpoint(exp_dir) save_path = model_io.find_last_checkpoint(exp_dir)
optimizer_state = None optimizer_state = None
if save_path is not None: if save_path is not None:
logger.info(f"Found previous optimizer state {save_path}.") logger.info(f"Found previous optimizer state {save_path} -> resuming.")
logger.info(" -> resuming")
opt_path = model_io.get_optimizer_path(save_path) opt_path = model_io.get_optimizer_path(save_path)
if os.path.isfile(opt_path): if os.path.isfile(opt_path):
@ -193,5 +207,5 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
} }
optimizer_state = torch.load(opt_path, map_location) optimizer_state = torch.load(opt_path, map_location)
else: else:
optimizer_state = None raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
return optimizer_state return optimizer_state

View File

@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import logging import logging
import os
import time import time
from typing import Any, Optional from typing import Any, List, Optional
import torch import torch
from accelerate import Accelerator from accelerate import Accelerator
@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
) -> None: ) -> None:
raise NotImplementedError() raise NotImplementedError()
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Stats:
raise NotImplementedError()
@registry.register @registry.register
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13] class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
logged. logged.
visualize_interval: The batch interval at which the visualizations visualize_interval: The batch interval at which the visualizations
should be plotted should be plotted
visdom_env: The name of the Visdom environment to use for plotting.
visdom_port: The Visdom port.
visdom_server: Address of the Visdom server.
""" """
# Parameters of the outer training loop. # Parameters of the outer training loop.
@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
test_when_finished: bool = False test_when_finished: bool = False
validation_interval: int = 1 validation_interval: int = 1
# Parameters of a single training-validation step. # Gradient clipping.
clip_grad: float = 0.0 clip_grad: float = 0.0
# Visualization/logging parameters.
metric_print_interval: int = 5 metric_print_interval: int = 5
visualize_interval: int = 1000 visualize_interval: int = 1000
visdom_env: str = ""
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
"Cannot evaluate and dump results to json, no test data provided." "Cannot evaluate and dump results to json, no test data provided."
) )
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars and resume_epoch.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
resume: If False, do not load stats from the checkpoint speci-
fied by resume and resume_epoch; instead, create a fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
model_path = None
if resume:
if resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
if not os.path.isfile(model_path):
raise FileNotFoundError(
f"Cannot find stats from epoch {resume_epoch}."
)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if resume:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
logger.info(f"Found previous stats in {stats_path} -> resuming.")
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info("Clearing stats")
return stats
def _training_or_validation_epoch( def _training_or_validation_epoch(
self, self,
epoch: int, epoch: int,

View File

@ -124,14 +124,32 @@ data_source_ImplicitronDataSource_args:
dataset_length_val: 0 dataset_length_val: 0
dataset_length_test: 0 dataset_length_test: 0
model_factory_ImplicitronModelFactory_args: model_factory_ImplicitronModelFactory_args:
force_load: false resume: true
model_class_type: GenericModel model_class_type: GenericModel
resume: false
resume_epoch: -1 resume_epoch: -1
visdom_env: '' force_resume: false
visdom_port: 8097
visdom_server: http://127.0.0.1
model_GenericModel_args: model_GenericModel_args:
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
mask_images: true mask_images: true
mask_depths: true mask_depths: true
render_image_width: 400 render_image_width: 400
@ -162,27 +180,6 @@ model_factory_ImplicitronModelFactory_args:
loss_prev_stage_rgb_mse: 1.0 loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0 loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0 loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
global_encoder_HarmonicTimeEncoder_args: global_encoder_HarmonicTimeEncoder_args:
n_harmonic_functions: 10 n_harmonic_functions: 10
append_input: true append_input: true
@ -422,8 +419,6 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
lr_policy: MultiStepLR lr_policy: MultiStepLR
momentum: 0.9 momentum: 0.9
multistep_lr_milestones: [] multistep_lr_milestones: []
resume: false
resume_epoch: -1
weight_decay: 0.0 weight_decay: 0.0
training_loop_ImplicitronTrainingLoop_args: training_loop_ImplicitronTrainingLoop_args:
eval_only: false eval_only: false
@ -437,6 +432,9 @@ training_loop_ImplicitronTrainingLoop_args:
clip_grad: 0.0 clip_grad: 0.0
metric_print_interval: 5 metric_print_interval: 5
visualize_interval: 1000 visualize_interval: 1000
visdom_env: ''
visdom_port: 8097
visdom_server: http://127.0.0.1
evaluator_ImplicitronEvaluator_args: evaluator_ImplicitronEvaluator_args:
camera_difficulty_bin_breaks: camera_difficulty_bin_breaks:
- 0.97 - 0.97

View File

@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
@ -172,3 +173,48 @@ class TestNerfRepro(unittest.TestCase):
experiment_runner = experiment.Experiment(**cfg) experiment_runner = experiment.Experiment(**cfg)
experiment.dump_cfg(cfg) experiment.dump_cfg(cfg)
experiment_runner.run() experiment_runner.run()
@unittest.skip("This test checks resuming of the NeRF training.")
def test_nerf_blender_resume(self):
# Train one train batch of NeRF, then resume for one more batch.
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
if not interactive_testing_requested():
return
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
with tempfile.TemporaryDirectory() as exp_dir:
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
cfg.exp_dir = exp_dir
# set dataset len to 1
# fmt: off
(
cfg
.data_source_ImplicitronDataSource_args
.data_loader_map_provider_SequenceDataLoaderMapProvider_args
.dataset_length_train
) = 1
# fmt: on
# run for one epoch
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 1
experiment_runner = experiment.Experiment(**cfg)
experiment.dump_cfg(cfg)
experiment_runner.run()
# update num epochs + 2, let the optimizer resume
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 3
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()
# start from scratch
cfg.model_factory_ImplicitronModelFactory_args.resume = False
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()
# force resume from epoch 1
cfg.model_factory_ImplicitronModelFactory_args.resume = True
cfg.model_factory_ImplicitronModelFactory_args.force_resume = True
cfg.model_factory_ImplicitronModelFactory_args.resume_epoch = 1
experiment_runner = experiment.Experiment(**cfg)
experiment_runner.run()

View File

@ -344,7 +344,7 @@ def export_scenes(
# Load the previously trained model # Load the previously trained model
experiment = Experiment(config) experiment = Experiment(config)
model = experiment.model_factory(force_load=True, load_model_only=True) model = experiment.model_factory(force_resume=True)
model.cuda() model.cuda()
model.eval() model.eval()

View File

@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
@ -45,6 +45,10 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
optimization. optimization.
""" """
# The keys from `preds` (output of ImplicitronModelBase.forward) to be logged in
# the training loop.
log_vars: List[str] = field(default_factory=lambda: ["objective"])
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()

View File

@ -1,3 +1,24 @@
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
mask_images: true mask_images: true
mask_depths: true mask_depths: true
render_image_width: 400 render_image_width: 400
@ -28,27 +49,6 @@ loss_weights:
loss_prev_stage_rgb_mse: 1.0 loss_prev_stage_rgb_mse: 1.0
loss_mask_bce: 0.0 loss_mask_bce: 0.0
loss_prev_stage_mask_bce: 0.0 loss_prev_stage_mask_bce: 0.0
log_vars:
- loss_rgb_psnr_fg
- loss_rgb_psnr
- loss_rgb_mse
- loss_rgb_huber
- loss_depth_abs
- loss_depth_abs_fg
- loss_mask_neg_iou
- loss_mask_bce
- loss_mask_beta_prior
- loss_eikonal
- loss_density_tv
- loss_depth_neg_penalty
- loss_autodecoder_norm
- loss_prev_stage_rgb_mse
- loss_prev_stage_rgb_psnr_fg
- loss_prev_stage_rgb_psnr
- loss_prev_stage_mask_bce
- objective
- epoch
- sec/it
global_encoder_SequenceAutodecoder_args: global_encoder_SequenceAutodecoder_args:
autodecoder_args: autodecoder_args:
encoding_dim: 0 encoding_dim: 0

View File

@ -90,5 +90,5 @@ class TestGenericModel(unittest.TestCase):
remove_unused_components(instance_args) remove_unused_components(instance_args)
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False) yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
if DEBUG: if DEBUG:
(DATA_DIR / "overrides.yaml_").write_text(yaml) (DATA_DIR / "overrides_.yaml").write_text(yaml)
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text()) self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())